PyTorchTrial
to DeepSpeedTrial
#
Adapting an existing PyTorchTrial
to use DeepSpeed mirrors the process
for adapting existing code to use DeepSpeed outside of Determined.
The first step is to switch to the DeepSpeed trial and context objects. Next, you need to initialize the model engine and replace the context calls with appropriate replacements. Remember to modify the experiment configuration, specifying an appropriate DeepSpeed configuration.
Reference conversion example:
+import deepspeed
-from determined import pytorch
+from determined.pytorch import deepspeed as det_ds
-class MyTrial(pytorch.PyTorchTrial):
+class MyTrial(det_ds.DeepSpeedTrial):
def __init__(self, context):
self.context = context
self.args = AttrDict(self.context.get_hparams())
net = ...
optimizer = ...
- self.model = self.context.wrap_model(net)
- self.optimizer = self.context.wrap_optimizer(optimizer)
+ model_engine = deepspeed.initialize(
+ args=self.args,
+ model=net,
+ optimizer=optimizer,
+ ...
+ )
+ self.model = self.context.wrap_model_engine(model_engine)
def build_training_data_loader(self) -> Any:
trainset = ...
return DataLoader(
trainset,
- batch_size=self.context.get_per_slot_batch_size(),
+ batch_size=self.model.train_micro_batch_size_per_gpu(),
shuffle=True
)
def build_validation_data_loader(self) -> Any:
valset = ...
return DataLoader(
valset,
- batch_size=self.context.get_per_slot_batch_size(),
+ batch_size=self.model.train_micro_batch_size_per_gpu(),
shuffle=True
)
- def train_batch(self, batch, epoch_idx, batch_idx):
+ def train_batch(self, iter_dataloader, epoch_idx, batch_idx):
- inputs, targets = batch
+ inputs, targets = self.context.to_device(
+ next(iter_dataloader)
+ ) # Get a batch from the iterator
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
- self.context.backward(loss)
- self.context.step_optimizer(self.optimizer)
+ self.model.backward(loss)
+ self.model.step()
return {"loss": loss}
- def evaluate_batch(self, batch, batch_idx):
+ def evaluate_batch(self, iter_dataloader, batch_idx):
- inputs, targets = batch
+ inputs, targets = self.context.to_device(
+ next(iter_dataloader)
+ ) # Get a batch from the iterator
outputs = self.model(inputs)
metric = ...
return {"metric": metric}