Python API determined.pytorch.samplers

Guidelines for Reproducible Datasets

Note

Normally, using det.pytorch.DataLoader is required and handles all of the below details without any special effort on your part (see Data Loading). When det.pytorch.DataLoader() is not suitable (especially in the case of IterableDatasets), you may disable this requirement by calling context.experimental.disable_dataset_reproducibility_checks() in your Trial’s __init__() method. Then you may choose to follow the below guidelines for ensuring dataset reproducibility on your own.

Achieving a reproducible dataset that is able to pause and continue (sometimes called “incremental training”) is easy if you follow a few rules.

  • Even if you are going to ultimately return an IterableDataset, it is best to use PyTorch’s Sampler class as the basis for choosing the order of records. Operations on Samplers are quick and cheap, while operations on data afterwards are expensive. For more details, see the discussion of random vs sequential access here. If you don’t have a custom sampler, start with a simple one:

  • Shuffle first: Always use a reproducible shuffle when you shuffle. Determined provides two shuffling samplers for this purpose; the ReproducibleShuffleSampler for operating on records and the ReproducibleShuffleBatchSampler for operating on batches. You should prefer to shuffle on records (use the ReproducibleShuffleSampler) whenever possible, to achieve the highest-quality shuffle.

  • Repeat when training: In Determined, you always repeat your training dataset and you never repeat your validation datasets. Determined provides a RepeatSampler and a RepeatBatchSampler to wrap your sampler or batch_sampler. For your training dataset, make sure that you always repeat AFTER you shuffle, otherwise your shuffle will hang.

  • Always shard, and not before a repeat: Use Determined’s DistributedSampler or DistributedBatchSampler to provide a unique shard of data to each worker based on your sampler or batch_sampler. It is best to always shard your data, and even when you are not doing distributed training, because in non-distributed-training settings, the sharding is nearly zero-cost, and it makes distributed training seamless if you ever want to use it in the future.

    It is generally important to shard after you repeat, unless you can guarantee that each shard of the dataset will have the same length. Otherwise, differences between the epoch boundaries for each worker can grow over time, especially on small datasets. If you shard after you repeat, you can change the number of workers arbitrarily without issue.

  • Skip when training, and always last: In Determined, training datasets should always be able to start from an arbitrary point in the dataset. This allows for advanced hyperparameter searches and responsive preemption for training on spot instances in the cloud. The easiest way to do this, which is also very efficient, is to apply a skip to the sampler.

    Determined provides a SkipBatchSampler that you can apply to your batch_sampler for this purpose. There is also a SkipSampler that you can apply to your sampler, but you should prefer to skip on batches unless you are confident that your dataset always yields identical size batches, where the number of records to skip can be reliably calculatd from the number of batches already trained.

    Always skip AFTER your repeat, so that the skip only happens once, and not on every epoch.

    Always skip AFTER your shuffle, to preserve the reproducibility of the shuffle.

Here is some example code that follows each of these rules that you can use as a starting point if you find that the built-in context.DataLoader() does not support your use case.

class determined.pytorch.samplers.DistributedBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, num_workers: int, rank: int)

DistributedBatchSampler will iterate through an underlying batch sampler and return batches which belong to this shard.

DistributedBatchSampler is different than the PyTorch built-in torch.utils.data.distributed.DistributedSampler, because that DistributedSampler expects to bbe called before the BatchSampler, and additionally the DistributedSampler is meant to be a stand-alone sampler.

DistributedBatchSampler has the potential gotcha that when wrapping a non-repeating BatchSampler, if the length of the BatchSampler is not divisible by the number of replicas the length of the resulting DistributedBatchSampler will differ based on the rank. In that case, the divergent paths of multiple workers could cause problems during training. PyTorchTrial always uses RepeatBatchSampler during training, PyTorchTrial does not require that the workers stay in-step during validation, so this potential gotcha is not a problem in Determined.

class determined.pytorch.samplers.DistributedSampler(sampler: torch.utils.data.sampler.Sampler, num_workers: int, rank: int)

DistributedSampler will iterate through an underlying sampler and return samples which belong to this shard.

DistributedSampler is different than the PyTorch built-in torch.utils.data.DistributedSampler because theirs is meant to be a standalone sampler. Theirs does shuffling and assumes a constant size dataset as an input. Ours is meant to be used a building block in a chain of samplers, so it accepts a sampler as input that may or may not be constant-size.

class determined.pytorch.samplers.RepeatBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler)

RepeatBatchSampler yields infinite batches indices by repeatedly iterating through the batches of another BatchSampler. __len__ is just the length of the underlying BatchSampler.

class determined.pytorch.samplers.RepeatSampler(sampler: torch.utils.data.sampler.Sampler)

RepeatSampler yields infinite batches indices by repeatedly iterating through the batches of another Sampler. __len__ is just the length of the underlying Sampler.

class determined.pytorch.samplers.ReproducibleShuffleBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, seed: int)

ReproducibleShuffleBatchSampler will apply a deterministic shuffle based on a seed.

Warning

Always shuffle before skipping and before repeating. Skip-before-shuffle would break the reproducibility of the shuffle, and repeat-before-shuffle would cause the shuffle to hang as it iterates through an infinite sampler.

Warning

Always prefer ReproducibleShuffleSampler over this class when possible. The reason is that shuffling at the batch level results in a superior shuffle, where the contents of each batch are varied between epochs, rather than just the order of batches.

class determined.pytorch.samplers.ReproducibleShuffleSampler(sampler: torch.utils.data.sampler.Sampler, seed: int)

ReproducibleShuffleSampler will apply a deterministic shuffle based on a seed.

Warning

Always shuffle before skipping and before repeating. Skip-before-shuffle would break the reproducibility of the shuffle, and repeat-before-shuffle would cause the shuffle to hang as it iterates through an infinite sampler.

class determined.pytorch.samplers.SkipBatchSampler(batch_sampler: torch.utils.data.sampler.BatchSampler, skip: int)

SkipBatchSampler skips some batches from an underlying BatchSampler, and yield the rest.

Always skip before you repeat when you are continuing training, or you will apply the skip on every epoch.

Because the SkipBatchSampler is only meant to be used on a training dataset (we never checkpoint during evaluation), and because the training dataset should always be repeated before applying the skip (so you only skip once rather than many times), the length reported is always the length of the underlying sampler, regardless of the size of the skip.

class determined.pytorch.samplers.SkipSampler(sampler: torch.utils.data.sampler.BatchSampler, skip: int)

SkipSampler skips some records from an underlying Sampler, and yields the rest.

Always skip before you repeat when you are continuing training, or you will apply the skip on every epoch.

Warning

When trying to achieve reproducibility after pausing and restarting, you should never prefer this SkipSampler over the SkipBatchSampler, unless you are sure that your dataset will always yield identically sized batches. This is due to how Determined counts batches trained but does not count records trained. Reproducibility when skipping records is only possible if the records to skip can be reliably calculated based on batch size and batches trained.

Because the SkipSampler is only meant to be used on a training dataset (we never checkpoint during evaluation), and because the training dataset should always be repeated before applying the skip (so you only skip once rather than many times), the length reported is always the length of the underlying sampler, regardless of the size of the skip.