class determined.pytorch.PyTorchTrial(context: determined.pytorch._pytorch_context.PyTorchTrialContext)

PyTorch trials are created by subclassing this abstract class.

We can do the following things in this trial class:

  • Define models, optimizers, and LR schedulers.

    Initialize models, optimizers, and LR schedulers and wrap them with wrap_model, wrap_optimizer, wrap_lr_scheduler provided by PyTorchTrialContext in the __init__().

  • Run forward and backward passes.

    Call backward and step_optimizer provided by PyTorchTrialContext in train_batch(). Note that we support arbitrary numbers of models, optimizers, and LR schedulers and arbitrary orders of running forward and backward passes.

  • Configure automatic mixed precision.

    Call configure_apex_amp provided by PyTorchTrialContext in the __init__().

  • Clip gradients.

    In the train_batch(), pass a function into step_optimizer(optimizer, clip_grads=...) provided by PyTorchTrialContext.

abstract __init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext) → None

Initializes a trial using the provided context. The general steps are:

  1. Initialize model(s) and wrap them with context.wrap_model.

  2. Initialize optimizer(s) and wrap them with context.wrap_optimizer.

  3. Initialize learning rate schedulers and wrap them with context.wrap_lr_scheduler.

  4. If desired, wrap models and optimizer with context.configure_apex_amp to use apex.amp for automatic mixed precision.

Here is a code example.

self.context = context

self.a = self.context.wrap_model(MyModelA())
self.b = self.context.wrap_model(MyModelB())
self.opt1 = self.context.wrap_optimizer(torch.optm.Adam(self.a))
self.opt2 = self.context.wrap_optimizer(torch.optm.Adam(self.b))

(self.a, self.b), (self.opt1, self.opt2) = self.context.configure_apex_amp(
    models=[self.a, self.b],
    optimizers=[self.opt1, self.opt2],

self.lrs1 = self.context.wrap_lr_scheduler(
    lr_scheduler=LambdaLR(self.opt1, lr_lambda=lambda epoch: 0.95 ** epoch),
abstract train_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], epoch_idx: int, batch_idx: int) → Union[torch.Tensor, Dict[str, Any]]

Train on one batch.

Users should implement this function by doing the following things:

  1. Run forward passes on the models.

  2. Calculate the gradients with the losses with context.backward.

  3. Call an optimization step for the optimizers with context.step_optimizer. You can clip gradients by specifying the argument clip_grads.

  4. Step LR schedulers if using manual step mode.

  5. Return training metrics in a dictionary.

Here is a code example.

# Assume two models, two optimizers, and two LR schedulers were initialized
# in ``__init__``.

# Calculate the losses using the models.
loss1 = self.model1(batch)
loss2 = self.model2(batch)

# Run backward passes on losses and step optimizers. These can happen
# in arbitrary orders.
    clip_grads=lambda params: torch.nn.utils.clip_grad_norm_(params, 0.0001),

# Step the learning rate.

return {"loss1": loss1, "loss2": loss2}
  • batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for training.

  • epoch_idx (integer) – index of the current epoch among all the batches processed per device (slot) since the start of training.

  • batch_idx (integer) – index of the current batch among all the epoches processed per device (slot) since the start of training.


training metrics to return.

Return type

torch.Tensor or Dict[str, Any]

abstract build_training_data_loader() → determined.pytorch._data.DataLoader

Defines the data loader to use during training.

Must return an instance of determined.pytorch.DataLoader.

abstract build_validation_data_loader() → determined.pytorch._data.DataLoader

Defines the data loader to use during validation.

Must return an instance of determined.pytorch.DataLoader.

build_callbacks() → Dict[str, determined.pytorch._callback.PyTorchCallback]

Defines a dictionary of string names to callbacks to be used during training and/or validation.

The string name will be used as the key to save and restore callback state for any callback that defines load_state_dict() and state_dict().

evaluate_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]) → Dict[str, Any]

Calculate evaluation metrics for a batch and return them as a dictionary mapping metric names to metric values.

There are two ways to specify evaluation metrics. Either override evaluate_batch() or evaluate_full_dataset(). While evaluate_full_dataset() is more flexible, evaluate_batch() should be preferred, since it can be parallelized in distributed environments, whereas evaluate_full_dataset() cannot. Only one of evaluate_full_dataset() and evaluate_batch() should be overridden by a trial.

