This part of the documentation describes how to train a Tensorpack model in Determined.
Model and Metrics¶
Tensorpack trials are created by subclassing the abstract class
TensorpackTrial. Users must implement the following abstract methods
that will specify the deep learning model associated with a trial in the
experiment, as well as how to subsequently train and evaluate it:
build_model(self, hparams, trainer_type): Builds and returns the Tensorpack model (tp.ModelDesc) to be used during training.
validation_metrics(self, hparams): If the validation dataflow is specified in
build_validation_dataflow, this function returns a list of metric names that will be evaluated on the validation data set (e.g.,
"cross_entropy_loss"). Otherwise this function returns an instance of
build_training_dataflow(self): Covered in the next section.
training_metrics(self, hparams): Specifies the training metrics that should be tracked (e.g.,
tensorpack_callbacks(self, hparams): Returns a list of Tensorpack callbacks to use during training. Often users choose to control their learning rate schedule via these callbacks (e.g.,
tensorpack_monitors(self, hparams): Returns a list of Tensorpack monitors to use during training.
load_backbone_weights(self, hparams): Returns the filepath for the backbone weights which are loaded prior to training.
build_validation_dataflow(self): Covered in the next section.
Learning Rate Schedule¶
When training models using Tensorpack, users often choose to control
learning rate schedules via callbacks.
Determined users with
ScheduleSetter, which subclasses
ScheduleSetter can be used for any hyperparameter, but is
most commonly used to control learning rate.)
ScheduleSetter takes a
SchedulePoint() objects, which defines how the value of
hyperparameter changes during training.
SchedulePoint(point, value, interp): Specifies that the value of the hyperparameter should be exactly
valueat the end of
interpcan either be
Noneto specify that the value should remain the same until the next point or
"interp"to specify that it should be linearly interpolated.
ScheduleSetter(param, schedule): Defines the parameter name (e.g.,
"learning_rate") and a list of schedule points.
An example use case where we reduce the initial learning rate every 1000 steps by factor of 10:
from determined.tensorpack import SchedulePoint, ScheduleSetter, TensorpackTrial def make_schedule(): init_lr = 0.1 schedule =  for idx in range(10): mult = 0.1 ** idx schedule.append(SchedulePoint(1000 * idx, init_lr * mult)) return schedule class YourTrial(TensorpackTrial): ... def tensorpack_callbacks(self, hparams): return [ScheduleSetter("learning_rate", make_schedule())]
A Determined user prescribes data access in
build_training_dataflow function (required) and
build_validation_dataflow function (optional). These functions
should return tp.DataFlow objects
def build_training_dataflow(self): ... return trainDataset def build_validation_dataflow(self): ... return validDataset
For cases where the dataset is too large to be stored locally, Determined supports downloading data from Google Cloud Storage (GCS). To use this feature, replace code that opens a local file, e.g.,:
import cv2 image = cv2.imread(filename)
with a call to GCS:
import cv2 from google.cloud import storage from determined.util import download_gcs_blob_with_backoff c = storage.Client.create_anonymous_client() gcs_bucket_name = "bucket_name" bucket = c.get_bucket(gcs_bucket_name) blob = bucket.blob(filename) s = download_gcs_blob_with_backoff(blob) image = cv2.imdecode(np.asarray(bytearray(s), dtype=np.uint8))
download_gcs_blob_with_backoff implements a standard error handling
strategy for network applications in which a client periodically retries
a failed request with increasing delays between requests. This strategy
is suggested when reading data from GCS to handle transient network
failures and HTTP 429 and 5xx error codes.
Instead of defining the validation dataflow in
build_validation_dataflow, Determined users can subclass
Evaluator to perform validation. Using
Evaluator allows users to
manually specify the validation graph, as well as custom code for
computing validation metrics. We must implement the following abstract
set_up_graph(self, trainer): Builds the validation graph. The
trainerargument is an instance of tp.Trainer.
compute_validation_metrics(self): Defines the process for computing validation metrics. This function returns the validation metrics.