This part of the documentation describes how to train a Tensorpack model in Determined.


Model and Metrics

Tensorpack trials are created by subclassing the abstract class TensorpackTrial. Users must implement the following abstract methods that will specify the deep learning model associated with a trial in the experiment, as well as how to subsequently train and evaluate it:

  • build_model(self, hparams, trainer_type): Builds and returns the Tensorpack model (tp.ModelDesc) to be used during training.

  • validation_metrics(self, hparams): If the validation dataflow is specified in build_validation_dataflow, this function returns a list of metric names that will be evaluated on the validation data set (e.g., "cross_entropy_loss"). Otherwise this function returns an instance of Evaluator.

  • build_training_dataflow(self): Covered in the next section.

Optional Methods

  • training_metrics(self, hparams): Specifies the training metrics that should be tracked (e.g., "learning_rate").

  • tensorpack_callbacks(self, hparams): Returns a list of Tensorpack callbacks to use during training. Often users choose to control their learning rate schedule via these callbacks (e.g., det.tensorpack.ScheduleSetter()).

  • tensorpack_monitors(self, hparams): Returns a list of Tensorpack monitors to use during training.

  • load_backbone_weights(self, hparams): Returns the filepath for the backbone weights which are loaded prior to training.

  • build_validation_dataflow(self): Covered in the next section.

Learning Rate Schedule

When training models using Tensorpack, users often choose to control learning rate schedules via callbacks. TensorpackTrial provides Determined users with ScheduleSetter, which subclasses tp.callbacks.HyperParamSetter. (Note: ScheduleSetter can be used for any hyperparameter, but is most commonly used to control learning rate.) ScheduleSetter takes a list of SchedulePoint() objects, which defines how the value of hyperparameter changes during training.

  • SchedulePoint(point, value, interp): Specifies that the value of the hyperparameter should be exactly value at the end of point training steps. interp can either be None to specify that the value should remain the same until the next point or "interp" to specify that it should be linearly interpolated.

  • ScheduleSetter(param, schedule): Defines the parameter name (e.g., "learning_rate") and a list of schedule points.

An example use case where we reduce the initial learning rate every 1000 steps by factor of 10:

from determined.tensorpack import SchedulePoint, ScheduleSetter, TensorpackTrial

def make_schedule():
    init_lr = 0.1
    schedule = []

    for idx in range(10):
        mult = 0.1 ** idx
        schedule.append(SchedulePoint(1000 * idx, init_lr * mult))

    return schedule

class YourTrial(TensorpackTrial):

    def tensorpack_callbacks(self, hparams):
        return [ScheduleSetter("learning_rate", make_schedule())]

Data Loading

A Determined user prescribes data access in TensorpackTrial by writing a build_training_dataflow function (required) and build_validation_dataflow function (optional). These functions should return tp.DataFlow objects respectively.

def build_training_dataflow(self):
    return trainDataset

def build_validation_dataflow(self):
    return validDataset

For cases where the dataset is too large to be stored locally, Determined supports downloading data from Google Cloud Storage (GCS). To use this feature, replace code that opens a local file, e.g.,:

import cv2

image = cv2.imread(filename)

with a call to GCS:

import cv2
from import storage
from determined.util import download_gcs_blob_with_backoff

c = storage.Client.create_anonymous_client()
gcs_bucket_name = "bucket_name"
bucket = c.get_bucket(gcs_bucket_name)
blob = bucket.blob(filename)
s = download_gcs_blob_with_backoff(blob)
image = cv2.imdecode(np.asarray(bytearray(s), dtype=np.uint8))

download_gcs_blob_with_backoff implements a standard error handling strategy for network applications in which a client periodically retries a failed request with increasing delays between requests. This strategy is suggested when reading data from GCS to handle transient network failures and HTTP 429 and 5xx error codes.

Subclassing Evaluator

Instead of defining the validation dataflow in build_validation_dataflow, Determined users can subclass Evaluator to perform validation. Using Evaluator allows users to manually specify the validation graph, as well as custom code for computing validation metrics. We must implement the following abstract methods:

  • set_up_graph(self, trainer): Builds the validation graph. The trainer argument is an instance of tp.Trainer.

  • compute_validation_metrics(self): Defines the process for computing validation metrics. This function returns the validation metrics.