Tensorpack Model Definition

This part of the documentation describes how to train a Tensorpack model in PEDL.

There are three steps needed to define a Tensorpack model in PEDL using a Standard Model Definition:

  1. Define a make_data_loaders() function to specify data access and any preprocessing in the data pipeline.

  2. Optionally, subclass the abstract class Evaluator. This part of the interface defines the validation process if the validation dataflow is not specified in make_data_loaders().

  3. Subclass the abstract class TensorpackTrial. This part of the interface defines the deep learning model, including the graph, loss, and optimizers.

Data Loading via make_data_loaders()

A PEDL user prescribes data access in TensorpackTrial by writing a make_data_loaders() function. This function should return a pair of tp.DataFlow objects, the first for the training set and the second for the validation set. Alternatively, this function can return a single tp.DataFlow object to use for training; in that case, a subclass of the Evaluator must be provided to define the validation process.

def make_data_loaders(experiment_config, hparams):
    return trainDataset, valDataset

For cases where the dataset is too large to be stored locally, PEDL supports downloading data from Google Cloud Storage (GCS). To use this feature, replace code that opens a local file, e.g.,:

import cv2
image = cv2.imread(filename)

with a call to GCS:

import cv2
from google.cloud import storage
from pedl.util import download_gcs_blob_with_backoff

c = storage.Client.create_anonymous_client()
gcs_bucket_name = "bucket_name"
bucket = c.get_bucket(gcs_bucket_name)
blob = bucket.blob(filename)
s = download_gcs_blob_with_backoff(blob)
image = cv2.imdecode(np.asarray(bytearray(s), dtype=np.uint8))

download_gcs_blob_with_backoff implements a standard error handling strategy for network applications in which a client periodically retries a failed request with increasing delays between requests. This strategy is suggested when reading data from GCS to handle transient network failures and HTTP 429 and 5xx error codes.

Subclassing Evaluator

Instead of defining the validation dataflow in make_data_loaders, PEDL users can subclass Evaluator to perform validation. Using Evaluator allows users to manually specify the validation graph, as well as custom code for computing validation metrics. Users must define the following abstract methods:

  • set_up_graph(self, trainer): Builds the validation graph. The trainer argument is an instance of tp.Trainer.

  • compute_validation_metrics(self): Defines the process for computing validation metrics. This function returns the validation metrics.

Subclassing TensorpackTrial

Tensorpack trials are created by subclassing the abstract class TensorpackTrial. Users must define the following abstract methods that will specify the deep learning model associated with a trial in the experiment, as well as how to subsequently train and evaluate it:

  • build_model(self, hparams, trainer_type): Builds and returns the Tensorpack model (tp.ModelDesc) to be used during training.

  • validation_metrics(self, hparams): If the validation dataflow is specified in make_data_loaders, this function returns a list of metric names that will be evaluated on the validation data set (e.g., "cross_entropy_loss"). Otherwise this function returns an instance of Evaluator.

Optional Methods

  • training_metrics(self, hparams): Specifies the training metrics that should be tracked (e.g., "learning_rate").

  • tensorpack_callbacks(self, hparams): Returns a list of Tensorpack callbacks to use during training. Often users choose to control their learning rate schedule via these callbacks (e.g., pedl.frameworks.tensorflow.tensorpack_trial.ScheduleSetter()).

  • tensorpack_monitors(self, hparams): Returns a list of Tensorpack monitors to use during training.

  • load_backbone_weights(self, hparams): Returns the filepath for the backbone weights which are loaded prior to training.

Learning Rate Schedule

When training models using Tensorpack, users often choose to control learning rate schedules via callbacks. TensorpackTrial provides PEDL users with ScheduleSetter, which subclasses tp.callbacks.HyperParamSetter. (Note: ScheduleSetter can be used for any hyperparameter, but is most commonly used to control learning rate.) ScheduleSetter takes a list of SchedulePoint() objects, which defines how the value of hyperparameter changes during training.

  • SchedulePoint(point, value, interp): Specifies that the value of the hyperparameter should be exactly value at the end of point training steps. interp can either be None to specify that the value should remain the same until the next point or "interp" to specify that it should be linearly interpolated.

  • ScheduleSetter(param, schedule): Defines the parameter name (e.g., "learning_rate") and a list of schedule points.

An example use case where we reduce the initial learning rate every 1000 steps by factor of 10:

from pedl.frameworks.tensorflow.tensorpack_trial import SchedulePoint, ScheduleSetter, TensorpackTrial

def make_schedule():
    init_lr = 0.1
    schedule = []

    for idx in range(10):
        mult = 0.1 ** idx
        schedule.append(SchedulePoint(1000 * idx, init_lr * mult))

    return schedule

class YourTrial(TensorpackTrial):
    def tensorpack_callbacks(self, hparams):
        return [
            ScheduleSetter("learning_rate", make_schedule())

Performance Optimization

When training in a distributed setting, TensorpackTrial supports a performance optimization that shortens the training time by reducing communication. PEDL users can enable this feature by setting the hyperparameter aggregation_frequency to be greater than 1:

  aggregation_frequency: 4

Aggregation frequency controls the frequency with which updates are communicated between workers. This modifies the effective training batch size (number of training samples processed per gradient update); thus PEDL users are encouraged to modify their learning rate proportionally with aggregation frequency: new_learning_rate = original_learning_rate * aggregation_frequency.