This document guides you through training a Estimator model in Determined. You need to implement a
trial class that inherits
EstimatorTrial and specify it as the
entrypoint in the experiment configuration.
To learn about this API, you can start by reading the trial definitions from the following examples:
Define Optimizer and Datasets¶
Before loading data, read this document Prepare Data to understand how to work with different sources of data.
tf.estimator models with Determined, users need to wrap their optimizer and datasets
wrap_dataset(). Note that the concrete context
object where these functions will be found will be in
Determined supports proper reduction of arbitrary validation metrics during distributed training by
allowing users to define custom reducers for their metrics. Custom reducers can be either a function
or an implementation of the
context.make_metric() for more
A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). Users can also embed arbitrary metadata in checkpoints via a Python API.
TensorFlow Estimator trials are checkpointed using the SavedModel format. Please consult the TensorFlow documentation for details on how to restore models from the SavedModel format.
To execute arbitrary Python code during the lifecycle of a
RunHook extends tf.estimator.SessionRunHook. When utilizing
determined.estimator.RunHook, users can use native estimator hooks such as
and Determined hooks such as
Example usage of
determined.estimator.RunHook which adds custom metadata checkpoints:
class MyHook(determined.estimator.RunHook): def __init__(self, context, metadata) -> None: self._context = context self._metadata = metadata def on_checkpoint_end(self, checkpoint_dir) -> None: with open(os.path.join(checkpoint_dir, "metadata.txt"), "w") as fp: fp.write(self._metadata) class MyEstimatorTrial(determined.estimator.EstimatorTrial): ... def build_train_spec(self) -> tf.estimator.TrainSpec: return tf.estimator.TrainSpec( make_input_fn(), hooks=[MyHook(self.context, "my_metadata")], )