Keras API

API reference

Keras API Reference

This document guides you through training a Keras model in Determined. You need to implement a trial class that inherits TFKerasTrial and specify it as the entrypoint in the experiment configuration.

To learn about this API, you can start by reading the trial definitions from the following examples:

Load Data

Note

Before loading data, read Prepare Data to understand how to work with different sources of data.

Loading data is done by defining build_training_data_loader() and build_validation_data_loader() methods. Each should return one of the following data types:

  1. A tuple (x, y) of Numpy arrays. x must be a NumPy array (or array-like), a list of arrays (in case the model has multiple inputs), or a dict mapping input names to the corresponding array, if the model has named inputs. y should be a numpy array.

  2. A tuple (x, y, sample_weights) of Numpy arrays.

  3. A tf.data.dataset returning a tuple of either (inputs, targets) or (inputs, targets, sample_weights).

  4. A keras.utils.Sequence returning a tuple of either (inputs, targets) or (inputs, targets, sample weights).

If using tf.data.Dataset, users are required to wrap both their training and validation dataset using self.context.wrap_dataset. This wrapper is used to shard the dataset for Introduction to Distributed Training. For optimal performance, users should wrap a dataset immediately after creating it.

Define the Model

Users are required wrap their model prior to compiling it using self.context.wrap_model. This is typically done inside build_model().

Customize Calling Model Fitting Function

The TFKerasTrial interface allows the user to configure how model.fit is called by calling self.context.configure_fit().

Checkpointing

A checkpoint includes the model definition (Python source code), experiment configuration file, network architecture, and the values of the model’s parameters (i.e., weights) and hyperparameters. When using a stateful optimizer during training, checkpoints will also include the state of the optimizer (i.e., learning rate). Users can also embed arbitrary metadata in checkpoints via a Python API.

TensorFlow Keras trials are checkpointed to a file named determined-keras-model.h5 using tf.keras.models.save_model. You can learn more from the TF Keras docs.

Callbacks

To execute arbitrary Python code during the lifecycle of a TFKerasTrial, implement the determined.keras.callbacks.Callback interface (an extension of the tf.keras.callbacks.Callbacks interface) and supply them to the TFKerasTrial by implementing keras_callbacks().