pytorch.experimental.torch_batch_process API Reference#

User Guide

Torch Batch Processing API


This is an experimental API and may change at any time.

The main arguments to torch_batch_process are batch_processor_cls, a subclass of TorchBatchProcessor and dataset.



determined.pytorch.experimental.torch_batch_process(batch_processor_cls: Type[determined.pytorch.experimental._torch_batch_process.TorchBatchProcessor], dataset:, batch_size: Optional[int] = None, max_batches: Optional[int] = None, checkpoint_interval: int = 5, dataloader_kwargs: Optional[Dict[str, Any]] = None, distributed_context: Optional[determined.core._distributed.DistributedContext] = None) None#

`torch_batch_process` shard and iterate through the provided dataset and process the dataset with user-defined logic in `batch_processor_cls`.

  • batch_processor_cls – A user-defined class extending `TorchBatchProcessor`

  • dataset – A torch dataset class implementing __len__() and __getitem__()

  • batch_size – The number of items to in each batch

  • max_batches – The maximum number of batches to iterate over per worker

  • checkpoint_interval – Interval to checkpoint progress (i.e. record number of batches processed)

  • dataloader_kwargs – Kwargs to pass to PyTorch dataloader

  • distributed_context – Distributed context to initialize core context


class determined.pytorch.experimental.TorchBatchProcessorContext(core_context: determined.core._context.Context, storage_path: str)#
to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor], warned_types: Optional[Set[Type]] = None) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]#

Accept np.ndarray, torch.Tensor, list, or dictionary. Recursively convert any ndarrays to tensors and call .to() on any tensors or data types that have custom serialization logic defined via a callable to() attribute.

If the data cannot be moved to device, log a warning (only once per type) and return the original data.

get_tensorboard_path() pathlib.Path#

Tensorboard files should be written to the path returned to be shown properly in the UI.

For example, the path should be passed to PyTorch profiler as shown below:

prepare_model_for_inference(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module#

Set model to eval mode and send model to device :param model: a nn.Module

upload_path() AbstractContextManager[pathlib.Path]#

Returns a context that uploads files to default storage path on exit.

report_metrics(group: str, steps_completed: int, metrics: Dict[str, Any]) None#

Report metrics data to the master.

  • group (string) – metrics group name. Can be used to partition metrics into different logical groups or time series. “training” and “validation” group names map to built-in training and validation time series. Note: Group cannot contain . character.

  • steps_completed (int) – global step number, e.g. the number of batches processed.

  • metrics (Dict[str, Any]) – metrics data dictionary. Must be JSON-serializable. When reporting metrics with the same group and steps_completed values, the dictionary keys must not overlap.

report_task_using_model_version(model_version: model.ModelVersion) None#

Associate model_version with the current task. This links together the metrics reporting so that any metrics which are reported to the current task will be visible when querying for metrics associated with this model version


model_Version (model.ModelVersion) – The model version to associate with this task

report_task_using_checkpoint(checkpoint: checkpoint.Checkpoint) None#

Associate checkpoint with the current task. This links together the metrics reporting so that any metrics which are reported to the current task will be visible when querying for metrics associated with this checkpoint


checkpoint (checkpoint.Checkpoint) – The checkpoint to associate with this task

get_distributed_rank() int#

The rank of this current process in a trial

get_distributed_size() int#

The number of slots this trial is running on


class determined.pytorch.experimental.TorchBatchProcessor(context: determined.pytorch.experimental._torch_batch_process.TorchBatchProcessorContext)#
abstract process_batch(batch: Any, batch_idx: int) None#

This function will be called with every batch of data in the dataset

  • batch – a batch of data of the dataset passed into torch_batch_process

  • batch_idx – index of the batch. Note that index is per worker. For example, if there are 8 batches of data to process and 4 workers, each worker would get two batches of data (batch_idx = 0 and batch_idx = 1)

on_checkpoint_start() None#

This function will be called right before each checkpoint

on_finish() None#

This function will be called right before exiting after completing iteration over dataset