pytorch.experimental.torch_batch_process API Reference#
User Guide |
---|
Caution
This is an experimental API and may change at any time.
The main arguments to torch_batch_process
are batch_processor_cls, a subclass of
TorchBatchProcessor
and dataset.
torch_batch_process(
batch_processor_cls=MyProcessor
dataset=dataset
)
determined.pytorch.torch_batch_process
#
- determined.pytorch.experimental.torch_batch_process(batch_processor_cls: Type[determined.pytorch.experimental._torch_batch_process.TorchBatchProcessor], dataset: torch.utils.data.dataset.Dataset, batch_size: Optional[int] = None, max_batches: Optional[int] = None, checkpoint_interval: int = 5, dataloader_kwargs: Optional[Dict[str, Any]] = None, distributed_context: Optional[determined.core._distributed.DistributedContext] = None) None #
`torch_batch_process`
shard and iterate through the provided dataset and process the dataset with user-defined logic in`batch_processor_cls`
.- Parameters
batch_processor_cls – A user-defined class extending
`TorchBatchProcessor`
dataset – A torch dataset class implementing __len__() and __getitem__()
batch_size – The number of items to in each batch
max_batches – The maximum number of batches to iterate over per worker
checkpoint_interval – Interval to checkpoint progress (i.e. record number of batches processed)
dataloader_kwargs – Kwargs to pass to PyTorch dataloader
distributed_context – Distributed context to initialize core context
determined.pytorch.TorchBatchProcessorContext
#
- class determined.pytorch.experimental.TorchBatchProcessorContext(core_context: determined.core._context.Context, storage_path: str)#
- to_device(data: Union[Dict[str, Union[numpy.ndarray, torch.Tensor]], Sequence[Union[numpy.ndarray, torch.Tensor]], numpy.ndarray, torch.Tensor], warned_types: Optional[Set[Type]] = None) Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor] #
Accept np.ndarray, torch.Tensor, list, or dictionary. Recursively convert any ndarrays to tensors and call .to() on any tensors or data types that have custom serialization logic defined via a callable to() attribute.
If the data cannot be moved to device, log a warning (only once per type) and return the original data.
- get_tensorboard_path() pathlib.Path #
Tensorboard files should be written to the path returned to be shown properly in the UI.
For example, the path should be passed to PyTorch profiler as shown below:
torch.profiler.profile( activities=..., schedule=..., on_trace_ready=torch.profiler.tensorboard_trace_handler(<tensorboard_path>), )
- prepare_model_for_inference(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module #
Set model to eval mode and send model to device :param model: a nn.Module
- upload_path() AbstractContextManager[pathlib.Path] #
Returns a context that uploads files to default storage path on exit.
- report_metrics(group: str, steps_completed: int, metrics: Dict[str, Any]) None #
Report metrics data to the master.
- Parameters
group (string) – metrics group name. Can be used to partition metrics into different logical groups or time series. “training” and “validation” group names map to built-in training and validation time series. Note: Group cannot contain
.
character.steps_completed (int) – global step number, e.g. the number of batches processed.
metrics (Dict[str, Any]) – metrics data dictionary. Must be JSON-serializable. When reporting metrics with the same
group
andsteps_completed
values, the dictionary keys must not overlap.
- report_task_using_model_version(model_version: model.ModelVersion) None #
Associate
model_version
with the current task. This links together the metrics reporting so that any metrics which are reported to the current task will be visible when querying for metrics associated with this model version- Parameters
model_Version (model.ModelVersion) – The model version to associate with this task
- report_task_using_checkpoint(checkpoint: checkpoint.Checkpoint) None #
Associate
checkpoint
with the current task. This links together the metrics reporting so that any metrics which are reported to the current task will be visible when querying for metrics associated with this checkpoint- Parameters
checkpoint (checkpoint.Checkpoint) – The checkpoint to associate with this task
- get_distributed_rank() int #
The rank of this current process in a trial
- get_distributed_size() int #
The number of slots this trial is running on
determined.pytorch.TorchBatchProcessor
#
- class determined.pytorch.experimental.TorchBatchProcessor(context: determined.pytorch.experimental._torch_batch_process.TorchBatchProcessorContext)#
- abstract process_batch(batch: Any, batch_idx: int) None #
This function will be called with every batch of data in the dataset
- Parameters
batch – a batch of data of the dataset passed into torch_batch_process
batch_idx – index of the batch. Note that index is per worker. For example, if there are 8 batches of data to process and 4 workers, each worker would get two batches of data (batch_idx = 0 and batch_idx = 1)
- on_checkpoint_start() None #
This function will be called right before each checkpoint
- on_finish() None #
This function will be called right before exiting after completing iteration over dataset