determined.tensorpack¶
This part of the documentation describes how to train a Tensorpack model in Determined.
determined.tensorpack.TensorpackTrial
¶
Model and Metrics¶
Tensorpack trials are created by subclassing the abstract class
TensorpackTrial
. Users must implement the following abstract methods
that will specify the deep learning model associated with a trial in the
experiment, as well as how to subsequently train and evaluate it:
build_model(self, hparams, trainer_type)
: Builds and returns the Tensorpack model (tp.ModelDesc) to be used during training.validation_metrics(self, hparams)
: If the validation dataflow is specified inbuild_validation_dataflow
, this function returns a list of metric names that will be evaluated on the validation data set (e.g.,"cross_entropy_loss"
). Otherwise this function returns an instance ofEvaluator
.build_training_dataflow(self)
: Covered in the next section.
Optional Methods
training_metrics(self, hparams)
: Specifies the training metrics that should be tracked (e.g.,"learning_rate"
).tensorpack_callbacks(self, hparams)
: Returns a list of Tensorpack callbacks to use during training. Often users choose to control their learning rate schedule via these callbacks (e.g.,det.tensorpack.ScheduleSetter()
).tensorpack_monitors(self, hparams)
: Returns a list of Tensorpack monitors to use during training.load_backbone_weights(self, hparams)
: Returns the filepath for the backbone weights which are loaded prior to training.build_validation_dataflow(self)
: Covered in the next section.
Learning Rate Schedule¶
When training models using Tensorpack, users often choose to control
learning rate schedules via callbacks. TensorpackTrial
provides Determined
users with ScheduleSetter
, which subclasses
tp.callbacks.HyperParamSetter.
(Note: ScheduleSetter
can be used for any hyperparameter, but is
most commonly used to control learning rate.) ScheduleSetter
takes a
list of SchedulePoint()
objects, which defines how the value of
hyperparameter changes during training.
SchedulePoint(point, value, interp)
: Specifies that the value of the hyperparameter should be exactlyvalue
at the end ofpoint
training steps.interp
can either beNone
to specify that the value should remain the same until the next point or"interp"
to specify that it should be linearly interpolated.ScheduleSetter(param, schedule)
: Defines the parameter name (e.g.,"learning_rate"
) and a list of schedule points.
An example use case where we reduce the initial learning rate every 1000 steps by factor of 10:
from determined.tensorpack import SchedulePoint, ScheduleSetter, TensorpackTrial
def make_schedule():
init_lr = 0.1
schedule = []
for idx in range(10):
mult = 0.1 ** idx
schedule.append(SchedulePoint(1000 * idx, init_lr * mult))
return schedule
class YourTrial(TensorpackTrial):
...
def tensorpack_callbacks(self, hparams):
return [
ScheduleSetter("learning_rate", make_schedule())
]
Data Loading¶
A Determined user prescribes data access in TensorpackTrial
by writing a
build_training_dataflow
function (required) and
build_validation_dataflow
function (optional). These functions should
return tp.DataFlow objects
respectively.
def build_training_dataflow(self):
...
return trainDataset
def build_validation_dataflow(self):
...
return validDataset
For cases where the dataset is too large to be stored locally, Determined supports downloading data from Google Cloud Storage (GCS). To use this feature, replace code that opens a local file, e.g.,:
import cv2
image = cv2.imread(filename)
with a call to GCS:
import cv2
from google.cloud import storage
from determined.util import download_gcs_blob_with_backoff
c = storage.Client.create_anonymous_client()
gcs_bucket_name = "bucket_name"
bucket = c.get_bucket(gcs_bucket_name)
blob = bucket.blob(filename)
s = download_gcs_blob_with_backoff(blob)
image = cv2.imdecode(np.asarray(bytearray(s), dtype=np.uint8))
download_gcs_blob_with_backoff
implements a standard error handling
strategy for network applications in which a client periodically retries
a failed request with increasing delays between requests. This strategy
is suggested when reading data from GCS to handle transient network
failures and HTTP 429 and 5xx error codes.
Subclassing Evaluator
¶
Instead of defining the validation dataflow in build_validation_dataflow
,
Determined users can subclass Evaluator
to perform validation. Using
Evaluator
allows users to manually specify the validation graph, as
well as custom code for computing validation metrics. We must implement
the following abstract methods:
set_up_graph(self, trainer)
: Builds the validation graph. Thetrainer
argument is an instance of tp.Trainer.compute_validation_metrics(self)
: Defines the process for computing validation metrics. This function returns the validation metrics.