tf.keras Model Definition

This part of the documentation describes how to train a tf.keras model in PEDL. There are two steps needed to define a tf.keras model in PEDL:

  1. Define a make_data_loaders() function. See Data Loading for more information.

  2. Implement the TFKerasTrial interface.

Data Loading

There are three supported data types for loading data into tf.keras models: an object that implements the tf.keras.utils.Sequence interface, or a

Loading data is done by defining a make_data_loaders() function. This function should return a pair of objects (one for training and one for validation) of one of the three following types: * Example for can be found in the TensorFlow Documentation. * An object that implements the tf.keras.utils.Sequence interface.

The behavior of Sequence objects should be familiar if you have used fit_generator. Like in tf.keras, these Sequence objects should return batches of data (i.e., either (inputs, targets) or (inputs, targets, sample_weights)). Examples can be found in TFKerasTrial.


If you are using with distributed training, PEDL’s support for automatically checkpointing and resuming workloads does not work correctly. Therefore, using inputs with distributed training is currently not recommended.

Required Wrappers

To use users need to wrap their datasets using PEDL-provided wrappers.

  • pedl.frameworks.keras.wrap_dataset(dataset): This should be used to wrap objects immediately after they have been created. Users should use the output of this wrapper as the new instance of their dataset. If users create multiple datasets (e.g., one for training and one for testing) users should wrap each dataset independently.

Multithreading / Multiprocessing

We support multithreading and multiprocessing only for Sequences. This can be done by returning an instance of as one of the data loaders in make_data_loaders(). KerasDataAdapter is a small abstraction that provides a way to define the parameters for multithreading / multiprocessing. Its behavior is similar to the data inputs to Keras’ fit_generator().

Usage Examples

  • Use main Python process with no multithreading and no multiprocessing

KerasDataAdapter(sequence, workers=0, use_multiprocessing=False)
  • Use one background process

KerasDataAdapter(sequence, workers=1, use_multiprocessing=True)
  • Use two background threads

KerasDataAdapter(sequence, workers=2, use_multiprocessing=False)


  • sequence: A Sequence that holds the data.

  • use_multiprocessing: If True, use multiprocessing, else, use multithreading. If unspecified, use_multiprocessing will default to False. Note that because this implementation relies on multiprocessing, you should not pass non-pickleable arguments for the data loaders as they can’t be passed easily to children processes.

  • workers: Maximum number of processes to create when using multiprocessing, otherwise it is the maximum number of threads. If unspecified, workers will default to 1. If 0, will execute the data loading on the main thread.

  • max_queue_size: Maximum size for the generator queue. If unspecified, max_queue_size will default to 10.

Multi-GPU Training

When doing Multi-GPU Training of TFKerasTrial, see Data Downloading if downloading data within the trial.

Examples for TFKerasTrial


class pedl.frameworks.keras.tf_keras_trial.TFKerasTrial(hparams: Dict[str, Any], *_: Any, **__: Any)

tf.keras trials are created by subclassing the abstract class TFKerasTrial.

Users must define all the abstract methods to create the deep learning model associated with a specific trial, and to subsequently train and evaluate it.

abstract build_model(hparams: Dict[str, Any]) →

Defines the deep learning architecture associated with a trial, which may depend on the trial’s specific hyperparameter settings that are stored in the hparams dictionary. This function returns a tf.keras.Model object. Users must compile this model by calling model.compile() on the tf.keras.Model instance before it is returned.

session_config(hparams: Dict[str, Any]) → tensorflow.core.protobuf.config_pb2.ConfigProto

Specifies the tf.ConfigProto to be used by the TensorFlow session. By default, tf.ConfigProto(allow_soft_placement=True) is used.

keras_callbacks(hparams: Dict[str, Any]) → List[tensorflow.python.keras.callbacks.Callback]

Specifies a list of tf.keras.callback.Callback objects to be used during the trial’s lifetime.