class determined.pytorch.PyTorchTrial(context: determined.pytorch._pytorch_context.PyTorchTrialContext)

PyTorch trials are created by subclassing this abstract class.

We can do the following things in this trial class:

  • Define models, optimizers, and LR schedulers.

    Initialize models, optimizers, and LR schedulers and wrap them with wrap_model, wrap_optimizer, wrap_lr_scheduler provided by PyTorchTrialContext in the __init__().

  • Run forward and backward passes.

    Call backward and step_optimizer provided by PyTorchTrialContext in train_batch(). Note that we support arbitrary numbers of models, optimizers, and LR schedulers and arbitrary orders of running forward and backward passes.

  • Configure automatic mixed precision.

    Call configure_apex_amp provided by PyTorchTrialContext in the __init__().

  • Clip gradients.

    In the train_batch(), pass a function into step_optimizer(optimizer, clip_grads=...) provided by PyTorchTrialContext.

abstract __init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext) → None

Initializes a trial using the provided context. The general steps are:

  1. Initialize model(s) and wrap them with context.wrap_model.

  2. Initialize optimizer(s) and wrap them with context.wrap_optimizer.

  3. Initialize learning rate schedulers and wrap them with context.wrap_lr_scheduler.

  4. If desired, wrap models and optimizer with context.configure_apex_amp to use apex.amp for automatic mixed precision.

Here is a code example.

self.context = context

self.a = self.context.wrap_model(MyModelA())
self.b = self.context.wrap_model(MyModelB())
self.opt1 = self.context.wrap_optimizer(torch.optm.Adam(self.a))
self.opt2 = self.context.wrap_optimizer(torch.optm.Adam(self.b))

(self.a, self.b), (self.opt1, self.opt2) = self.context.configure_apex_amp(
    models=[self.a, self.b],
    optimizers=[self.opt1, self.opt2],

self.lrs1 = self.context.wrap_lr_scheduler(
    lr_scheduler=LambdaLR(self.opt1, lr_lambda=lambda epoch: 0.95 ** epoch),
abstract train_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], epoch_idx: int, batch_idx: int) → Union[torch.Tensor, Dict[str, Any]]

Train on one batch.

Users should implement this function by doing the following things:

  1. Run forward passes on the models.

  2. Calculate the gradients with the losses with context.backward.

  3. Call an optimization step for the optimizers with context.step_optimizer. You can clip gradients by specifying the argument clip_grads.

  4. Step LR schedulers if using manual step mode.

  5. Return training metrics in a dictionary.

Here is a code example.

# Assume two models, two optimizers, and two LR schedulers were initialized
# in ``__init__``.

# Calculate the losses using the models.
loss1 = self.model1(batch)
loss2 = self.model2(batch)

# Run backward passes on losses and step optimizers. These can happen
# in arbitrary orders.
    clip_grads=lambda params: torch.nn.utils.clip_grad_norm_(params, 0.0001),

# Step the learning rate.

return {"loss1": loss1, "loss2": loss2}
  • batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for training.

  • epoch_idx (integer) – index of the current epoch among all the batches processed per device (slot) since the start of training.

  • batch_idx (integer) – index of the current batch among all the epoches processed per device (slot) since the start of training.


training metrics to return.

Return type

torch.Tensor or Dict[str, Any]

abstract build_training_data_loader() → determined.pytorch._data.DataLoader

Defines the data loader to use during training.

Must return an instance of determined.pytorch.DataLoader.

abstract build_validation_data_loader() → determined.pytorch._data.DataLoader

Defines the data loader to use during validation.

Must return an instance of determined.pytorch.DataLoader.

build_callbacks() → Dict[str, determined.pytorch._callback.PyTorchCallback]

Defines a dictionary of string names to callbacks to be used during training and/or validation.

The string name will be used as the key to save and restore callback state for any callback that defines load_state_dict() and state_dict().

evaluate_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]) → Dict[str, Any]

Calculate evaluation metrics for a batch and return them as a dictionary mapping metric names to metric values.

There are two ways to specify evaluation metrics. Either override evaluate_batch() or evaluate_full_dataset(). While evaluate_full_dataset() is more flexible, evaluate_batch() should be preferred, since it can be parallelized in distributed environments, whereas evaluate_full_dataset() cannot. Only one of evaluate_full_dataset() and evaluate_batch() should be overridden by a trial.

The metrics returned from this function must be JSON-serializable.


batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for evaluating.

evaluation_reducer() → Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]]

Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to determined.pytorch.Reducer.AVG.

