Keras API#
In this guide, you’ll learn how to use the Keras API.
Visit the API reference |
---|
This document guides you through training a Keras model in Determined. You need to implement a trial
class that inherits TFKerasTrial
and specify it as the entrypoint in the
Experiment Configuration Reference.
To learn about this API, you can start by reading the trial definitions in the Iris categorization example.
Load Data#
Note
Before loading data, visit Prepare Data to understand how to work with different sources of data.
Loading data is done by defining build_training_data_loader()
and build_validation_data_loader()
methods. Each should return
one of the following data types:
A tuple
(x, y)
of NumPy arrays. x must be a NumPy array (or array-like), a list of arrays (in case the model has multiple inputs), or a dict mapping input names to the corresponding array, if the model has named inputs. y should be a numpy array.A tuple
(x, y, sample_weights)
of NumPy arrays.A
tf.data.dataset
returning a tuple of either (inputs, targets) or (inputs, targets, sample_weights).A
keras.utils.Sequence
returning a tuple of either (inputs, targets) or (inputs, targets, sample weights).
If using tf.data.Dataset
, users are required to wrap both their training and validation dataset
using self.context.wrap_dataset
. This
wrapper is used to shard the dataset for distributed training. For optimal performance, users should
wrap a dataset immediately after creating it.
Note
To learn more about distributed training with Determined, visit the conceptual overview or the intro to implementing distributed training.
Define the Model#
Users are required wrap their model prior to compiling it using self.context.wrap_model
. This is typically done inside
build_model()
.
Customize Calling Model Fitting Function#
The TFKerasTrial
interface allows the user to configure how model.fit
is called by calling self.context.configure_fit()
.
Checkpointing#
A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). You can also embed arbitrary metadata in checkpoints via a Python SDK.
TensorFlow Keras trials are checkpointed to a file named determined-keras-model.h5
using
tf.keras.models.save_model
. You can learn more from the TF Keras docs.
Callbacks#
To execute arbitrary Python code during the lifecycle of a TFKerasTrial
,
implement the determined.keras.callbacks.Callback
interface (an extension of the
tf.keras.callbacks.Callbacks
interface) and supply them to the
TFKerasTrial
by implementing
keras_callbacks()
.
Profiling#
Determined supports integration with the native TF Keras profiler. Results will automatically be uploaded to the trial’s TensorBoard path and can be viewed in the Determined Web UI.
The Keras profiler is configured as a callback in the TFKerasTrial
class.
The determined.keras.callbacks.TensorBoard
callback is a thin wrapper around the native
Keras TensorBoard callback, tf.keras.callbacks.TensorBoard
. It overrides the log_dir
argument to set the Determined TensorBoard path, while other arguments are passed directly into
tf.keras.callbacks.TensorBoard
. For a list of accepted arguments, consult the official Keras
API documentation.
The following code snippet will configure profiling for batches 5 and 10, and will compute weight histograms every 1 epochs.
from determined import keras
def keras_callbacks(self) -> List[tf.keras.callbacks.Callback]:
return [
keras.callbacks.TensorBoard(
update_freq="batch",
profile_batch='5, 10',
histogram_freq=1,
)
]
Note
Though specifying batches to profile with profile_batch
is optional, profiling every batch
may cause a large amount of data to be uploaded to Tensorboard. This may result in long rendering
times for Tensorboard and memory issues. For long-running experiments, it is recommended to
configure profiling only on desired batches.