TensorFlow Estimator Model Definition

EstimatorTrial Interface

class pedl.frameworks.tensorflow.estimator_trial.EstimatorTrial(hparams: Dict[str, Any], *_: Any, **__: Any)
abstract build_estimator(hparams: Dict[str, Any]) → tensorflow_estimator.python.estimator.estimator.Estimator

Specifies the tf.estimator.Estimator instance to be used during training and validation. This may be an instance of a Premade Estimator provided by the TensorFlow team, or a Custom Estimator created by the user.

abstract build_train_spec(hparams: Dict[str, Any]) → tensorflow_estimator.python.estimator.training.TrainSpec

Specifies the tf.estimator.TrainSpec to be used for training steps. This training specification will contain a TensorFlow input_fn which constructs the input data for a training step. Unlike the standard Tensorflow input_fn interface, EstimatorTrial only supports an input_fn that returns a tf.data.Dataset object. A function that returns a tuple of features and labels is currently not supported by EstimatorTrial. Additionally, the max_steps attribute of the training specification will be ignored; instead, the batches_per_step option in the experiment configuration is used to determine how many batches each training step uses.

Note

When doing distributed training or optimized_parallel single machine training of an Estimator model, see Data Downloading if build_train_spec() downloads the entire data set.

abstract build_validation_spec(hparams: Dict[str, Any]) → tensorflow_estimator.python.estimator.training.EvalSpec

Specifies the tf.estimator.EvalSpec to be used for validation steps. This evaluation spec will contain a TensorFlow input_fn which constructs the input data for a validation step. The validation step will evaluate steps batches, or evaluate until the input_fn raises an end-of-input exception if steps is None.

Note

When doing distributed training or optimized_parallel single machine training of an Estimator model, see Data Downloading if build_validation_spec() downloads the entire data set.

build_serving_input_receiver_fns(hparams: Dict[str, Any]) → Dict[str, Callable[..., Union[tensorflow_estimator.python.estimator.export.export.ServingInputReceiver, tensorflow_estimator.python.estimator.export.export.TensorServingInputReceiver]]]

Optionally returns a Python dictionary mapping string names to serving_input_receiver_fn s. If specified, each serving input receiver function will be used to export a distinct SavedModel inference graph when a PEDL checkpoint is saved, using Estimator.export_saved_model. The exported models are saved under subdirectories named by the keys of the respective serving input receiver functions. For example, returning

{
    "raw": tf.estimator.export.build_raw_serving_input_receiver_fn(...),
    "parsing": tf.estimator.export.build_parsing_serving_input_receiver_fn(...)
}

from this function would configure PEDL to export two SavedModel inference graphs in every checkpoint under raw and parsing subdirectories, respectively. By default, this function returns an empty dictionary and the PEDL checkpoint directory only contains metadata associated with the training graph.

Required Wrappers

To use EstimatorTrial, users need to wrap their optimizer and datasets using the following PEDL-provided wrappers.

pedl.frameworks.tensorflow.estimator_wrap.wrap_optimizer(optimizer: Any) → Any

This should be used to wrap optimizers objects immediately after they have been created. Users should use the output of this wrapper as the new instance of their optimizer. E.g., If users create their optimizer within build_estimator(), they should call optimizer = wrap_optimizer(optimzer) prior to passing the optimizer into their Estimator.

pedl.frameworks.tensorflow.estimator_wrap.wrap_dataset(dataset: Any) → Any

This should be used to wrap tf.data.Dataset objects immediately after they have been created. Users should use the output of this wrapper as the new instance of their dataset. If users create multiple datasets (e.g., one for training and one for testing) users should wrap each dataset independently. E.g., If users instantiate their training dataset within build_train_spec(), they should call dataset = wrap_dataset(dataset) prior to passing it into tf.estimator.TrainSpec.

Examples for EstimatorTrial