evaluate_full_dataset(data_loader: → Dict[str, Any]

Calculate validation metrics on the entire validation dataset and return them as a dictionary mapping metric names to reduced metric values (i.e., each returned metric is the average or sum of that metric across the entire validation set).

This validation can not be distributed and is performed on a single device, even when multiple devices (slots) are used for training. Only one of evaluate_full_dataset() and evaluate_batch() should be overridden by a trial.

The metrics returned from this function must be JSON-serializable.


data_loader ( – data loader for evaluating.

class determined.pytorch.LRScheduler(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode)

Wrapper for a PyTorch LRScheduler.

This wrapper fulfills two main functions:

  1. Save and restore the learning rate when a trial is paused, preempted, etc.

  2. Step the learning rate scheduler at the configured frequency (e.g., every batch or every epoch).

class StepMode

Specifies when and how scheduler.step() should be executed.

__init__(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode)

LRScheduler constructor

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler to be used by Determined.

  • step_mode (det.pytorch.LRSchedulerStepMode) –

    The strategy Determined will use to call (or not call) scheduler.step().

    1. STEP_EVERY_EPOCH: Determined will call scheduler.step() after every training epoch. No arguments will be passed to step().

    2. STEP_EVERY_BATCH: Determined will call scheduler.step() after every training batch. No arguments will be passed to step().

    3. MANUAL_STEP: Determined will not call scheduler.step() at all. It is up to the user to decide when to call scheduler.step(), and whether to pass any arguments.

class determined.pytorch.Reducer

A Reducer defines a method for reducing (aggregating) evaluation metrics. See determined.pytorch.PyTorchTrial.evaluation_reducer() for details.

class determined.tensorboard.metric_writers.pytorch.TorchWriter

TorchWriter uses PyTorch file writers and summary operations to write out tfevent files containing scalar batch metrics. It creates an instance of torch.utils.tensorboard.SummaryWriter which can be accessed via the writer field and configures the SummaryWriter to write to the correct directory inside the trial container.

Usage example:

from determined.tensorboard.metric_writers.pytorch import TorchWriter

class MyModel(PyTorchTrial):
    def __init__(self, context):
        self.logger = TorchWriter()

    def train_batch(self, batch, epoch_idx, batch_idx):
        self.logger.writer.add_scalar('my_metric', np.random.random(), batch_idx)

Data Loading

Loading data into PyTorchTrial models is done by defining two functions, build_training_data_loader() and build_validation_data_loader(). These functions should each return an instance of determined.pytorch.DataLoader. determined.pytorch.DataLoader behaves the same as and is a drop-in replacement.

Each DataLoader is allowed to return batches with arbitrary structures of the following types, which will be fed directly to the train_batch and evaluate_batch functions:

  • np.ndarray

    np.array([[0, 0], [0, 0]])
  • torch.Tensor

    torch.Tensor([[0, 0], [0, 0]])
  • tuple of np.ndarrays or torch.Tensors

    (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]]))
  • list of np.ndarrays or torch.Tensors

    [torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])]
  • dictionary mapping strings to np.ndarrays or torch.Tensors

    {"data": torch.Tensor([[0, 0], [0, 0]]), "label": torch.Tensor([[1, 1], [1, 1]])}
  • combination of the above

        "data": [
            {"sub_data1": torch.Tensor([[0, 0], [0, 0]])},
            {"sub_data2": torch.Tensor([0, 0])},
        "label": (torch.Tensor([0, 0]), torch.Tensor([[0, 0], [0, 0]])),

Trial Context

determined.pytorch.PyTorchTrialContext subclasses determined.TrialContext. It provides useful methods for writing Trial subclasses.

class determined.pytorch.PyTorchTrialContext(*args: Any, **kwargs: Any)

Contains runtime information for any Determined workflow that uses the PyTorch API.

With this class, users can do the following things:

  1. Wrap PyTorch models, optimizers, and LR schedulers with their Determined-compatible counterparts using wrap_model(), wrap_optimizer(), wrap_lr_scheduler(), respectively. The Determined-compatible objects are capable of transparent distributed training, checkpointing and exporting, mixed-precision training, and gradient aggregation.

  2. Configure apex amp by calling configure_apex_amp() (optional).

  3. Calculate the gradients with backward() on a specified loss.

  4. Run an optimization step with step_optimizer().

  5. Functionalities inherited from determined.TrialContext, including getting the runtime information and properly handling training data in distributed training.

backward(loss: torch.Tensor, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False, create_graph: bool = False) → None

Compute the gradient of current tensor w.r.t. graph leaves.

The arguments are used in the same way as torch.Tensor.backward. See for details.


When using distributed training, we don’t support manual gradient accumulation. That means the gradient on each parameter can only be calculated once on each batch. If a parameter is associated with multiple losses, you can either choose to call backward on only one of those losses or you could set require_grads flag of a parameter or module to false to avoid manual gradient accumulation on that parameter. However, you can do gradient accumulation across batches by setting optimizations.aggregation_frequency in the experiment configuration to be greater than 1.

  • gradient (Tensor or None) – Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless create_graph is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional.

  • retain_graph (bool, optional) – If False, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

  • create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

configure_apex_amp(models: Union[torch.nn.modules.module.Module, List[torch.nn.modules.module.Module]], optimizers: Union[torch.optim.optimizer.Optimizer, List[torch.optim.optimizer.Optimizer]], enabled: Optional[bool] = True, opt_level: Optional[str] = 'O1', cast_model_type: Optional[torch.dtype] = None, patch_torch_functions: Optional[bool] = None, keep_batchnorm_fp32: Union[str, bool, None] = None, master_weights: Optional[bool] = None, loss_scale: Union[float, str, None] = None, cast_model_outputs: Optional[torch.dtype] = None, num_losses: Optional[int] = 1, verbosity: Optional[int] = 1, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = 16777216.0) → Tuple

Configure automatic mixed precision for your models and optimizers. Note that details for apex.amp are handled automatically within Determined after this call.

This function must be called after you have finished constructing your models and optimizers with wrap_model() and wrap_optimizer().

This function has the same arguments as apex.amp.initialize.


When using distributed training and automatic mixed precision, we only support num_losses=1 and calling backward on the loss once.

  • models (torch.nn.Module or list of torch.nn.Module s) – Model(s) to modify/cast.

  • optimizers (torch.optim.Optimizer or list of torch.optim.Optimizer s) – Optimizers to modify/cast. REQUIRED for training.

  • enabled (bool, optional, default=True) – If False, renders all Amp calls no-ops, so your script should run as if Amp were not present.

  • opt_level (str, optional, default="O1") – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above.

  • cast_model_type (torch.dtype, optional, default=None) – Optional property override, see above.

  • patch_torch_functions (bool, optional, default=None) – Optional property override.

  • keep_batchnorm_fp32 (bool or str, optional, default=None) – Optional property override. If passed as a string, must be the string “True” or “False”.

  • master_weights (bool, optional, default=None) – Optional property override.

  • loss_scale (float or str, optional, default=None) – Optional property override. If passed as a string, must be a string representing a number, e.g., “128.0”, or the string “dynamic”.

  • cast_model_outputs (torch.dtype, optional, default=None) – Option to ensure that the outputs of your model is always cast to a particular type regardless of opt_level.

  • num_losses (int, optional, default=1) – Option to tell Amp in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to amp.scale_loss, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.

  • verbosity (int, default=1) – Set to 0 to suppress Amp-related output.

  • min_loss_scale (float, default=None) – Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.

  • max_loss_scale (float, default=2.**24) – Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.


Model(s) and optimizer(s) modified according to the opt_level. If optimizers args were lists, the corresponding return value will also be a list.

is_epoch_end() → bool

Returns true if the current batch is the last batch of the epoch.


Not accurate for variable size epochs.

is_epoch_start() → bool

Returns true if the current batch is the first batch of the epoch.


Not accurate for variable size epochs.

step_optimizer(optimizer: torch.optim.optimizer.Optimizer, clip_grads: Optional[Callable[Iterator, None]] = None, auto_zero_grads: bool = True) → None

Perform a single optimization step.

This function must be called once for each optimizer. However, the order of different optimizers’ steps can be specified by calling this function in different orders. Also, gradient accumulation across iterations is performed by the Determined training loop by setting the experiment configuration field optimizations.aggregation_frequency.

Here is a code example:

def clip_grads(params):
    torch.nn.utils.clip_grad_norm_(params, 0.0001),

self.context.step_optimizer(self.opt1, clip_grads)
  • optimizer (torch.optim.Optimizer) – Which optimizer should be stepped.

  • clip_grads (a function, optional) – This function should have one argument for parameters in order to clip the gradients.

  • auto_zero_grads (bool, optional) – Automatically zero out gradients automatically after stepping the optimizer. If false, you need to call optimizer.zero_grad() manually. Note that if optimizations.aggregation_frequency is greater than 1, auto_zero_grads must be true.

to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor]) → Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]

