det.pytorch.deepspeed
API Reference#
User Guide |
---|
determined.pytorch.deepspeed.DeepSpeedTrial
#
- class determined.pytorch.deepspeed.DeepSpeedTrial(context: determined.pytorch.deepspeed._deepspeed_context.DeepSpeedTrialContext)#
DeepSpeed trials are created by subclassing this abstract class.
We can do the following things in this trial class:
Define the DeepSpeed model engine which includes the model, optimizer, and lr_scheduler.
In the
__init__()
method, initialize models and, optionally, optimizers and LR schedulers and pass them todeepspeed.initialize
to build the model engine. Then pass the created model engine towrap_model_engine
provided byDeepSpeedTrialContext
. We support multiple DeepSpeed model engines if they only use data parallelism or if they use the same model parallel unit.Run forward and backward passes.
In
train_batch()
, use the methods provided by the DeepSpeed model engine to perform the backward pass and optimizer step. These methods will differ depending on whether you are using pipeline parallelism or not.
- trial_controller_class#
alias of
determined.pytorch.deepspeed._deepspeed_trial.DeepSpeedTrialController
- trial_context_class#
alias of
determined.pytorch.deepspeed._deepspeed_context.DeepSpeedTrialContext
- abstract __init__(context: determined.pytorch.deepspeed._deepspeed_context.DeepSpeedTrialContext) None #
Initializes a trial using the provided
context
. The general steps are:Initialize the model(s) and, optionally, the optimizer and lr_scheduler. The latter two can also be configured using the DeepSpeed config.
Build the DeepSpeed model engine by calling
deepspeed.initialize
with the model (optionally optimizer and lr scheduler) and a DeepSpeed config. Wrap it withcontext.wrap_model_engine
.If you want, use a custom model parallel unit by calling
context.set_mpu
.If you want, disable automatic gradient accumulation by calling
context.disable_auto_grad_accumulation
.If you want, use a custom data loader by calling
context.disable_dataset_reproducibility_checks
.
Here is a code example.
self.context = context self.args = AttrDict(self.context.get_hparams()) # Build deepspeed model engine. model = ... # build model model_engine, optimizer, lr_scheduler, _ = deepspeed.initialize( args=self.args, model=model, ) self.model_engine = self.context.wrap_model_engine(model_engine)
- abstract train_batch(dataloader_iter: Optional[Iterator[Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]]], epoch_idx: int, batch_idx: int) Union[torch.Tensor, Dict[str, Any]] #
Train one full batch (i.e. train on
train_batch_size
samples, perhaps consisting of multiple micro-batches).If training without pipeline parallelism, users should implement this function by doing the following things:
Get a batch from the
dataloader_iter
and pass it to the GPU.Compute the loss in the forward pass.
Perform the backward pass.
Perform an optimizer step.
Return training metrics in a dictionary.
Here is a code example.
# Assume one model_engine wrapped in ``__init__``. batch = self.context.to_device(next(dataloader_iter)) loss = self.model_engine(batch) self.model_engine.backward(loss) self.model_engine.step() return {"loss": loss}
If using gradient accumulation over multiple micro-batches, Determined will automatically call
train_batch
multiple times according togradient_accumulation_steps
in the DeepSpeed config.With pipeline parallelism there is no need to manually get a batch from the
dataloader_iter
and the forward, backward, optimizer steps are combined in the model engine’strain_batch
method.# Assume one model_engine wrapped in ``__init__``. loss = self.model_engine.train_batch(dataloader_iter) return {"loss": loss}
- Parameters
dataloader_iter (Iterator[torch.utils.data.DataLoader], optional) – iterator over the train DataLoader.
epoch_idx (integer) – index of the current epoch among all the batches processed per device (slot) since the start of training.
batch_idx (integer) – index of the current batch among all the epochs processed per device (slot) since the start of training.
- Returns
training metrics to return.
- Return type
torch.Tensor or Dict[str, Any]
- abstract build_training_data_loader() Optional[determined.pytorch._data.DataLoader] #
Defines the data loader to use during training.
Must return an instance of
determined.pytorch.DataLoader
unlesscontext.disable_dataset_reproducibility_checks
is called.If using data parallel training, the batch size should be per GPU batch size. If using gradient aggregation, the data loader should return batches with
train_micro_batch_size_per_gpu
samples each.
- abstract build_validation_data_loader() Optional[determined.pytorch._data.DataLoader] #
Defines the data loader to use during validation.
Must return an instance of
determined.pytorch.DataLoader
unlesscontext.disable_dataset_reproducibility_checks
is called.If using data parallel training, the batch size should be per GPU batch size. If using gradient aggregation, the data loader should return batches with a desired micro batch size (most of the time this is the same as
train_micro_batch_size_per_gpu
).
- build_callbacks() Dict[str, determined.pytorch._callback.PyTorchCallback] #
Defines a dictionary of string names to callbacks to be used during training and/or validation.
The string name will be used as the key to save and restore callback state for any callback that defines
load_state_dict()
andstate_dict()
.
- abstract evaluate_batch(dataloader_iter: Optional[Iterator[Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]]], batch_idx: int) Dict[str, Any] #
Calculate validation metrics for a batch and return them as a dictionary mapping metric names to metric values. Per-batch validation metrics are averaged to produce a single set of validation metrics for the entire validation set by default.
The metrics returned from this function must be JSON-serializable.
DeepSpeedTrial supports more flexible metrics computation via our custom reducer API, see
MetricReducer
for more details.- Parameters
dataloader_iter (Iterator[torch.utils.data.DataLoader], optional) – iterator over the validation DataLoader.
- save(context: determined.pytorch.deepspeed._deepspeed_context.DeepSpeedTrialContext, path: pathlib.Path) None #
Save is called on every GPU to make sure all checkpoint shards are saved.
By default, we loop through the wrapped model engines and call DeepSpeed’s save:
for i, m in enumerate(context.models): m.save_checkpoint(path, tag=f"model{i}")
This method can be overwritten for more custom save behavior.
- load(context: determined.pytorch.deepspeed._deepspeed_context.DeepSpeedTrialContext, load_path: pathlib.Path) None #
By default, we loop through the wrapped model engines and call DeepSpeed’s load.
for i, m in enumerate(context.models): m.load_checkpoint(path, tag=f"model{i}")
This method can be overwritten for more custom load behavior.
determined.pytorch.deepspeed.DeepSpeedTrialContext
#
- class determined.pytorch.deepspeed.DeepSpeedTrialContext(*args: Any, **kwargs: Any)#
Bases:
determined._trial_context.TrialContext
,determined.pytorch._reducer._PyTorchReducerContext
Contains runtime information for any Determined workflow that uses the
DeepSpeedTrial
API.With this class, users can do the following things:
Wrap DeepSpeed model engines that contain the model, optimizer, lr_scheduler, etc. This will make sure Determined can automatically provide gradient aggregation, checkpointing and fault tolerance. In contrast to
determined.pytorch.PyTorchTrial
, the user does not need to wrap optimizer and lr_scheduler as that should all be instead passed to the DeepSpeed initialize function (see https://www.deepspeed.ai/getting-started/#writing-deepspeed-models) when building the model engine.Overwrite a deepspeed config file or dictionary with values from Determined’s experiment config to ensure consistency in batch size and support hyperparameter tuning.
Set a custom model parallel configuration that should instantiate a
determined.pytorch.deepspeed.ModelParallelUnit
dataclass. We automatically set the mpu for data parallel and standard pipeline parallel training. This should only be needed if there is additional model parallelism outside DeepSpeed’s supported methods.Disable data reproducibility checks to allow custom data loaders.
Disable automatic gradient aggregation for non-pipeline-parallel training.
- current_train_batch() int #
Current global batch index
- disable_auto_grad_accumulation() None #
Prevent the DeepSpeedTrialController from automatically calling train_batch multiple times to process enough micro batches to meet the per slot batch size. Thus, the user is responsible for manually training on enough micro batches in train_batch to meet the expected per slot batch size.
- disable_dataset_reproducibility_checks() None #
disable_dataset_reproducibility_checks()
allows you to return an arbitraryDataLoader
frombuild_training_data_loader()
orbuild_validation_data_loader()
.Normally you would be required to return a
det.pytorch.DataLoader
instead, which would guarantee that an appropriateSampler
is used that ensures:When
shuffle=True
, the shuffle is reproducible.The dataset will start at the right location, even after pausing/continuing.
Proper sharding is used during distributed training.
However, there can be cases where either reproducibility of the dataset is not needed or where the nature of the dataset can cause the
det.pytorch.DataLoader
to be unsuitable.In those cases, you can call
disable_dataset_reproducibility_checks()
and you will be free to return anytorch.utils.data.DataLoader
you like. Dataset reproducibility will still be possible, but it will be your responsibility. TheSampler
classes indetermined.pytorch.samplers
can help in this regard.
- classmethod from_config(config: Dict[str, Any]) determined._trial_context.TrialContext #
Create a context object suitable for debugging outside of Determined.
An example for a subclass of
DeepSpeedTrial
:config = { ... } context = det.pytorch.deepspeed.DeepSpeedTrialContext.from_config(config) my_trial = MyDeepSpeedTrial(context) train_ds = my_trial.build_training_data_loader() for epoch_idx in range(3): for batch_idx, batch in enumerate(train_ds): metrics = my_trial.train_batch(batch, epoch_idx, batch_idx) ...
An example for a subclass of
TFKerasTrial
:config = { ... } context = det.keras.TFKerasTrialContext.from_config(config) my_trial = tf_keras_one_var_model.OneVarTrial(context) model = my_trial.build_model() model.fit(my_trial.build_training_data_loader()) eval_metrics = model.evaluate(my_trial.build_validation_data_loader())
- Parameters
config – An experiment config file, in dictionary form.
- get_data_config() Dict[str, Any] #
Return the data configuration.
- get_enable_tensorboard_logging() bool #
Return whether automatic tensorboard logging is enabled
- get_experiment_config() Dict[str, Any] #
Return the experiment configuration.
- get_experiment_id() int #
Return the experiment ID of the current trial.
- get_hparam(name: str) Any #
Return the current value of the hyperparameter with the given name.
- get_hparams() Dict[str, Any] #
Return a dictionary of hyperparameter names to values.
- get_stop_requested() bool #
Return whether a trial stoppage has been requested.
- get_tensorboard_path() pathlib.Path #
Get the path where files for consumption by TensorBoard should be written
- get_tensorboard_writer() Any #
This function returns an instance of
torch.utils.tensorboard.SummaryWriter
Trials users who wish to log to TensorBoard can use this writer object. We provide and manage a writer in order to save and upload TensorBoard files automatically on behalf of the user.
Usage example:
class MyModel(PyTorchTrial): def __init__(self, context): ... self.writer = context.get_tensorboard_writer() def train_batch(self, batch, epoch_idx, batch_idx): self.writer.add_scalar('my_metric', np.random.random(), batch_idx) self.writer.add_image('my_image', torch.ones((3,32,32)), batch_idx)
- get_trial_id() int #
Return the trial ID of the current trial.
- is_epoch_end() bool #
Returns true if the current batch is the last batch of the epoch.
Warning
Not accurate for variable size epochs.
- is_epoch_start() bool #
Returns true if the current batch is the first batch of the epoch.
Warning
Not accurate for variable size epochs.
- set_enable_tensorboard_logging(enable_tensorboard_logging: bool) None #
Set a flag to indicate whether automatic upload to tensorboard is enabled.
- set_mpu(mpu: determined.pytorch.deepspeed._mpu.ModelParallelUnit) None #
Use a custom model parallel configuration.
The argument
mpu
should implement adetermined.pytorch.deepspeed.ModelParallelUnit
dataclass to provide information on data parallel topology and whether a rank should compute metrics/build data loaders.This should only be needed if training with custom model parallelism.
In the case of multiple model parallel engines, we assume that the MPU and data loaders correspond to the first wrapped model engine.
- set_profiler(*args: List[str], **kwargs: Any) None #
set_profiler()
is a thin wrapper around PyTorch profiler, torch-tb-profiler. It overrides theon_trace_ready
parameter to the determined tensorboard path, while all other arguments are passed directly intotorch.profiler.profile
. Stepping the profiler will be handled automatically during the training loop.See the PyTorch profiler plugin for details.
Examples:
Profiling GPU and CPU activities, skipping batch 1, warming up on batch 2, and profiling batches 3 and 4.
self.context.set_profiler( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( wait=1, warmup=1, active=2 ), )
- set_stop_requested(stop_requested: bool) None #
Set a flag to request a trial stoppage. When this flag is set to True, we finish the step, checkpoint, then exit.
- to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor]) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor] #
Map data to the device allocated by the Determined cluster.
Since we pass an iterable over the data loader to
train_batch
andevaluate_batch
for DeepSpeedTrial, the user is responsible for moving data to GPU if needed. This is basically a helper function to make that easier.
- wrap_model_engine(model: deepspeed.DeepSpeedEngine) deepspeed.DeepSpeedEngine #
Register a DeepSpeed model engine.
In the background, we track the model engine for checkpointing, set batch size information, using the first wrapped model engine, and perform checks to properly handle pipeline parallelism if the model engine is a PipelineEngine.
- wrap_reducer(reducer: Union[Callable, determined.pytorch._reducer.MetricReducer], name: Optional[str] = None, for_training: bool = True, for_validation: bool = True) determined.pytorch._reducer.MetricReducer #
Register a custom reducer that will calculate a metric properly, even with distributed training.
During distributed training and evaluation, many types of metrics must be calculated globally, rather than calculating the metric on each shard of the dataset and averaged or summed. For example, an accurate ROC AUC for dataset cannot be derived from the individual ROC AUC metrics calculated on by each worker.
Determined solves this problem by offering fully customizable metric reducers which are distributed-aware. These are registered by calling
context.wrap_reducer()
and are updated by the user duringtrain_batch()
orevaluate_batch()
.- Parameters
reducer (Union[Callable, pytorch.MetricReducer]) – Either a reducer function or a pytorch.MetricReducer. See below for more details.
name – (Optional[str] = None): Either a string name to associate with the metric returned by the reducer, or
None
to indicate the metric will return a dict mapping string names to metric values. This allows for a single reducer to return many metrics, such as for a per-class mean IOU calculation. Note that if name is a string, the returned metric must NOT be a dict-type metric.for_training – (bool = True): Indicate that the
reducer
should be used for training workloads.for_validation – (bool = True): Indicate that the
reducer
should be used for validation workloads.
- Return Value:
- pytorch.MetricReducer:
If
reducer
was a function, the returnedMetricReducer
will have a single user-facing method likedef update(value: Any) -> None
that you should call duringtrain_batch
orevaluate_batch
. Otherwise, the return value will just be thereducer
that was passed in.
Reducer functions: the simple API
If the
reducer
parameter is a function, it must have the following properties:It accepts a single parameter, which will be a flat list of all inputs the users pass when they call
.update()
on the object returned bywrap_reducer()
. See the code example below for more details.It returns either a single (non-dict) metric or a dictionary mapping names to metrics, as described above.
The primary motivation for passing a function as the reducer is simplicity. Metrics from all batches will be buffered in memory and passed over the network before they are reduced all at once. This introduces some overhead, but it is likely unnoticeable for scalar metrics or on validation datasets of small or medium size. This single function strategy can be useful for quick prototyping or for calculating metrics that are difficult or impossible to calculate incrementally.
For example, ROC AUC could be properly calculated by passing a small wrapper function calling
sklearn.metrics.roc_auc_score
:# Custom reducer function. def roc_auc_reducer(values): # values will be a flat list of all inputs to # .update(), which in this code example are # tuples of (y_true, y_score). We reshape # that list into two separate lists: y_trues, y_scores = zip(*values) # Then we return a metric value: return sklearn.metrics.roc_auc_score( np.array(y_trues), np.array(y_scores) ) class MyPyTorchTrial(PyTorchTrial): def __init__(self, context): self.roc_auc = context.wrap_reducer( roc_auc_reducer, name="roc_auc" ) ... def evaluate_batch(self, batch): ... # Function-based reducers are updated with .update(). # The roc_auc_reducer function will get a list of all # inputs that we pass in here: self.roc_auc.update((y_true, y_score)) # The "roc_auc" metric will be included in the final # metrics after the workload has completed; no need # to return it here. If that is your only metric, # just return an empty dict. return {}
MetricReducer objects: the advanced API
The primary motivation for passing a
det.pytorch.MetricReducer
as the reducer is performance.det.pytorch.MetricReducer
allows the user more control in how values are stored and exposes aper_slot_reduce()
call which lets users minimize the cost of the network communication before the finalcross_slot_reduce()
.An additional reason for using the
det.pytorch.MetricReducer
is for flexibility of the update mechanism, which is completely user-defined when subclassingMetricReducer
.For the full details and a code example, see:
MetricReducer
.
determined.pytorch.deepspeed.overwrite_deepspeed_config
#
- determined.pytorch.deepspeed.overwrite_deepspeed_config(base_ds_config: Union[str, Dict], source_ds_dict: Dict[str, Any]) Dict[str, Any] #
Overwrite a base_ds_config with values from a source_ds_dict.
You can use source_ds_dict to overwrite leaf nodes of the base_ds_config. More precisely, we will iterate depth first into source_ds_dict and if a node corresponds to a leaf node of base_ds_config, we copy the node value over to base_ds_config.
- Parameters
base_ds_config (str or Dict) – either a path to a DeepSpeed config file or a dictionary.
source_ds_dict (Dict) – dictionary with fields that we want to copy to base_ds_config
- Returns
The resulting dictionary when base_ds_config is overwritten with source_ds_dict.
determined.pytorch.deepspeed.ModelParallelUnit
#
- class determined.pytorch.deepspeed.ModelParallelUnit(data_parallel_rank: int, data_parallel_world_size: int, should_report_metrics: bool, should_build_data_loader: bool)#
This class contains the functions we expect in order to accurately carry out parallel training. For custom model parallel training, you need to subclass and override the functions before passing it to the
DeepSpeedTrialContext
by callingcontext.wrap_mpu(mpu)
.
The following classes and methods overlap with PyTorchTrial (click to go to respective documentation):