Torch Batch Processing API#

In this guide, you’ll learn about the Torch Batch Process API and how to perform batch inference (also known as offline inference).

Visit the API reference

pytorch.experimental.torch_batch_process API Reference

Caution

This is an experimental API and may change at any time.

Overview#

The Torch Batch Processing API takes in (1) a dataset and (2) a user-defined processor class and runs distributed data processing.

This API automatically handles the following for you:

  • shards a dataset by number of workers available

  • applies user-defined logic to each batch of data

  • handles synchronization between workers

  • tracks job progress to enable preemption and resumption of trial

This is a flexible API that can be used for many different tasks, including batch (offline) inference.

If you have some trained models in a Checkpoint or a Model with more than one ModelVersion inside, you can associate the trial with the Checkpoint or ModelVersion used in a given inference run to aggregate custom inference metrics.

You can then query those Checkpoint or ModelVersion objects using the Python SDK to see all metrics associated with them.

Usage#

The main arguments to torch_batch_process() are processor class and dataset.

torch_batch_process(
    batch_processor_cls=MyProcessor
    dataset=dataset
)

In the experiment config file, use a distributed launcher as the API requires information such as rank set by the launcher. Below is an example.

entrypoint: >-
    python3 -m determined.launch.torch_distributed
    python3 batch_processing.py
resources:
  slots_per_trial: 4

TorchBatchProcessor#

During __init__() of TorchBatchProcessor, we pass in a TorchBatchProcessorContext object, which contains useful methods that can be used within the TorchBatchProcessor class.

TorchBatchProcessor is compatible with Determined’s MetricReducer. You can pass MetricReducer to TorchBatchProcessor as follow:

TorchBatchProcessorContext#

TorchBatchProcessorContext should be a subclass of TorchBatchProcessor. The two functions you must implement are the __init__() and process_batch(). The other lifecycle functions are optional.

class MyProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.reducer = context.wrap_reducer(reducer=AccuracyMetricReducer(), name="accuracy")

How To Perform Batch (Offline) Inference#

In this section, we’ll learn how to perform batch inference using the Torch Batch Processing API.

Step 1: Define an InferenceProcessor#

The first step is to define an InferenceProcessor. You should initialize your model in the __init__() function of the InferenceProcessor. You should implement process_batch() function with inference logic.

You can optionally implement on_checkpoint_start() and on_finish() to be run before every checkpoint and after all the data has been processed, respectively.

"""
Define custom processor class
"""
class InferenceProcessor(TorchBatchProcessor):
    def __init__(self, context):
        self.context = context
        self.model = context.prepare_model_for_inference(get_model())
        self.output = []
        self.last_index = 0

    def process_batch(self, batch, batch_idx) -> None:
        model_input = batch[0]
        model_input = self.context.to_device(model_input)

        with torch.no_grad():
            with self.profiler as p:
                pred = self.model(model_input)
                p.step()
                output = {"predictions": pred, "input": batch}
                self.output.append(output)

        self.last_index = batch_idx

    def on_checkpoint_start(self):
        """
        During checkpoint, we persist prediction result
        """
        if len(self.output) == 0:
            return
        file_name = f"prediction_output_{self.last_index}"
        with self.context.upload_path() as path:
            file_path = pathlib.Path(path, file_name)
            torch.save(self.output, file_path)

        self.output = []

Step 3: Initialize the Dataset#

Initialize the dataset you want to process.

"""
Initialize dataset
"""
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
with filelock.FileLock(os.path.join("/tmp", "inference.lock")):
    inference_data = tv.datasets.CIFAR10(
        root="/data", train=False, download=True, transform=transform
    )

Step 4: Pass the InferenceProcessor Class and Dataset#

Pass the InferenceProcessor class and the dataset to torch_batch_process.

"""
Pass processor class and dataset to torch_batch_process
"""
torch_batch_process(
     InferenceProcessor,
     dataset,
     batch_size=64,
     checkpoint_interval=10
 )

Step 5: Send and Query Custom Inference Metrics (Optional)#

Report metrics anywhere in the trial to have them aggregated for the Checkpoint or ModelVersion in question.

For example, you could send metrics in on_finish().

def on_finish(self):
    self.context.report_metrics(
        group="inference",
        steps_completed=self.rank,
        metrics={
            "my_metric": 1.0,
        },
    )

And check the metric afterwards from the SDK:

from determined.experimental import client

# Checkpoint
ckpt = client.get_checkpoint("<CHECKPOINT_UUID>")
metrics = ckpt.get_metrics("inference")

# Or Model Version
model = client.get_model("<MODEL_NAME>")
model_version = model.get_version(MODEL_VERSION_NUM)
metrics = model_version.get_metrics("inference")