Map generated data to the device allocated by the Determined cluster.

All the data in the data loader and the models are automatically moved to the allocated device. This method aims at providing a function for the data generated on the fly.

wrap_lr_scheduler(lr_scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode) → torch.optim.lr_scheduler._LRScheduler

Returns a wrapped LR scheduler.

The LR scheduler must use an optimizer wrapped by wrap_optimizer(). If apex.amp is in use, the optimizer must also have been configured with configure_apex_amp().

wrap_model(model: torch.nn.modules.module.Module) → torch.nn.modules.module.Module

Returns a wrapped model.

wrap_optimizer(optimizer: torch.optim.optimizer.Optimizer) → torch.optim.optimizer.Optimizer

Returns a wrapped optimizer.

The optimizer must use the models wrapped by wrap_model(). This function creates a horovod.DistributedOptimizer if using parallel/distributed training.

Gradient Clipping

Users need to pass a gradient clipping function to determined.pytorch.PyTorchTrialContext.step_optimizer().


To execute arbitrary Python code during the lifecycle of a PyTorchTrial, implement the callback interface:

class determined.pytorch.PyTorchCallback

Abstract base class used to define a callback that should execute during the lifetime of a PyTorchTrial.


If you are defining a stateful callback (e.g., it mutates a self attribute over its lifetime), you must also override state_dict() and load_state_dict() to ensure this state can be serialized and deserialized over checkpoints.


