pytorch.experimental.torch_batch_process API Reference#
User Guide |
---|
Caution
This is an experimental API and may change at any time.
The main arguments to torch_batch_process
are batch_processor_cls, a subclass of
TorchBatchProcessor
and dataset.
torch_batch_process(
batch_processor_cls=MyProcessor
dataset=dataset
)
determined.pytorch.torch_batch_process
#
- determined.pytorch.experimental.torch_batch_process(batch_processor_cls: Type[determined.pytorch.experimental._torch_batch_process.TorchBatchProcessor], dataset: torch.utils.data.dataset.Dataset, batch_size: Optional[int] = None, max_batches: Optional[int] = None, checkpoint_interval: int = 5, dataloader_kwargs: Optional[Dict[str, Any]] = None, distributed_context: Optional[determined.core._distributed.DistributedContext] = None) None #
`torch_batch_process`
shard and iterate through the provided dataset and process the dataset with user-defined logic in`batch_processor_cls`
.- Parameters
batch_processor_cls – A user-defined class extending
`TorchBatchProcessor`
dataset – A torch dataset class implementing __len__() and __getitem__()
batch_size – The number of items to in each batch
max_batches – The maximum number of batches to iterate over per worker
checkpoint_interval – Interval to checkpoint progress (i.e. record number of batches processed)
dataloader_kwargs – Kwargs to pass to PyTorch dataloader
distributed_context – Distributed context to initialize core context
determined.pytorch.TorchBatchProcessorContext
#
- class determined.pytorch.experimental.TorchBatchProcessorContext(core_context: determined.core._context.Context, storage_path: str)#
- to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor], warned_types: Optional[Set[Type]] = None) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor] #
Accept np.ndarray, torch.Tensor, list, or dictionary. Recursively convert any ndarrays to tensors and call .to() on any tensors or data types that have custom serialization logic defined via a callable to() attribute.
If the data cannot be moved to device, log a warning (only once per type) and return the original data.
- get_tensorboard_path() pathlib.Path #
Tensorboard files should be written to the path returned to be shown properly in the UI.
For example, the path should be passed to PyTorch profiler as shown below:
torch.profiler.profile( activities=..., schedule=..., on_trace_ready=torch.profiler.tensorboard_trace_handler(<tensorboard_path>), )
- prepare_model_for_inference(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module #
Set model to eval mode and send model to device :param model: a nn.Module
- upload_path() AbstractContextManager[pathlib.Path] #
Returns a context that uploads files to default storage path on exit.
determined.pytorch.TorchBatchProcessor
#
- class determined.pytorch.experimental.TorchBatchProcessor(context: determined.pytorch.experimental._torch_batch_process.TorchBatchProcessorContext)#
- abstract process_batch(batch: Any, batch_idx: int) None #
This function will be called with every batch of data in the dataset
- Parameters
batch – a batch of data of the dataset passed into torch_batch_process
batch_idx – index of the batch. Note that index is per worker. For example, if there are 8 batches of data to process and 4 workers, each worker would get two batches of data (batch_idx = 0 and batch_idx = 1)
- on_checkpoint_start() None #
This function will be called right before each checkpoint
- on_finish() None #
This function will be called right before exiting after completing iteration over dataset