pytorch.experimental.torch_batch_process API Reference#

User Guide

Torch Batch Processing API

Caution

This is an experimental API and may change at any time.

The main arguments to torch_batch_process are batch_processor_cls, a subclass of TorchBatchProcessor and dataset.

torch_batch_process(
    batch_processor_cls=MyProcessor
    dataset=dataset
)

determined.pytorch.torch_batch_process#

determined.pytorch.experimental.torch_batch_process(batch_processor_cls: Type[determined.pytorch.experimental._torch_batch_process.TorchBatchProcessor], dataset: torch.utils.data.dataset.Dataset, batch_size: Optional[int] = None, max_batches: Optional[int] = None, checkpoint_interval: int = 5, dataloader_kwargs: Optional[Dict[str, Any]] = None, distributed_context: Optional[determined.core._distributed.DistributedContext] = None) None#

`torch_batch_process` shard and iterate through the provided dataset and process the dataset with user-defined logic in `batch_processor_cls`.

Parameters
  • batch_processor_cls – A user-defined class extending `TorchBatchProcessor`

  • dataset – A torch dataset class implementing __len__() and __getitem__()

  • batch_size – The number of items to in each batch

  • max_batches – The maximum number of batches to iterate over per worker

  • checkpoint_interval – Interval to checkpoint progress (i.e. record number of batches processed)

  • dataloader_kwargs – Kwargs to pass to PyTorch dataloader

  • distributed_context – Distributed context to initialize core context

determined.pytorch.TorchBatchProcessorContext#

class determined.pytorch.experimental.TorchBatchProcessorContext(core_context: determined.core._context.Context, storage_path: str)#
to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor], warned_types: Optional[Set[Type]] = None) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]#

Accept np.ndarray, torch.Tensor, list, or dictionary. Recursively convert any ndarrays to tensors and call .to() on any tensors or data types that have custom serialization logic defined via a callable to() attribute.

If the data cannot be moved to device, log a warning (only once per type) and return the original data.

get_tensorboard_path() pathlib.Path#

Tensorboard files should be written to the path returned to be shown properly in the UI.

For example, the path should be passed to PyTorch profiler as shown below:

torch.profiler.profile(
    activities=...,
    schedule=...,
    on_trace_ready=torch.profiler.tensorboard_trace_handler(<tensorboard_path>),
)
prepare_model_for_inference(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module#

Set model to eval mode and send model to device :param model: a nn.Module

upload_path() AbstractContextManager[pathlib.Path]#

Returns a context that uploads files to default storage path on exit.

determined.pytorch.TorchBatchProcessor#

class determined.pytorch.experimental.TorchBatchProcessor(context: determined.pytorch.experimental._torch_batch_process.TorchBatchProcessorContext)#
abstract process_batch(batch: Any, batch_idx: int) None#

This function will be called with every batch of data in the dataset

Parameters
  • batch – a batch of data of the dataset passed into torch_batch_process

  • batch_idx – index of the batch. Note that index is per worker. For example, if there are 8 batches of data to process and 4 workers, each worker would get two batches of data (batch_idx = 0 and batch_idx = 1)

on_checkpoint_start() None#

This function will be called right before each checkpoint

on_finish() None#

This function will be called right before exiting after completing iteration over dataset