Estimator API

API reference

Estimator API Reference

This document guides you through training a Estimator model in Determined. You need to implement a trial class that inherits EstimatorTrial and specify it as the entrypoint in the experiment configuration.

To learn about this API, you can start by reading the trial definitions from the following examples:

Define Optimizer and Datasets

Note

Before loading data, read this document Prepare Data to understand how to work with different sources of data.

To use tf.estimator models with Determined, users need to wrap their optimizer and datasets using wrap_optimizer() and wrap_dataset(). Note that the concrete context object where these functions will be found will be in determined.estimator.EstimatorTrialContext.

Reduce Metrics

Determined supports proper reduction of arbitrary validation metrics during distributed training by allowing users to define custom reducers for their metrics. Custom reducers can be either a function or an implementation of the determined.estimator.MetricReducer interface.

See context.make_metric() for more details.

Checkpointing

A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). Users can also embed arbitrary metadata in checkpoints via a Python API.

TensorFlow Estimator trials are checkpointed using the SavedModel format. Please consult the TensorFlow documentation for details on how to restore models from the SavedModel format.

Callbacks

To execute arbitrary Python code during the lifecycle of a EstimatorTrial, RunHook extends tf.estimator.SessionRunHook. When utilizing determined.estimator.RunHook, users can use native estimator hooks such as before_run() and Determined hooks such as on_checkpoint_end().

Example usage of determined.estimator.RunHook which adds custom metadata checkpoints:

class MyHook(determined.estimator.RunHook):
    def __init__(self, context, metadata) -> None:
        self._context = context
        self._metadata = metadata

    def on_checkpoint_end(self, checkpoint_dir) -> None:
        with open(os.path.join(checkpoint_dir, "metadata.txt"), "w") as fp:
            fp.write(self._metadata)


class MyEstimatorTrial(determined.estimator.EstimatorTrial):
    ...

    def build_train_spec(self) -> tf.estimator.TrainSpec:
        return tf.estimator.TrainSpec(
            make_input_fn(),
            hooks=[MyHook(self.context, "my_metadata")],
        )