det.pytorch
API Reference#
User Guide |
---|
Determined offers a PyTorch-based training loop that is fully integrated with the Determined platform which includes:
PyTorchTrial
, which you must subclass to define things like model architecture, optimizer, data loaders, and how to train or validate a single batch.PyTorchTrialContext
, which can be accessed from withinPyTorchTrial
and contains runtime methods used for training with thePyTorch
API.Trainer
, which is used for customizing and executing the training loop around aPyTorchTrial
.
determined.pytorch.PyTorchTrial
#
- class determined.pytorch.PyTorchTrial(context: determined.pytorch._pytorch_context.PyTorchTrialContext)#
PyTorch trials are created by subclassing this abstract class.
We can do the following things in this trial class:
Define models, optimizers, and LR schedulers.
In the
__init__()
method, initialize models, optimizers, and LR schedulers and wrap them withwrap_model
,wrap_optimizer
,wrap_lr_scheduler
provided byPyTorchTrialContext
.Run forward and backward passes.
In
train_batch()
, callbackward
andstep_optimizer
provided byPyTorchTrialContext
. We support arbitrary numbers of models, optimizers, and LR schedulers and arbitrary orders of running forward and backward passes.Configure automatic mixed precision.
In the
__init__()
method, callconfigure_apex_amp
provided byPyTorchTrialContext
.Clip gradients.
In
train_batch()
, pass a function intostep_optimizer(optimizer, clip_grads=...)
provided byPyTorchTrialContext
.
- trial_context_class#
alias of
determined.pytorch._pytorch_context.PyTorchTrialContext
- abstract __init__(context: determined.pytorch._pytorch_context.PyTorchTrialContext) None #
Initializes a trial using the provided
context
. The general steps are:Initialize model(s) and wrap them with
context.wrap_model
.Initialize optimizer(s) and wrap them with
context.wrap_optimizer
.Initialize learning rate schedulers and wrap them with
context.wrap_lr_scheduler
.If desired, wrap models and optimizer with
context.configure_apex_amp
to useapex.amp
for automatic mixed precision.Define custom loss function and metric functions.
Warning
You may see metrics for trials that are paused and later continued that are significantly different from trials that are not paused if some of your models, optimizers, and learning rate schedulers are not wrapped. The reason is that the model’s state may not be restored accurately or completely from the checkpoint, which is saved to a checkpoint and then later loaded into the trial during resumed training. When using PyTorch, this can sometimes happen if the PyTorch API is not used correctly.
Here is a code example.
self.context = context self.a = self.context.wrap_model(MyModelA()) self.b = self.context.wrap_model(MyModelB()) self.opt1 = self.context.wrap_optimizer(torch.optm.Adam(self.a)) self.opt2 = self.context.wrap_optimizer(torch.optm.Adam(self.b)) (self.a, self.b), (self.opt1, self.opt2) = self.context.configure_apex_amp( models=[self.a, self.b], optimizers=[self.opt1, self.opt2], num_losses=2, ) self.lrs1 = self.context.wrap_lr_scheduler( lr_scheduler=LambdaLR(self.opt1, lr_lambda=lambda epoch: 0.95 ** epoch), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH, ))
- abstract train_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], epoch_idx: int, batch_idx: int) Union[torch.Tensor, Dict[str, Any]] #
Train on one batch.
Users should implement this function by doing the following things:
Run forward passes on the models.
Calculate the gradients with the losses with
context.backward
.Call an optimization step for the optimizers with
context.step_optimizer
. You can clip gradients by specifying the argumentclip_grads
.Step LR schedulers if using manual step mode.
Return training metrics in a dictionary.
Here is a code example.
# Assume two models, two optimizers, and two LR schedulers were initialized # in ``__init__``. # Calculate the losses using the models. loss1 = self.model1(batch) loss2 = self.model2(batch) # Run backward passes on losses and step optimizers. These can happen # in arbitrary orders. self.context.backward(loss1) self.context.backward(loss2) self.context.step_optimizer( self.opt1, clip_grads=lambda params: torch.nn.utils.clip_grad_norm_(params, 0.0001), ) self.context.step_optimizer(self.opt2) # Step the learning rate. self.lrs1.step() self.lrs2.step() return {"loss1": loss1, "loss2": loss2}
- Parameters
batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for training.
epoch_idx (integer) – index of the current epoch among all the batches processed per device (slot) since the start of training.
batch_idx (integer) – index of the current batch among all the epochs processed per device (slot) since the start of training.
- Returns
training metrics to return.
- Return type
torch.Tensor or Dict[str, Any]
- abstract build_training_data_loader() Union[determined.pytorch._data.DataLoader, torch.utils.data.dataloader.DataLoader] #
Defines the data loader to use during training.
Most implementations of
determined.pytorch.PyTorchTrial
will return adetermined.pytorch.DataLoader
here. Some use cases may not fit the assumptions ofdetermined.pytorch.DataLoader
. In that event, a baretorch.utils.data.DataLoader
may be returned if steps in the note atop Customize a Reproducible Dataset are followed.
- abstract build_validation_data_loader() Union[determined.pytorch._data.DataLoader, torch.utils.data.dataloader.DataLoader] #
Defines the data loader to use during validation.
Users with a MapDataset will normally return a
determined.pytorch.DataLoader
, but users with an IterableDataset or with other advanced needs may sacrifice some Determined-managed functionality (ex: automatic data sharding) to return a baretorch.utils.data.DataLoader
following the best-practices described in Customize a Reproducible Dataset.
- build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback] #
Defines a dictionary of string names to callbacks to be used during training and/or validation.
The string name will be used as the key to save and restore callback state for any callback that defines
load_state_dict()
andstate_dict()
.
- evaluate_batch(batch: Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], batch_idx: int) Dict[str, Any] #
Calculate validation metrics for a batch and return them as a dictionary mapping metric names to metric values. Per-batch validation metrics are reduced (aggregated) to produce a single set of validation metrics for the entire validation set (see
evaluation_reducer()
).There are two ways to specify evaluation metrics. Either override
evaluate_batch()
orevaluate_full_dataset()
. Whileevaluate_full_dataset()
is more flexible,evaluate_batch()
should be preferred, since it can be parallelized in distributed environments, whereasevaluate_full_dataset()
cannot. Only one ofevaluate_full_dataset()
andevaluate_batch()
should be overridden by a trial.The metrics returned from this function must be JSON-serializable.
- Parameters
batch (Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor) – batch of data for evaluating.
batch_idx (integer) – index of the current batch among all the epochs processed per device (slot) since the start of training.
- evaluation_reducer() Union[determined.pytorch._reducer.Reducer, Dict[str, determined.pytorch._reducer.Reducer]] #
Return a reducer for all evaluation metrics, or a dict mapping metric names to individual reducers. Defaults to
determined.pytorch.Reducer.AVG
.
- evaluate_full_dataset(data_loader: torch.utils.data.dataloader.DataLoader) Dict[str, Any] #
Calculate validation metrics on the entire validation dataset and return them as a dictionary mapping metric names to reduced metric values (i.e., each returned metric is the average or sum of that metric across the entire validation set).
This validation cannot be distributed and is performed on a single device, even when multiple devices (slots) are used for training. Only one of
evaluate_full_dataset()
andevaluate_batch()
should be overridden by a trial.The metrics returned from this function must be JSON-serializable.
- Parameters
data_loader (torch.utils.data.DataLoader) – data loader for evaluating.
- get_batch_length(batch: Any) int #
Count the number of records in a given batch.
Override this method when you are using custom batch types, as produced when iterating over the class:determined.pytorch.DataLoader. For example, when using
pytorch_geometric
:# Extra imports: from determined.pytorch import DataLoader from torch_geometric.data.dataloader import Collater # Trial methods: def build_training_data_loader(self): return DataLoader( self.train_subset, batch_size=self.context.get_per_slot_batch_size(), collate_fn=Collater([], []), ) def get_batch_length(self, batch): # `batch` is `torch_geometric.data.batch.Batch`. return batch.num_graphs
- Parameters
batch (Any) – input training or validation data batch object.
determined.pytorch.PyTorchTrialContext
#
- class determined.pytorch.PyTorchTrialContext(core_context: determined.core._context.Context, trial_seed: Optional[int], hparams: Optional[Dict], slots_per_trial: int, num_gpus: int, exp_conf: Optional[Dict[str, Any]], aggregation_frequency: int, steps_completed: int, managed_training: bool, debug_enabled: bool, enable_tensorboard_logging: bool = True)#
Contains runtime information for any Determined workflow that uses the
PyTorch
API.With this class, users can do the following things:
Wrap PyTorch models, optimizers, and LR schedulers with their Determined-compatible counterparts using
wrap_model()
,wrap_optimizer()
,wrap_lr_scheduler()
, respectively. The Determined-compatible objects are capable of transparent distributed training, checkpointing and exporting, mixed-precision training, and gradient aggregation.Configure apex amp by calling
configure_apex_amp()
(optional).Calculate the gradients with
backward()
on a specified loss.Run an optimization step with
step_optimizer()
.Functionalities inherited from
determined.TrialContext
, including getting the runtime information and properly handling training data in distributed training.
- backward(loss: torch.Tensor, gradient: Optional[torch.Tensor] = None, retain_graph: bool = False, create_graph: bool = False) None #
Compute the gradient of current tensor w.r.t. graph leaves.
The arguments are used in the same way as
torch.Tensor.backward
. See https://pytorch.org/docs/1.4.0/_modules/torch/tensor.html#Tensor.backward for details.Warning
When using distributed training, we don’t support manual gradient accumulation. That means the gradient on each parameter can only be calculated once on each batch. If a parameter is associated with multiple losses, you can either choose to call
backward'' on only one of those losses, or you can set the ``require_grads
flag of a parameter or module toFalse
to avoid manual gradient accumulation on that parameter. However, you can do gradient accumulation across batches by setting optimizations.aggregation_frequency in the experiment configuration to be greater than 1.- Parameters
gradient (Tensor or None) – Gradient w.r.t. the tensor. If it is a tensor, it will be automatically converted to a Tensor that does not require grad unless
create_graph
is True. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable then this argument is optional.retain_graph (bool, optional) – If
False
, the graph used to compute the grads will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value ofcreate_graph
.create_graph (bool, optional) – If
True
, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults toFalse
.
- configure_apex_amp(models: Union[torch.nn.modules.module.Module, List[torch.nn.modules.module.Module]], optimizers: Union[torch.optim.optimizer.Optimizer, List[torch.optim.optimizer.Optimizer]], enabled: Optional[bool] = True, opt_level: Optional[str] = 'O1', cast_model_type: Optional[torch.dtype] = None, patch_torch_functions: Optional[bool] = None, keep_batchnorm_fp32: Optional[Union[bool, str]] = None, master_weights: Optional[bool] = None, loss_scale: Optional[Union[float, str]] = None, cast_model_outputs: Optional[torch.dtype] = None, num_losses: Optional[int] = 1, verbosity: Optional[int] = 1, min_loss_scale: Optional[float] = None, max_loss_scale: Optional[float] = 16777216.0) Tuple #
Configure automatic mixed precision for your models and optimizers using NVIDIA’s Apex PyTorch extension. Note that details for
apex.amp
are handled automatically within Determined after this call.This function must be called after you have finished constructing your models and optimizers with
wrap_model()
andwrap_optimizer()
.This function has the same arguments as apex.amp.initialize.
Warning
When using distributed training and automatic mixed precision, we only support
num_losses=1
and calling backward on the loss once.- Parameters
models (
torch.nn.Module
or list oftorch.nn.Module
s) – Model(s) to modify/cast.optimizers (
torch.optim.Optimizer
or list oftorch.optim.Optimizer
s) – Optimizers to modify/cast. REQUIRED for training.enabled (bool, optional, default=True) – If False, renders all Amp calls no-ops, so your script should run as if Amp were not present.
opt_level (str, optional, default="O1") – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above.
cast_model_type (
torch.dtype
, optional, default=None) – Optional property override, see above.patch_torch_functions (bool, optional, default=None) – Optional property override.
keep_batchnorm_fp32 (bool or str, optional, default=None) – Optional property override. If passed as a string, must be the string “True” or “False”.
master_weights (bool, optional, default=None) – Optional property override.
loss_scale (float or str, optional, default=None) – Optional property override. If passed as a string, must be a string representing a number, e.g., “128.0”, or the string “dynamic”.
cast_model_outputs (torch.dtype, optional, default=None) – Option to ensure that the outputs of your model is always cast to a particular type regardless of
opt_level
.num_losses (int, optional, default=1) – Option to tell Amp in advance how many losses/backward passes you plan to use. When used in conjunction with the
loss_id
argument toamp.scale_loss
, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. Ifnum_losses
is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.verbosity (int, default=1) – Set to 0 to suppress Amp-related output.
min_loss_scale (float, default=None) – Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used,
min_loss_scale
is ignored.max_loss_scale (float, default=2.**24) – Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used,
max_loss_scale
is ignored.
- Returns
Model(s) and optimizer(s) modified according to the
opt_level
. Ifoptimizers
args were lists, the corresponding return value will also be a list.
- current_train_batch() int #
Current global batch index
- get_data_config() Dict[str, Any] #
Return the data configuration.
- get_enable_tensorboard_logging() bool #
Return whether automatic tensorboard logging is enabled
- get_experiment_id() int #
Return the experiment ID of the current trial.
- get_global_batch_size() int #
Return the global batch size.
- get_hparam(name: str) Any #
Return the current value of the hyperparameter with the given name.
- get_per_slot_batch_size() int #
Return the per-slot batch size. When a model is trained with a single GPU, this is equal to the global batch size. When multi-GPU training is used, this is equal to the global batch size divided by the number of GPUs used to train the model.
- get_stop_requested() bool #
Return whether a trial stoppage has been requested.
- get_tensorboard_path() pathlib.Path #
Get the path where files for consumption by TensorBoard should be written
- get_tensorboard_writer() Any #
This function returns an instance of
torch.utils.tensorboard.SummaryWriter
Trials users who wish to log to TensorBoard can use this writer object. We provide and manage a writer in order to save and upload TensorBoard files automatically on behalf of the user.
Usage example:
class MyModel(PyTorchTrial): def __init__(self, context): ... self.writer = context.get_tensorboard_writer() def train_batch(self, batch, epoch_idx, batch_idx): self.writer.add_scalar('my_metric', np.random.random(), batch_idx) self.writer.add_image('my_image', torch.ones((3,32,32)), batch_idx)
- get_trial_id() int #
Return the trial ID of the current trial.
- is_epoch_end() bool #
Returns true if the current batch is the last batch of the epoch.
Warning
Not accurate for variable size epochs.
- is_epoch_start() bool #
Returns true if the current batch is the first batch of the epoch.
Warning
Not accurate for variable size epochs.
- set_enable_tensorboard_logging(enable_tensorboard_logging: bool) None #
Set a flag to indicate whether automatic upload to tensorboard is enabled.
- set_profiler(*args: List[str], **kwargs: Any) None #
set_profiler()
is a thin wrapper around the native PyTorch profiler, torch-tb-profiler. It overrides theon_trace_ready
parameter to the determined tensorboard path, while all other arguments are passed directly intotorch.profiler.profile
. Stepping the profiler will be handled automatically during the training loop.See the PyTorch profiler plugin for details.
Examples:
Profiling GPU and CPU activities, skipping batch 1, warming up on batch 2, and profiling batches 3 and 4.
self.context.set_profiler( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( wait=1, warmup=1, active=2 ), )
- set_stop_requested(stop_requested: bool) None #
Set a flag to request a trial stoppage. When this flag is set to True, we finish the step, checkpoint, then exit.
- step_optimizer(optimizer: torch.optim.optimizer.Optimizer, clip_grads: Optional[Callable[[Iterator], None]] = None, auto_zero_grads: bool = True, scaler: Optional[Any] = None) None #
Perform a single optimization step.
This function must be called once for each optimizer. However, the order of different optimizers’ steps can be specified by calling this function in different orders. Also, gradient accumulation across iterations is performed by the Determined training loop by setting the experiment configuration field optimizations.aggregation_frequency.
Here is a code example:
def clip_grads(params): torch.nn.utils.clip_grad_norm_(params, 0.0001), self.context.step_optimizer(self.opt1, clip_grads)
- Parameters
optimizer (
torch.optim.Optimizer
) – Which optimizer should be stepped.clip_grads (a function, optional) – This function should have one argument for parameters in order to clip the gradients.
auto_zero_grads (bool, optional) – Automatically zero out gradients automatically after stepping the optimizer. If false, you need to call
optimizer.zero_grad()
manually. Note that if optimizations.aggregation_frequency is greater than 1,auto_zero_grads
must be true.scaler (
torch.cuda.amp.GradScaler
, optional) – The scaler to use for stepping the optimizer. This should be unset if not using AMP, and is necessary ifwrap_scaler()
was called directly.
- to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor]) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor] #
Map generated data to the device allocated by the Determined cluster.
All the data in the data loader and the models are automatically moved to the allocated device. This method aims at providing a function for the data generated on the fly.
- wrap_lr_scheduler(lr_scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode, frequency: int = 1) torch.optim.lr_scheduler._LRScheduler #
Returns a wrapped LR scheduler.
The LR scheduler must use an optimizer wrapped by
wrap_optimizer()
. Ifapex.amp
is in use, the optimizer must also have been configured withconfigure_apex_amp()
.
- wrap_model(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module #
Returns a wrapped model.
- wrap_optimizer(optimizer: torch.optim.optimizer.Optimizer, backward_passes_per_step: int = 1, fp16_compression: Optional[bool] = None, average_aggregated_gradients: Optional[bool] = None) torch.optim.optimizer.Optimizer #
Returns a wrapped optimizer.
The optimizer must use the models wrapped by
wrap_model()
. This function creates ahorovod.DistributedOptimizer
if using parallel/distributed training.backward_passes_per_step
can be used to specify how many gradient aggregation steps will be performed in a singletrain_batch
call per optimizer step. In most cases, this will just be the default value 1. However, this advanced functionality can be used to support training loops like the one shown below:def train_batch( self, batch: TorchData, epoch_idx: int, batch_idx: int ) -> Dict[str, torch.Tensor]: data, labels = batch output = self.model(data) loss1 = output['loss1'] loss2 = output['loss2'] self.context.backward(loss1) self.context.backward(loss2) self.context.step_optimizer(self.optimizer, backward_passes_per_step=2) return {"loss1": loss1, "loss2": loss2}
- wrap_scaler(scaler: Any) Any #
Prepares to use automatic mixed precision through PyTorch’s native AMP API. The returned scaler should be passed to
step_optimizer
, but usage does not otherwise differ from vanilla PyTorch APIs. Loss should be scaled before callingbackward
,unscale_
should be called before clipping gradients,update
should be called after stepping all optimizers, etc.PyTorch 1.6 or greater is required for this feature.
- Parameters
scaler (
torch.cuda.amp.GradScaler
) – Scaler to wrap and track.- Returns
The scaler. It may be wrapped to add additional functionality for use in Determined.
determined.pytorch.PyTorchTrialContext.distributed
#
- class determined.core._distributed.DistributedContext(*, rank: int, size: int, local_rank: int, local_size: int, cross_rank: int, cross_size: int, chief_ip: Optional[str] = None, pub_port: int = 12360, pull_port: int = 12376, port_offset: int = 0, force_tcp: bool = False)
DistributedContext provides useful methods for effective distributed training.
- A DistributedContext has the following required args:
rank: the index of this worker in the entire job
size: the number of workers in the entire job
local_rank: the index of this worker on this machine
local_size: the number of workers on this machine
cross_rank: the index of this machine in the entire job
cross_size: the number of machines in the entire job
- Additionally, any time that cross_size > 1, you must also provide:
chief_ip: the ip address to reach the chief worker (where rank==0)
Note
DistributedContext has
.allgather()
,.gather()
, and.broadcast()
methods, which are easy to use and which can be useful for coordinating work across workers, but it is not a replacement for the allgather/gather/broadcast operations in your particular distributed training framework.- classmethod from_horovod(hvd: Any, chief_ip: Optional[str] = None) determined.core._distributed.DistributedContext
Create a
DistributedContext
using the providedhvd
module to determine rank information.Example:
import horovod.torch as hvd hvd.init() distributed = DistributedContext.from_horovod(hvd)
The IP address for the chief worker is required whenever
hvd.cross_size() > 1
. The value may be provided using thechief_ip
argument or theDET_CHIEF_IP
environment variable.
- classmethod from_deepspeed(chief_ip: Optional[str] = None) determined.core._distributed.DistributedContext
Create a
DistributedContext
using the standard deepspeed environment variables to determine rank information.The IP address for the chief worker is required whenever CROSS_SIZE > 1. The value may be provided using the chief_ip argument or the DET_CHIEF_IP environment variable.
- classmethod from_torch_distributed(chief_ip: Optional[str] = None) determined.core._distributed.DistributedContext
Create a DistributedContext using the standard torch distributed environment variables to determine rank information.
The IP address for the chief worker is required whenever CROSS_SIZE > 1. The value may be provided via the chief_ip argument or the DET_CHIEF_IP environment variable.
- get_rank() int
Return the rank of the process in the trial. The rank of a process is a unique ID within the trial. That is, no two processes in the same trial are assigned the same rank.
- get_local_rank() int
Return the rank of the process on the agent. The local rank of a process is a unique ID within a given agent and trial; that is, no two processes in the same trial that are executing on the same agent are assigned the same rank.
- get_size() int
Return the number of slots this trial is running on.
- get_num_agents() int
Return the number of agents this trial is running on.
- gather(stuff: Any) Optional[List]
Gather
stuff
to the chief. The chief returns a list of all stuff, and workers returnNone
.gather()
is not a replacement for the gather functionality of your distributed training framework.
- gather_local(stuff: Any) Optional[List]
Gather
stuff
to the local chief. The local chief returns a list of all stuff, and local workers returnNone
.gather_local()
is not a replacement for the gather functionality of your distributed training framework.
- allgather(stuff: Any) List
Gather
stuff
to the chief and broadcast all of it back to the workers.allgather()
is not a replacement for the allgather functionality of your distributed training framework.
- allgather_local(stuff: Any) List
Gather
stuff
to the local chief and broadcast all of it back to the local workers.allgather_local()
is not a replacement for the allgather functionality of your distributed training framework.
- broadcast(stuff: Any) Any
Every worker gets the
stuff
sent by the chief.broadcast()
is not a replacement for the broadcast functionality of your distributed training framework.
- broadcast_local(stuff: Optional[Any] = None) Any
Every worker gets the
stuff
sent by the local chief.broadcast_local()
is not a replacement for the broadcast functionality of your distributed training framework.
determined.pytorch.PyTorchExperimentalContext
#
- class determined.pytorch.PyTorchExperimentalContext(parent: Any)#
- disable_auto_to_device() None #
Prevent the PyTorchTrialController from automatically moving batched data to device. Call this if you want to override the default behavior of moving all items of a list, tuple, and/or dict to the GPU. Then, you can control how data is moved to the GPU directly in the
train_batch
andevaluate_batch
methods of your PyTorchTrial definition. You should call context.to_device on primitive data types that you do want to move to GPU as in the example below.# PyTorchTrial methods. def __init__(context): # PyTorchTrial init self.context.experimental.disable_auto_to_device() ... def train_batch(self, context, batch): for k, item in batch.items(): if k == "img": batch["img"] = self.context.to_device(batch["img"]) ...
- disable_dataset_reproducibility_checks() None #
disable_dataset_reproducibility_checks()
allows you to return an arbitraryDataLoader
frombuild_training_data_loader()
orbuild_validation_data_loader()
.Normally you would be required to return a
det.pytorch.DataLoader
instead, which would guarantee that an appropriateSampler
is used that ensures:When
shuffle=True
, the shuffle is reproducible.The dataset will start at the right location, even after pausing/continuing.
Proper sharding is used during distributed training.
However, there may be cases where either reproducibility of the dataset is not needed or where the nature of the dataset may cause the
det.pytorch.DataLoader
to be unsuitable.In those cases, you may call
disable_dataset_reproducibility_checks()
and you will be free to return anytorch.utils.data.DataLoader
you like. Dataset reproducibility will still be possible, but it will be your responsibility. If desired, you may find theSampler
classes indetermined.pytorch.samplers
to be helpful.
- use_amp() None #
Handles all operations for the most simple cases automatically with a default gradient scaler. Specifically, wraps forward pass in an autocast context, scales loss before backward pass, unscales before clipping gradients, uses scaler when stepping optimizer(s), and updates scaler afterwards. Do not call
wrap_scaler
directly when using this method.PyTorch 1.6 or greater is required for this feature.
determined.pytorch.DataLoader
#
- class determined.pytorch.DataLoader(dataset: torch.utils.data.dataset.Dataset, batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[torch.utils.data.sampler.Sampler] = None, batch_sampler: Optional[torch.utils.data.sampler.BatchSampler] = None, num_workers: int = 0, collate_fn: Optional[Callable[[List[determined.pytorch._data.T]], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[[int], None]] = None, multiprocessing_context: Optional[Any] = None, generator: Optional[Any] = None, *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False)#
DataLoader is meant to contain a user’s
data.Dataset
, configuration for sampling data in batches, and performance configuration like multiprocessing.The __init__ function determines the defaults in the same way as a
torch.utils.data.DataLoader
would, so the behavior should be familiar. However, thetorch.utils.data.Dataloader
that is used for training and validation is not created untilget_data_loader(...)
is called. This is done so that Determined can ensure that sampling restarts from the right location and distributed sampling is handled correctly.Note that the arguments are from PyTorch.
- Parameters
dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default:
1
).shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default:False
).sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified,
shuffle
must beFalse
.batch_sampler (Sampler, optional) – like
sampler
, but returns a batch of indices at a time. Mutually exclusive withbatch_size
,shuffle
,sampler
, anddrop_last
.num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (default:0
)collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory (bool, optional) – If
True
, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or yourcollate_fn
returns a batch that is a custom type, see the example below.drop_last (bool, optional) – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False
)timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default:
0
)worker_init_fn (callable, optional) – If not
None
, this will be called on each worker subprocess with the worker id (an int in[0, num_workers - 1]
) as input, after seeding and before data loading. (default:None
)generator (torch.Generator, optional) – If not
None
, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generatebase_seed
for workers. (default:None
)prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker.
2
means there will be a total of 2 * num_workers samples prefetched across all workers. (default:2
)persistent_workers (bool, optional) – If
True
, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workersDataset
instances alive. (default:False
)
determined.pytorch.LRScheduler
#
- class determined.pytorch.LRScheduler(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode, frequency: int = 1)#
Wrapper for a PyTorch LRScheduler.
This wrapper fulfills two main functions:
Save and restore the learning rate when a trial is paused, preempted, etc.
Step the learning rate scheduler at the configured frequency (e.g., every batch or every epoch).
- class StepMode(value)#
Specifies when and how scheduler.step() should be executed.
- STEP_EVERY_EPOCH#
- STEP_EVERY_BATCH#
- MANUAL_STEP#
- STEP_EVERY_OPTIMIZER_STEP#
- __init__(scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: determined.pytorch._lr_scheduler.LRScheduler.StepMode, frequency: int = 1)#
LRScheduler constructor.
- Parameters
scheduler (
torch.optim.lr_scheduler._LRScheduler
) – Learning rate scheduler to be used by Determined.step_mode (
determined.pytorch.LRSchedulerStepMode
) –The strategy Determined will use to call (or not call) scheduler.step().
STEP_EVERY_EPOCH
: Determined will call scheduler.step() after everyfrequency
training epoch(s). No arguments will be passed to step().STEP_EVERY_BATCH
: Determined will call scheduler.step() after everyfrequency
training batch(es). No arguments will be passed to step(). This option does not take into account gradient aggregation;STEP_EVERY_OPTIMIZER_STEP
which is recommended.STEP_EVERY_OPTIMIZER_STEP
: Determined will call scheduler.step() in sync with optimizer steps. Withoptimizations.aggregation_frequency
unset, this is equivalent toSTEP_EVERY_BATCH
; when it is set, it ensures the LR scheduler is stepped every _effective_ batch.If the option
frequency
is set to some value N, Determined will step the LR scheduler every N optimizer steps.MANUAL_STEP
: Determined will not call scheduler.step() at all. It is up to the user to decide when to call scheduler.step(), and whether to pass any arguments.
frequency – Sets the frequency at which the batch and epoch step modes get triggered.
determined.pytorch.Reducer
#
determined.pytorch.MetricReducer
#
- class determined.pytorch.MetricReducer#
Efficiently aggregating validation metrics during a multi-slot distributed trial is done in three steps:
Gather all the values to be reduced during the reduction window (either a training or a validation workload). In a multi-slot trial, this is done on each slot in parallel.
Calculate the per-slot reduction. This will return some intermediate value that each slot will contribute to the final metric calculation. It can be as simple as a list of all the raw values from step 1, but reducing the intermediate value locally will distribute the final metric calculation more efficiently and will reduce network communication costs.
Reduce the per-slot reduction values from Step 2 into a final metric.
The MetricReducer API makes it possible for users to define a maximally efficient custom metric by exposing these steps to users:
Step 1 is defined by the user; it is not part of the interface. This flexibility gives the user full control when gathering individual values for reduction.
Step 2 is the MetricReducer.per_slot_reduce() interface.
Step 3 is the MetricReducer.cross_slot_reduce() interface.
The MetricReducer.reset() interface allows for MetricReducer reuse across many train and validation workloads.
Example implementation and usage:
class MyAvgMetricReducer(pytorch.MetricReducer): def __init__(self): self.reset() def reset(self): self.sum = 0 self.counts = 0 # User-defined mechanism for collecting values throughout # training or validation. This update() mechanism demonstrates # a computationally- and memory-efficient way to store the values. def update(self, value): self.sum += sum(value) self.counts += 1 def per_slot_reduce(self): # Because the chosen update() mechanism is so # efficient, this is basically a noop. return self.sum, self.counts def cross_slot_reduce(self, per_slot_metrics): # per_slot_metrics is a list of (sum, counts) tuples # returned by the self.pre_slot_reduce() on each slot sums, counts = zip(*per_slot_metrics) return sum(sums) / sum(counts) class MyPyTorchTrial(pytorch.PyTorchTrial): def __init__(self, context): # Register your custom reducer. self.my_avg = context.wrap_reducer( MyAvgMetricReducer(), name="my_avg" ) ... def train_batch(self, batch, epoch_idx, batch_idx): ... # You decide how/when you call update(). self.my_avg.update(my_val) # The "my_avg" metric will be included in the final # metrics after the workload has completed; no need # to return it here. return {"loss": loss}
See also:
determined.pytorch.PyTorchExperimentalContext.wrap_reducer()
.- abstract reset() None #
Reset reducer state for another set of values.
This will be called before any train or validation workload begins.
- abstract per_slot_reduce() Any #
This will be called after all workers have finished (even when there is only one worker).
It should return some picklable value that is meaningful for cross_slot_reduce.
This will be called after any train or validation workload ends.
- abstract cross_slot_reduce(per_slot_metrics: List) Any #
This will be called after per_slot_reduce has finished (even when there is only one worker).
The per_slot_metrics will be a list containing the output of per_slot_reduce() from each worker.
- The return value should either be:
A dict mapping string metric names to metric values, if the call to context.wrap_reducer() omitted the
name
parameter, orA non-dict metric value if the call to context.wrap_reducer() had name set to a string (an error will be raised if a dict-type metric is returned but name was set).
This will be called after per_slot_reduce.
determined.pytorch.PyTorchCallback
#
- class determined.pytorch.PyTorchCallback#
Abstract base class used to define a callback that should execute during the lifetime of a PyTorchTrial or DeepSpeedTrial.
Warning
If you are defining a stateful callback (e.g., it mutates a
self
attribute over its lifetime), you must also overridestate_dict()
andload_state_dict()
to ensure this state can be serialized and deserialized over checkpoints.Warning
If distributed training is enabled, every GPU will execute a copy of this callback (except for
on_checkpoint_write_end()
during PyTorchTrial training, which is only called on the chief). To configure a callback implementation to execute on a subset of GPUs, please condition your implementation ontrial.context.distributed.get_rank()
.- load_state_dict(state_dict: Dict[str, Any]) None #
Load the state of this using the deserialized
state_dict
.
- on_checkpoint_end(checkpoint_dir: str) None #
Deprecated. Please use
on_checkpoint_write_end()
instead.Warning
This callback only executes on the chief GPU when doing distributed training with PyTorchTrial.
- on_checkpoint_load_start(checkpoint: Dict[str, Any]) None #
Run before state_dict is restored.
- on_checkpoint_save_start(checkpoint: Dict[str, Any]) None #
Run before checkpoint is persisted.
- on_checkpoint_upload_end(uuid: str) None #
Run after every checkpoint finishes uploading.
- on_checkpoint_write_end(checkpoint_dir: str) None #
Run after every checkpoint finishes writing to checkpoint_dir.
Warning
This callback only executes on the chief GPU when doing distributed training with PyTorchTrial.
- on_training_epoch_end(epoch_idx: int) None #
Run on end of a training epoch
- on_training_epoch_start(epoch_idx: int) None #
Run on start of a new training epoch
- on_training_start() None #
Run after checkpoint loads and before training begins.
- on_training_workload_end(avg_metrics: Dict[str, Any], batch_metrics: Dict[str, Any]) None #
Run on end of a training workload. Workloads can contain varying numbers of batches. In the current implementation of PyTorchTrial and DeepSpeedTrial, the maximum number of batches in a workload is equal to the
scheduling_unit
field defined in the experiment config.
- on_trial_shutdown() None #
Runs just before shutting down training to get off of the cluster. This does not imply that the trial is complete; it may just be paused or preempted by a higher-priority task.
Warning
This callback runs each time a Trial shuts down gracefully to come off the cluster. This callback does not mean that the Trial is done training. Additionally, if the trial is killed the container will be destroyed without this callback running.
- on_trial_startup(first_batch_idx: int, checkpoint_uuid: Optional[str]) None #
Runs before training, validation, or building dataloaders.
- Parameters
first_batch_idx (int) – The first batch index to be trained. If the trial has already completed some amount of training in a previous allocation on the cluster, this will be nonzero.
checkpoint_uuid (str or None) – The checkpoint from which weight, optimizer state, etc. will be loaded. When
first_batch_idx > 0
this will contain the uuid of the most recent checkpoint saved by this trial. Otherwise, it will contain the uuid of the checkpoint from which this trial was configured to warm start from (viasource_trial_id
orsource_checkpoint_uuid
in the searcher config), or None if no warm start was configured.
- on_validation_end(metrics: Dict[str, Any]) None #
Run after every validation ends.
- on_validation_epoch_end(outputs: List[Any]) None #
Run after a new validation epoch has finished
- on_validation_epoch_start() None #
Run on start of a new validation epoch
- on_validation_start() None #
Run before every validation begins.
- state_dict() Dict[str, Any] #
Serialize the state of this callback to a dictionary. Return value must be pickle-able.
determined.pytorch.load_trial_from_checkpoint_path
#
- determined.pytorch.load_trial_from_checkpoint_path(path: str, trial_class: Optional[Type[determined.pytorch._pytorch_trial.PyTorchTrial]] = None, trial_kwargs: Optional[Dict[str, Any]] = None, torch_load_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Dict[str, Any]) determined.pytorch._pytorch_trial.PyTorchTrial #
Loads a checkpoint written by a PyTorchTrial.
You should have already downloaded the checkpoint files, likely with
Checkpoint.download()
.The return value will be a restored instance of the subclass PyTorchTrial you used for training.
- Parameters
path (string) – Top level directory to load the checkpoint from.
trial_class (optional) – Provide your PyTorchTrial class to be loaded. Only necessary if the automatic import logic is insufficient.
trial_kwargs (optional) – Additional keyword arguments to be passed to your PyTorchTrial class, in addition to the context, which will always be the first positional parameter.
torch_load_kwargs (optional) – Keyword arguments for
torch.load
. See documentation for torch.load.**kwargs (deprecated) – Use torch_load_kwargs instead.
determined.pytorch.Trainer
#
- class determined.pytorch.Trainer(trial: determined.pytorch._pytorch_trial.PyTorchTrial, context: determined.pytorch._pytorch_context.PyTorchTrialContext)#
pytorch.Trainer
is an abstraction on top of a vanilla PyTorch training loop that handles many training details under-the-hood, and exposes APIs for configuring training-related features such as automatic checkpointing, validation, profiling, metrics reporting, etc.Trainer
must be initialized and called from within apytorch.PyTorchTrialContext
.- configure_profiler(sync_timings: bool = True, enabled: bool = False, begin_on_batch: int = 0, end_after_batch: Optional[int] = None) None #
@deprecated: Configure fit(…, profiling_enabled=True) instead.
Configures the Determined profiler. This functionality is only supported for on-cluster training. For local training mode, this method is a no-op.
This method should only be called before .fit(), and only once within the scope of init(). If called multiple times, the last call’s configuration will be used.
- Parameters
sync_timings – (Optional) Specifies whether Determined should wait for all GPU kernel streams before considering a timing as ended. Defaults to true. Applies only for frameworks that collect timing metrics (currently just PyTorch).
enabled – (Optional) Defines whether profiles should be collected or not. Defaults to false.
begin_on_batch – (Optional) Specifies the batch on which profiling should begin. Defaults to 0.
end_after_batch – (Optional) Specifies the batch after which profiling should end.
Note
Profiles are collected for a maximum of 5 minutes, regardless of the settings above.
- fit(checkpoint_period: Optional[determined.pytorch._pytorch_trial.TrainUnit] = None, validation_period: Optional[determined.pytorch._pytorch_trial.TrainUnit] = None, max_length: Optional[determined.pytorch._pytorch_trial.TrainUnit] = None, reporting_period: determined.pytorch._pytorch_trial.TrainUnit = <determined.pytorch._pytorch_trial.Batch object>, checkpoint_policy: str = 'best', latest_checkpoint: Optional[str] = None, step_zero_validation: bool = False, test_mode: bool = False, profiling_enabled: bool = False) None #
fit()
trains aPyTorchTrial
configured from theTrainer
and handles checkpointing and validation steps, and metrics reporting.- Parameters
checkpoint_period – The number of steps to train for before checkpointing. This is a
TrainUnit
type (Batch
orEpoch
) which can take anint
or instance ofcollections.abc.Container
(list, tuple, etc.). For example,Batch(100)
would checkpoint every 100 batches, whileBatch([5, 30, 45])
would checkpoint after every 5th, 30th, and 45th batch.validation_period – The number of steps to train for before validating. This is a
TrainUnit
type (Batch
orEpoch
) which can take anint
or instance ofcollections.abc.Container
(list, tuple, etc.). For example,Batch(100)
would validate every 100 batches, whileBatch([5, 30, 45])
would validate after every 5th, 30th, and 45th batch.max_length – The maximum number of steps to train for. This value is required and only applicable in local training mode. For on-cluster training, this value will be ignored; the searcher’s
max_length
must be configured from the experiment configuration. This is aTrainUnit
type (Batch
orEpoch
) which takes anint
. For example,Epoch(1)
would train for a maximum length of one epoch.reporting_period – The number of steps to train for before reporting metrics and searcher progress. For local training mode, metrics are printed to stdout. This is a
TrainUnit
type (Batch
orEpoch
) which can take anint
or instance ofcollections.abc.Container
(list, tuple, etc.). For example,Batch(100)
would report every 100 batches, whileBatch([5, 30, 45])
would report after every 5th, 30th, and 45th batch.checkpoint_policy –
Controls how Determined performs checkpoints after validation operations, if at all. Should be set to one of the following values:
- best (default): A checkpoint will be taken after every validation operation
that performs better than all previous validations for this experiment. Validation metrics are compared according to the
metric
andsmaller_is_better
fields in the searcher configuration. This option is only supported for on-cluster training.- all: A checkpoint will be taken after every validation, no matter the
validation performance.
- none: A checkpoint will never be taken due to a validation. However,
even with this policy selected, checkpoints are still expected to be taken after the trial is finished training, due to cluster scheduling decisions, before search method decisions, or due to
min_checkpoint_period
.
latest_checkpoint – Configures the checkpoint used to start or continue training. This value should be set to
det.get_cluster_info().latest_checkpoint
for standard continue training functionality.step_zero_validation – Configures whether to perform an initial validation before training. Defaults to false.
test_mode – Runs a minimal loop of training for testing and debugging purposes. Will train and validate one batch. Defaults to false.
profiling_enabled – Enables system metric profiling functionality for on-cluster training. Defaults to false.
determined.pytorch.init()
#
- determined.pytorch.init(*, hparams: Optional[Dict] = None, exp_conf: Optional[Dict[str, Any]] = None, distributed: Optional[determined.core._distributed.DistributedContext] = None, aggregation_frequency: int = 1, enable_tensorboard_logging: bool = True) Iterator[determined.pytorch._pytorch_context.PyTorchTrialContext] #
Creates a PyTorchTrialContext for use with a PyTorchTrial. All trainer.* calls must be within the scope of this context because there are resources started in __enter__ that must be cleaned up in __exit__.
- Parameters
hparams – (Optional) instance of hyperparameters for the trial
exp_conf – (Optional) for local-training mode. If unset, calling context.get_experiment_config() will fail.
distributed – (Optional) custom distributed training configuration
aggregation_frequency – number of batches before gradients are exchanged in distributed training. This value is configured here because it is used in context.wrap_optimizer.
enable_tensorboard_logging – Configures if upload to tensorboard is enabled