The metrics returned from this function must be JSON-serializable.


batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for evaluating.

evaluation_reducer() → Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]]

Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to determined.pytorch.Reducer.AVG.

evaluate_full_dataset(data_loader: → Dict[str, Any]

Calculate validation metrics on the entire validation dataset and return them as a dictionary mapping metric names to reduced metric values (i.e., each returned metric is the average or sum of that metric across the entire validation set).

This validation can not be distributed and is performed on a single device, even when multiple devices (slots) are used for training. Only one of evaluate_full_dataset() and evaluate_batch() should be overridden by a trial.

The metrics returned from this function must be JSON-serializable.


data_loader ( – data loader for evaluating.

class determined.pytorch.LRScheduler(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode)

Wrapper for a PyTorch LRScheduler.

This wrapper fulfills two main functions:

  1. Save and restore the learning rate when a trial is paused, preempted, etc.

  2. Step the learning rate scheduler at the configured frequency (e.g., every batch or every epoch).

class StepMode

Specifies when and how scheduler.step() should be executed.

__init__(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode)

LRScheduler constructor

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler to be used by Determined.

  • step_mode (det.pytorch.LRSchedulerStepMode) –

    The strategy Determined will use to call (or not call) scheduler.step().

    1. STEP_EVERY_EPOCH: Determined will call scheduler.step() after every training epoch. No arguments will be passed to step().

    2. STEP_EVERY_BATCH: Determined will call scheduler.step() after every training batch. No arguments will be passed to step().

    3. MANUAL_STEP: Determined will not call scheduler.step() at all. It is up to the user to decide when to call scheduler.step(), and whether to pass any arguments.

class determined.pytorch.Reducer

A Reducer defines a method for reducing (aggregating) evaluation metrics. See determined.pytorch.PyTorchTrial.evaluation_reducer() for details.

class determined.tensorboard.metric_writers.pytorch.TorchWriter

TorchWriter uses PyTorch file writers and summary operations to write out tfevent files containing scalar batch metrics. It creates an instance of torch.utils.tensorboard.SummaryWriter which can be accessed via the writer field and configures the SummaryWriter to write to the correct directory inside the trial container.

Usage example:

from determined.tensorboard.metric_writers.pytorch import TorchWriter

class MyModel(PyTorchTrial):
    def __init__(self, context):
        self.logger = TorchWriter()

    def train_batch(self, batch, epoch_idx, batch_idx):
        self.logger.writer.add_scalar('my_metric', np.random.random(), batch_idx)

Data Loading

Loading data into PyTorchTrial models is done by defining two functions, build_training_data_loader() and build_validation_data_loader(). These functions should each return an instance of determined.pytorch.DataLoader. determined.pytorch.DataLoader behaves the same as and is a drop-in replacement.

Each DataLoader is allowed to return batches with arbitrary structures of the following types, which will be fed directly to the train_batch and evaluate_batch functions:

  • np.ndarray

    np.array([[0, 0], [0, 0]])
  • torch.Tensor

    torch.Tensor([[0, 0], [0, 0]])
  • tuple of np.ndarrays or torch.Tensors

    (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]]))
  • list of np.ndarrays or torch.Tensors

    [torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])]
  • dictionary mapping strings to np.ndarrays or torch.Tensors

    {"data": torch.Tensor([[0, 0], [0, 0]]), "label": torch.Tensor([[1, 1], [1, 1]])}
  • combination of the above

        "data": [
            {"sub_data1": torch.Tensor([[0, 0], [0, 0]])},
            {"sub_data2": torch.Tensor([0, 0])},
        "label": (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])),

Trial Context

determined.pytorch.PyTorchTrialContext subclasses determined.TrialContext. It provides useful methods for writing Trial subclasses.

class determined.pytorch.PyTorchTrialContext(*args: Any, **kwargs: Any)

Contains runtime information for any Determined workflow that uses the PyTorch API.

With this class, users can do the following things:

  1. Wrap PyTorch models, optimizers, and LR schedulers with their Determined-compatible counterparts using wrap_model(), wrap_optimizer(), wrap_lr_scheduler(), respectively. The Determined-compatible objects are capable of transparent distributed training, checkpointing and exporting, mixed-precision training, and gradient aggregation.

  2. Configure apex amp by calling configure_apex_amp() (optional).

  3. Calculate the gradients with backward() on a specified loss.

  4. Run an optimization step with step_optimizer().

  5. Functionalities inherited from determined.TrialContext, including getting the runtime information and properly handling training data in distributed training.

