Estimator API¶
API reference |
---|
This document guides you through training a Estimator model in Determined. You need to implement a
trial class that inherits EstimatorTrial
and specify it as the
entrypoint in the experiment configuration.
To learn about this API, you can start by reading the trial definitions from the following examples:
Define Optimizer and Datasets¶
Note
Before loading data, read this document Prepare Data to understand how to work with different sources of data.
To use tf.estimator
models with Determined, users need to wrap their optimizer and datasets
using wrap_optimizer()
and
wrap_dataset()
. Note that the concrete context
object where these functions will be found will be in
determined.estimator.EstimatorTrialContext
.
Reduce Metrics¶
Determined supports proper reduction of arbitrary validation metrics during distributed training by
allowing users to define custom reducers for their metrics. Custom reducers can be either a function
or an implementation of the determined.estimator.MetricReducer
interface.
See context.make_metric()
for more
details.
Checkpointing¶
A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). Users can also embed arbitrary metadata in checkpoints via the Python SDK.
TensorFlow Estimator trials are checkpointed using the SavedModel format. Please consult the TensorFlow documentation for details on how to restore models from the SavedModel format.
Callbacks¶
To execute arbitrary Python code during the lifecycle of a EstimatorTrial
,
RunHook
extends tf.estimator.SessionRunHook. When utilizing
determined.estimator.RunHook
, users can use native estimator hooks such as before_run()
and Determined hooks such as on_checkpoint_end()
.
Example usage of determined.estimator.RunHook
which adds custom metadata checkpoints:
class MyHook(determined.estimator.RunHook):
def __init__(self, context, metadata) -> None:
self._context = context
self._metadata = metadata
def on_checkpoint_end(self, checkpoint_dir) -> None:
with open(os.path.join(checkpoint_dir, "metadata.txt"), "w") as fp:
fp.write(self._metadata)
class MyEstimatorTrial(determined.estimator.EstimatorTrial):
...
def build_train_spec(self) -> tf.estimator.TrainSpec:
return tf.estimator.TrainSpec(
make_input_fn(),
hooks=[MyHook(self.context, "my_metadata")],
)