Torch Batch Processing API#

In this guide, you’ll learn about the Torch Batch Process API and how to perform batch inference (also known as offline inference).

This is an experimental API and may change at any time.


The Torch Batch Processing API takes in (1) a dataset and (2) a user-defined processor class and runs distributed data processing.

This API automatically handles the following for you:

  • shards a dataset by number of workers available

  • applies user-defined logic to each batch of data

  • handles synchronization between workers

  • tracks job progress to enable preemption and resumption of trial

This is a flexible API that can be used for many different tasks, including batch (offline) inference.


The main arguments to torch_batch_process() are processor class and dataset.


In the experiment config file, use a distributed launcher as the API requires information such as rank set by the launcher. Below is an example.

entrypoint: >-
    python3 -m determined.launch.torch_distributed
  slots_per_trial: 4


During __init__() of TorchBatchProcessor, we pass in a TorchBatchProcessorContext object, which contains useful methods that can be used within the TorchBatchProcessor class.

TorchBatchProcessor is compatible with Determined’s MetricReducer. You can pass MetricReducer to TorchBatchProcessor as follow:


TorchBatchProcessorContext should be a subclass of TorchBatchProcessor. The two functions you must implement are the __init__() and process_batch(). The other lifecycle functions are optional.

class MyProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.reducer = context.wrap_reducer(reducer=AccuracyMetricReducer(), name="accuracy")

How To Perform Batch (Offline) Inference#

In this section, we’ll learn how to perform batch inference using the Torch Batch Processing API.

Step 1: Define an InferenceProcessor#

The first step is to define an InferenceProcessor. You should initialize your model in the __init__() function of the InferenceProcessor. You should implement process_batch() function with inference logic.

You can optionally implement on_checkpoint_start() and on_finish() to be run before every checkpoint and after all the data has been processed, respectively. For an example of how to accomplish this, visit our Torch Batch Process Embeddings example.

Define custom processor class
class InferenceProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.context = context
        self.model = context.prepare_model_for_inference(get_model())
        self.output = []
        self.last_index = 0

    def process_batch(self, batch, batch_idx) -> None:
        model_input = batch[0]
        model_input = self.context.to_device(model_input)

        with torch.no_grad():
            with self.profiler as p:
                pred = self.model(model_input)
                output = {"predictions": pred, "input": batch}

        self.last_index = batch_idx

    def on_checkpoint_start(self):
        During checkpoint, we persist prediction result
        if len(self.output) == 0:
        file_name = f"prediction_output_{self.last_index}"
        with self.context.upload_path() as path:
            file_path = pathlib.Path(path, file_name)
  , file_path)

        self.output = []

Step 2: Initialize the Dataset#

Initialize the dataset you want to process.

Initialize dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
with filelock.FileLock(os.path.join("/tmp", "inference.lock")):
    inference_data = tv.datasets.CIFAR10(
        root="/data", train=False, download=True, transform=transform

Step 3: Pass the InferenceProcessor Class and Dataset#

Finally, pass the InferenceProcessor class and the dataset to torch_batch_process.

Pass processor class and dataset to torch_batch_process