backward(loss: torch.Tensor, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False, create_graph: bool = False) → None

Compute the gradient of current tensor w.r.t. graph leaves.

The arguments are used in the same way as torch.Tensor.backward. See for details.


When using distributed training, we don’t support manual gradient accumulation. That means the gradient on each parameter can only be calculated once on each batch. If a parameter is associated with multiple losses, you can either choose to call backward on only one of those losses or you could set require_grads flag of a parameter or module to false to avoid manual gradient accumulation on that parameter. However, you can do gradient accumulation across batches by setting optimizations.aggregation_frequency in the experiment configuration to be greater than 1.

  • gradient (Tensor or None) – Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless create_graph is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional.

  • retain_graph (bool, optional) – If False, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

  • create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

configure_apex_amp(models: Union[torch.nn.modules.module.Module, List[torch.nn.modules.module.Module]], optimizers: Union[torch.optim.optimizer.Optimizer, List[torch.optim.optimizer.Optimizer]], enabled: Optional[bool] = True, opt_level: Optional[str] = 'O1', cast_model_type: Optional[torch.dtype] = None, patch_torch_functions: Optional[bool] = None, keep_batchnorm_fp32: Union[str, bool, None] = None, master_weights: Optional[bool] = None, loss_scale: Union[float, str, None] = None, cast_model_outputs: Optional[torch.dtype] = None, num_losses: Optional[int] = 1, verbosity: Optional[int] = 1, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = 16777216.0) → Tuple

Configure automatic mixed precision for your models and optimizers. Note that details for apex.amp are handled automatically within Determined after this call.

This function must be called after you have finished constructing your models and optimizers with wrap_model() and wrap_optimizer().

This function has the same arguments as apex.amp.initialize.


When using distributed training and automatic mixed precision, we only support num_losses=1 and calling backward on the loss once.

  • models (torch.nn.Module or list of torch.nn.Module s) – Model(s) to modify/cast.

  • optimizers (torch.optim.Optimizer or list of torch.optim.Optimizer s) – Optimizers to modify/cast. REQUIRED for training.

  • enabled (bool, optional, default=True) – If False, renders all Amp calls no-ops, so your script should run as if Amp were not present.

  • opt_level (str, optional, default="O1") – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above.

  • cast_model_type (torch.dtype, optional, default=None) – Optional property override, see above.

  • patch_torch_functions (bool, optional, default=None) – Optional property override.

  • keep_batchnorm_fp32 (bool or str, optional, default=None) – Optional property override. If passed as a string, must be the string “True” or “False”.

  • master_weights (bool, optional, default=None) – Optional property override.

  • loss_scale (float or str, optional, default=None) – Optional property override. If passed as a string, must be a string representing a number, e.g., “128.0”, or the string “dynamic”.

  • cast_model_outputs (torch.dtype, optional, default=None) – Option to ensure that the outputs of your model is always cast to a particular type regardless of opt_level.

  • num_losses (int, optional, default=1) – Option to tell Amp in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to amp.scale_loss, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.

  • verbosity (int, default=1) – Set to 0 to suppress Amp-related output.

  • min_loss_scale (float, default=None) – Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.

  • max_loss_scale (float, default=2.**24) – Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.


Model(s) and optimizer(s) modified according to the opt_level. If optimizers args were lists, the corresponding return value will also be a list.

is_epoch_end() → bool

Returns true if the current batch is the last batch of the epoch.


Not accurate for variable size epochs.

is_epoch_start() → bool

Returns true if the current batch is the first batch of the epoch.


Not accurate for variable size epochs.

step_optimizer(optimizer: torch.optim.optimizer.Optimizer, clip_grads: Optional[Callable[Iterator, None]] = None, auto_zero_grads: bool = True) → None

Perform a single optimization step.

This function must be called once for each optimizer. However, the order of different optimizers’ steps can be specified by calling this function in different orders. Also, gradient accumulation across iterations is performed by the Determined training loop by setting the experiment configuration field optimizations.aggregation_frequency.

Here is a code example:

