Torch Batch Processing API#

In this guide, you’ll learn about the Torch Batch Process API and how to perform batch inference (also known as offline inference).

Visit the API reference

pytorch.experimental.torch_batch_process API Reference


This is an experimental API and may change at any time.


The Torch Batch Processing API takes in (1) a dataset and (2) a user-defined processor class and runs distributed data processing.

This API automatically handles the following for you:

  • shards a dataset by number of workers available

  • applies user-defined logic to each batch of data

  • handles synchronization between workers

  • tracks job progress to enable preemption and resumption of trial

This is a flexible API that can be used for many different tasks, including batch (offline) inference.

If you have some trained models in a :class:~determined.experimental.checkpoint.Checkpoint or a :class:~determined.experimental.model.Model with more than one :class:~determined.experimental.model.ModelVersion inside, you can associate the trial with the :class:~determined.experimental.checkpoint.Checkpoint or :class:~determined.experimental.model.ModelVersion used in a given inference run to aggregate custom inference metrics.

You can then query those :class:~determined.experimental.checkpoint.Checkpoint or :class:~determined.experimental.model.ModelVersion objects using the :ref:Python SDK <python-sdk> to see all metrics associated with them.


The main arguments to torch_batch_process() are processor class and dataset.


In the experiment config file, use a distributed launcher as the API requires information such as rank set by the launcher. Below is an example.

entrypoint: >-
    python3 -m determined.launch.torch_distributed
  slots_per_trial: 4


During __init__() of TorchBatchProcessor, we pass in a TorchBatchProcessorContext object, which contains useful methods that can be used within the TorchBatchProcessor class.

TorchBatchProcessor is compatible with Determined’s MetricReducer. You can pass MetricReducer to TorchBatchProcessor as follow:


TorchBatchProcessorContext should be a subclass of TorchBatchProcessor. The two functions you must implement are the __init__() and process_batch(). The other lifecycle functions are optional.

class MyProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.reducer = context.wrap_reducer(reducer=AccuracyMetricReducer(), name="accuracy")

How To Perform Batch (Offline) Inference#

In this section, we’ll learn how to perform batch inference using the Torch Batch Processing API.

Step 1: Define an InferenceProcessor#

The first step is to define an InferenceProcessor. You should initialize your model in the __init__() function of the InferenceProcessor. You should implement process_batch() function with inference logic.

You can optionally implement on_checkpoint_start() and on_finish() to be run before every checkpoint and after all the data has been processed, respectively.

Define custom processor class
class InferenceProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.context = context
        self.model = context.prepare_model_for_inference(get_model())
        self.output = []
        self.last_index = 0

    def process_batch(self, batch, batch_idx) -> None:
        model_input = batch[0]
        model_input = self.context.to_device(model_input)

        with torch.no_grad():
            with self.profiler as p:
                pred = self.model(model_input)
                output = {"predictions": pred, "input": batch}

        self.last_index = batch_idx

    def on_checkpoint_start(self):
        During checkpoint, we persist prediction result
        if len(self.output) == 0:
        file_name = f"prediction_output_{self.last_index}"
        with self.context.upload_path() as path:
            file_path = pathlib.Path(path, file_name)
  , file_path)

        self.output = []

Step 3: Initialize the Dataset#

Initialize the dataset you want to process.

Initialize dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
with filelock.FileLock(os.path.join("/tmp", "inference.lock")):
    inference_data = tv.datasets.CIFAR10(
        root="/data", train=False, download=True, transform=transform

Step 4: Pass the InferenceProcessor Class and Dataset#

Pass the InferenceProcessor class and the dataset to torch_batch_process.

Pass processor class and dataset to torch_batch_process

Step 5: Send and Query Custom Inference Metrics (Optional)#

Report metrics anywhere in the trial to have them aggregated for the Checkpoint or ModelVersion in question.

For example, you could send metrics in on_finish().

def on_finish(self):
            "my_metric": 1.0,

And check the metric afterwards from the SDK:

from determined.experimental import client

# Checkpoint
ckpt = client.get_checkpoint("<CHECKPOINT_UUID>")
metrics = ckpt.get_metrics("inference")

# Or Model Version
model = client.get_model("<MODEL_NAME>")
model_version = model.get_version(MODEL_VERSION_NUM)
metrics = model_version.get_metrics("inference")