If distributed training is enabled, every GPU will execute a copy of this callback (except for on_validation_end(), on_validation_step_end() and on_checkpoint_end()). To configure a callback implementation to execute on a subset of GPUs, please condition your implementation on trial.context.distributed.get_rank().

load_state_dict(state_dict: Dict[str, Any]) → None

Load the state of this using the deserialized state_dict.

on_before_optimizer_step(parameters: Iterator) → None

Run before every optimizer.step(). For multi-GPU training, executes after gradient updates have been communicated. Typically used to perform gradient clipping.


This is deprecated. Please pass a function into context.optimizer.step(clip_gradients=...) if you want to clip gradients.

on_checkpoint_end(checkpoint_dir: str) → None

Run after every checkpoint.


This callback only executes on the chief GPU when doing distributed training.

on_validation_end(metrics: Dict[str, Any]) → None

Run after every validation ends.


This callback only executes on the chief GPU when doing distributed training.

on_validation_start() → None

Run before every validation begins.

on_validation_step_end(metrics: Dict[str, Any]) → None

Run after every validation step ends.


This callback only executes on the chief GPU when doing distributed training.

on_validation_step_start() → None

Run before every validation step begins.

state_dict() → Dict[str, Any]

Serialize the state of this callback to a dictionary. Return value must be pickle-able.


To use the torch.optim.lr_scheduler.ReduceLROnPlateau class with PyTorchTrial, implement the following callback:

class ReduceLROnPlateauEveryValidationStep(PyTorchCallback):
    def __init__(self, context):
        self.reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
            context.get_optimizer(), "min", verbose=True
        )  # customize arguments as desired here

    def on_validation_end(self, metrics):

    def state_dict(self):
        return self.reduce_lr.state_dict()

    def load_state_dict(self, state_dict):

Then, implement the build_callbacks function in PyTorchTrial:

def build_callbacks(self):
    return {"reduce_lr": ReduceLROnPlateauEveryValidationStep(self.context)}

Migration from deprecated interface

The current PyTorch interface is designed to be flexible and to support multiple models, optimizers, and LR schedulers. The ability to run forward and backward passes in an arbitrary order affords users much greater flexibility compared to the deprecated approach used in Determined 0.12.12 and earlier.

To migrate from the previous PyTorch API, please change the following places in your code:

  1. Wrap models, optimizers, and LR schedulers in the __init__() method with the wrap_model, wrap_optimizer, and wrap_lr_scheduler methods that are provided by PyTorchTrialContext. At the same time, remove the implementation of build_model(), optimizer(), create_lr_scheduler().

  2. If using automatic mixed precision (AMP), configure Apex AMP in the __init__ method with the context.configure_apex_amp method. At the same time, remove the experiment configuration field optimizations.mixed_precision.

  3. Run backward passes on losses and step optimizers in the train_batch() method with the backward and step_optimizer methods provided by PyTorchTrialContext. Clip gradients by passing a function to the clip_grads argument of step_optimizer while removing the PyTorchCallback counterpart in the build_callbacks() method.