def clip_grads(params):
    torch.nn.utils.clip_grad_norm_(params, 0.0001),

self.context.step_optimizer(self.opt1, clip_grads)
  • optimizer (torch.optim.Optimizer) – Which optimizer should be stepped.

  • clip_grads (a function, optional) – This function should have one argument for parameters in order to clip the gradients.

  • auto_zero_grads (bool, optional) – Automatically zero out gradients automatically after stepping the optimizer. If false, you need to call optimizer.zero_grad() manually. Note that if optimizations.aggregation_frequency is greater than 1, auto_zero_grads must be true.

to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor]) → Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]

Map generated data to the device allocated by the Determined cluster.

All the data in the data loader and the models are automatically moved to the allocated device. This method aims at providing a function for the data generated on the fly.

wrap_lr_scheduler(lr_scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode) → torch.optim.lr_scheduler._LRScheduler

Returns a wrapped LR scheduler.

The LR scheduler must use an optimizer wrapped by wrap_optimizer(). If apex.amp is in use, the optimizer must also have been configured with configure_apex_amp().

wrap_model(model: torch.nn.modules.module.Module) → torch.nn.modules.module.Module

Returns a wrapped model.

wrap_optimizer(optimizer: torch.optim.optimizer.Optimizer, backward_passes_per_step: int = 1) → torch.optim.optimizer.Optimizer

Returns a wrapped optimizer.

The optimizer must use the models wrapped by wrap_model(). This function creates a horovod.DistributedOptimizer if using parallel/distributed training.

backward_passes_per_step can be used to specify how many gradient aggregation steps will be performed in a single train_batch call per optimizer step. In most cases, this will just be the default value 1. However, this advanced functionality can be used to support training loops like the one shown below:

