Estimator API#

Warning

EstimatorTrial is deprecated and will be removed in a future version. TensorFlow has advised Estimator users to switch to Keras since TensorFlow 2.0 was released. Consequently, we recommend users of EstimatorTrial to switch to the TFKerasTrial class.

In this guide, you’ll learn how to use the Estimator API.

Visit the API reference

det.estimator API Reference

This document guides you through training a Estimator model in Determined. You need to implement a trial class that inherits EstimatorTrial and specify it as the entrypoint in the experiment configuration.

Define Optimizer and Datasets#

Note

Before loading data, read this document Prepare Data to understand how to work with different sources of data.

To use tf.estimator models with Determined, you’ll need to wrap your optimizer and datasets using wrap_optimizer() and wrap_dataset(). Note that the concrete context object where these functions will be found will be in determined.estimator.EstimatorTrialContext.

Reduce Metrics#

Determined supports proper reduction of arbitrary validation metrics during distributed training by allowing users to define custom reducers for their metrics. Custom reducers can be either a function or an implementation of the determined.estimator.MetricReducer interface.

See context.make_metric() for more details.

Checkpointing#

A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). You can also embed arbitrary metadata in checkpoints via the Python SDK.

TensorFlow Estimator trials are checkpointed using the SavedModel format. Please consult the TensorFlow documentation for details on how to restore models from the SavedModel format.

Callbacks#

To execute arbitrary Python code during the lifecycle of a EstimatorTrial, RunHook extends tf.estimator.SessionRunHook. When utilizing determined.estimator.RunHook, users can use native estimator hooks such as before_run() and Determined hooks such as on_checkpoint_end().

Example usage of determined.estimator.RunHook which adds custom metadata checkpoints:

class MyHook(determined.estimator.RunHook):
    def __init__(self, context, metadata) -> None:
        self._context = context
        self._metadata = metadata

    def on_checkpoint_end(self, checkpoint_dir) -> None:
        with open(os.path.join(checkpoint_dir, "metadata.txt"), "w") as fp:
            fp.write(self._metadata)


class MyEstimatorTrial(determined.estimator.EstimatorTrial):
    ...

    def build_train_spec(self) -> tf.estimator.TrainSpec:
        return tf.estimator.TrainSpec(
            make_input_fn(),
            hooks=[MyHook(self.context, "my_metadata")],
        )