def train_batch(
    self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
    data, labels = batch
    output = self.model(data)
    loss1 = output['loss1']
    loss2 = output['loss2']
    self.context.step_optimizer(self.optimizer, backward_passes_per_step=2)
    return {"loss1": loss1, "loss2": loss2}
class determined.pytorch.PyTorchExperimentalContext
wrap_reducer(reducer: Union[Callable, determined.pytorch._reducer.MetricReducer], name: Optional[str] = None, for_training: bool = True, for_validation: bool = True) → determined.pytorch._reducer.MetricReducer

Register a custom reducer that will calculate a metric properly, even with distributed training.

During distributed training and evaluation, many types of metrics must be calculated globally, rather than calculating the metric on each shard of the dataset and averaged or summed. For example, an accurate ROC AUC for dataset cannot be derived from the individual ROC AUC metrics calculated on by each worker.

Determined solves this problem by offering fully customizable metric reducers which are distributed-aware. These are registered by calling context.experimental.wrap_reducer() and are updated by the user during train_batch() or evaluate_batch().

  • reducer (Union[Callable, pytorch.MetricReducer]) – Either a reducer function or a pytorch.MetricReducer. See below for more details.

  • name – (Optional[str] = None): Either a string name to associate with the metric returned by the reducer, or None to indicate the metric will return a dict mapping string names to metric values. This allows for a single reducer to return many metrics, such as for a per-class mean IOU calculation. Note that if name is a string, the returned metric must NOT be a dict-type metric.

  • for_training – (bool = True): Indicate that the reducer should be used for training workloads.

  • for_validation – (bool = True): Indicate that the reducer should be used for validation workloads.

Return Value:

If reducer was a function, the returned MetricReducer will have a single user-facing method like def update(value: Any) -> None that you should call during train_batch or evaluate_batch. Otherwise, the return value will just be the reducer that was passed in.

Reducer functions: the simple API

If the reducer parameter is a function, it must have the following properities:

  • It accepts a single parameter, which will be a flat list of all inputs the users passes when they call .update() on the object returned by wrap_reducer(). See the code example below for more details.

  • It returns either a single (non-dict) metric or a dictionary mapping names to metrics, as desribed above.

The primary motivation for passing a function as the reducer is simplicity. Metrics from all batches will be buffered in memory and passed over the network before they are reduced all at once. This introduces some overhead, but it is likely unnoticeable for scalar metrics or on validation datasets of small or medium size. This single function strategy may also be desirable for quick prototyping or for calculating metrics that are difficult or impossible to calculate incrementally.

For example, ROC AUC could be properly calculated by passing a small wrapper function calling sklearn.metrics.roc_auc_score:

# Custom reducer function.
def roc_auc_reducer(values):
    # values will be a flat list of all inputs to
    # .update(), which in this code example are
    # tuples of (y_true, y_score).  We reshape
    # that list into two separate lists:
    y_trues, y_scores = zip(*values)

    # Then we return a metric value:
    return sklearn.metrics.roc_auc_score(
        np.array(y_trues), np.array(y_scores)

class MyPyTorchTrial(PyTorchTrial):
    def __init__(self, context):
        self.roc_auc = context.experimental.wrap_reducer(
            roc_auc_reducer, name="roc_auc"

    def evaluate_batch(self, batch):
        # Function-based reducers are updated with .update().
        # The roc_auc_reducer function will get a list of all
        # inputs that we pass in here:
        self.roc_auc.update((y_true, y_score))

        # The "roc_auc" metric will be included in the final
        # metrics after the workload has completed; no need
        # to return it here.  If that is your only metric,
        # just return an empty dict.
        return {}

MetricReducer objects: the advanced API

The primary motivation for passing a det.pytorch.MetricReducer as the reducer is performance. det.pytorch.MetricReducer allows the user more control in how values are stored and exposes a per_slot_reduce() call which lets users minimize the cost of the network communication before the final cross_slot_reduce().

An additional reason for using the det.pytorch.MetricReducer

For the full details and a code example, see: MetricReducer.

Gradient Clipping

Users need to pass a gradient clipping function to determined.pytorch.PyTorchTrialContext.step_optimizer().

Reducing Metrics

Determined supports proper reduction of arbitrary training and validation metrics, even during distributed training, by allowing users to define custom reducers. Custom reducers can be either a function or an implementation of the determined.pytorch.MetricReducer interface.

See determined.pytorch.PyTorchExperimentalContext.wrap_reducer() for more details.

class determined.pytorch.MetricReducer

Efficiently aggregating validation metrics during a multi-slot distributed trial is done in three steps:

  1. Gather all the values to be reduced during the reduction window (either a training or a validation workload). In a multi-slot trial, this is done on each slot in parallel.

  2. Calculate the per-slot reduction. This will return some intermediate value that each slot will contribute to the final metric calculation. It can be as simple as a list of all the raw values from step 1, but reducing the intermediate value locally will distribute the final metric calculation more efficiently and will reduce network communication costs.

  3. Reduce the per-slot reduction values from Step 2 into a final metric.

The MetricReducer API makes it possible for users to define a maximally efficient custom metric by exposing these steps to users:

  • Step 1 is defined by the user; it is not part of the interface. This flexibility gives the user full control when gathering individual values for reduction.

  • Step 2 is the MetricReducer.per_slot_reduce() interface.

  • Step 3 is the MetricReducer.cross_slot_reduce() interface.

  • The MetricReducer.reset() interface allows for MetricReducer reuse across many train and validation workloads.

Example implementation and usage:

class MyAvgMetricReducer(pytorch.MetricReducer):
    def __init__(self):

    def reset(self):
        self.sum = 0
        self.counts = 0

    # User-defined mechanism for collecting values throughout
    # training or validation. This update() mechanism demonstrates
    # a computationally- and memory-efficient way to store the values.
    def update(self, value):
        self.sum += sum(value)
        self.counts += 1

    def per_slot_reduce(self):
        # Because the chosen update() mechanism is so
        # efficient, this is basically a noop.
        return self.sum, self.counts

    def cross_slot_reduce(self, per_slot_metrics):
        # per_slot_metrics is a list of (sum, counts) tuples
        # returned by the self.pre_slot_reduce() on each slot
        sums, counts = zip(*per_slot_metrics)
        return sum(sums) / sum(counts)

class MyPyTorchTrial(pytorch.PyTorchTrial):
    def __init__(self, context):
        # Register your custom reducer.
        self.my_avg = context.experimental.wrap_reducer(
            MyAvgMetricReducer(), name="my_avg"

    def train_batch(self, batch, epoch_idx, batch_idx):
        # You decide how/when you call update().

        # The "my_avg" metric will be included in the final
        # metrics after the workload has completed; no need
        # to return it here.
        return {"loss": loss}

See also: determined.pytorch.PyTorchExperimentalContext.wrap_reducer().

abstract reset() → None

Reset reducer state for another set of values.

This will be called before any train or validation workload begins.

abstract per_slot_reduce() → Any

This will be called after all workers have finished (even when there is only one worker).

It should return some picklable value that is meaningful for cross_slot_reduce.

This will be called after any train or validation workload ends.

abstract cross_slot_reduce(per_slot_metrics: List) → Any

This will be called after per_slot_reduce has finished (even when there is only one worker).

The per_slot_metrics will be a list containing the output of per_slot_reduce() from each worker.

The return value should either be:
  • A dict mapping string metric names to metric values, if the call to context.wrap_metric() omitted the name parameter, or

  • A non-dict metric value if the call to context.wrap_metric() had name set to a string (an error will be raised if a dict-type metric is returned but name was set).

This will be called after per_slot_reduce.


To execute arbitrary Python code during the lifecycle of a PyTorchTrial, implement the callback interface:

class determined.pytorch.PyTorchCallback

Abstract base class used to define a callback that should execute during the lifetime of a PyTorchTrial.


If you are defining a stateful callback (e.g., it mutates a self attribute over its lifetime), you must also override state_dict() and load_state_dict() to ensure this state can be serialized and deserialized over checkpoints.


If distributed training is enabled, every GPU will execute a copy of this callback (except for on_validation_end(), on_validation_step_end() and on_checkpoint_end()). To configure a callback implementation to execute on a subset of GPUs, please condition your implementation on trial.context.distributed.get_rank().

load_state_dict(state_dict: Dict[str, Any]) → None

Load the state of this using the deserialized state_dict.

on_before_optimizer_step(parameters: Iterator) → None

Run before every optimizer.step(). For multi-GPU training, executes after gradient updates have been communicated. Typically used to perform gradient clipping.


This is deprecated. Please pass a function into context.optimizer.step(clip_gradients=...) if you want to clip gradients.

on_checkpoint_end(checkpoint_dir: str) → None

Run after every checkpoint.


This callback only executes on the chief GPU when doing distributed training.

on_validation_end(metrics: Dict[str, Any]) → None

Run after every validation ends.


This callback only executes on the chief GPU when doing distributed training.

on_validation_start() → None

Run before every validation begins.

on_validation_step_end(metrics: Dict[str, Any]) → None

Run after every validation step ends.


This callback only executes on the chief GPU when doing distributed training.

on_validation_step_start() → None

Run before every validation step begins.

state_dict() → Dict[str, Any]

Serialize the state of this callback to a dictionary. Return value must be pickle-able.


To use the torch.optim.lr_scheduler.ReduceLROnPlateau class with PyTorchTrial, implement the following callback:

class ReduceLROnPlateauEveryValidationStep(PyTorchCallback):
    def __init__(self, context):
        self.reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
            context.get_optimizer(), "min", verbose=True
        )  # customize arguments as desired here

    def on_validation_end(self, metrics):

    def state_dict(self):
        return self.reduce_lr.state_dict()

    def load_state_dict(self, state_dict):

Then, implement the build_callbacks function in PyTorchTrial:

def build_callbacks(self):
    return {"reduce_lr": ReduceLROnPlateauEveryValidationStep(self.context)}

Migration from deprecated interface

The current PyTorch interface is designed to be flexible and to support multiple models, optimizers, and LR schedulers. The ability to run forward and backward passes in an arbitrary order affords users much greater flexibility compared to the deprecated approach used in Determined 0.12.12 and earlier.

To migrate from the previous PyTorch API, please change the following places in your code:

  1. Wrap models, optimizers, and LR schedulers in the __init__() method with the wrap_model, wrap_optimizer, and wrap_lr_scheduler methods that are provided by PyTorchTrialContext. At the same time, remove the implementation of build_model(), optimizer(), create_lr_scheduler().

  2. If using automatic mixed precision (AMP), configure Apex AMP in the __init__ method with the context.configure_apex_amp method. At the same time, remove the experiment configuration field optimizations.mixed_precision.

  3. Run backward passes on losses and step optimizers in the train_batch() method with the backward and step_optimizer methods provided by PyTorchTrialContext. Clip gradients by passing a function to the clip_grads argument of step_optimizer while removing the PyTorchCallback counterpart in the build_callbacks() method.