def __init__(self, model, optimizer, lr_scheduler=None, mp_scaler=None, name="adaptdl-dataparallel", **kwargs): super().__init__(model, **kwargs) self._key = id(self) # Register backward hooks on model parameters. Depends on these hooks # being invoked before gradients are averaged. This is technically an # internal behavior of DistributedDataParallel, but seems to be abused # pretty widely so there should be little chance of it changing. # https://discuss.pytorch.org/t/59291 for param in model.parameters(): param.register_hook(functools.partial(self._backward_hook, param)) # Setup for AdaScale, must be after registering backward hooks! self.adascale = AdaScale(self, optimizer, mp_scaler=mp_scaler, patch_optimizer=True) self._state = _AdaptiveDataParallelState(model, optimizer, lr_scheduler, mp_scaler, name) adaptdl.checkpoint.load_state(self._state) self._sync_start = None
class AdaptiveDataParallel(DistributedDataParallel): """ This class extends PyTorch DistributedDataParallel with support for adaptive batch sizes and checkpoint-restart elasticity. It automatically saves the given model, optimizer, and (optionally) LR scheduler whenever a checkpoint is triggered, and restores their states after restart. The optimizer is automatically patched with AdaScale. Arguments: model (torch.nn.Module): Model to be distributed. optimizer (torch.optim.Optimizer): Optimizer used to update the given model's parameters, will be patched using :class:`adaptdl.torch.adascale.AdaScale`. lr_scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used to anneal the learning rate for the given optimizer. name (string): Unique name for each instance of this class, needed only if multiple instances exist. """ def __init__(self, model, optimizer, lr_scheduler=None, name="adaptdl-dataparallel", **kwargs): super().__init__(model, **kwargs) self._key = id(self) # Register backward hooks on model parameters. Depends on these hooks # being invoked before gradients are averaged. This is technically an # internal behavior of DistributedDataParallel, but seems to be abused # pretty widely so there should be little chance of it changing. # https://discuss.pytorch.org/t/59291 for param in model.parameters(): param.register_hook(functools.partial(self._backward_hook, param)) # Setup for AdaScale, must be after registering backward hooks! self.adascale = AdaScale(optimizer, patch_optimizer=True) self._state = _AdaptiveDataParallelState(model, optimizer, lr_scheduler, name) adaptdl.checkpoint.load_state(self._state) self._sync_start = None def _backward_hook(self, param, grad): # This method should be invoked once for each parameter during the # backward pass, before gradients are synchronized between replicas. if grad.device.type.startswith("cuda"): self._sync_start = torch.cuda.Event(enable_timing=True) self._sync_start.record() else: self._sync_start = time.time() self._final_callback_queued = False Variable._execution_engine.queue_callback(self._queue_callback) def _queue_callback(self): # This method should be invoked after the entire backward pass. We want # to make sure self._final_callback is invoked once, only after all # gradients have been synchronized between each replica. However, the # synchronization code in DistributedDataParallel is also done in a # callback, which might not yet be executed. Therefore, we enqueue # self._final_callback from this method, which should ensure it is # invoked after the gradient synchronization callback. if self._final_callback_queued: return self._final_callback_queued = True Variable._execution_engine.queue_callback(self._final_callback) def _final_callback(self): # This method should be invoked once for each backward pass, after # gradients have been synchronized between each replica. self._final_callback_queued = False # self._sync_start should mark the last time any local gradient # from this module was produced. We assume the duration until now was # primarily spent waiting for gradient synchronization. # TODO: Depends on the internal behavior of DistributedDataParallel, # which might break with future versions of PyTorch. Any better # and well-supported way to measure the synchronization time? if isinstance(self._sync_start, torch.cuda.Event): sync_end = torch.cuda.Event(enable_timing=True) sync_end.record() sync_end.synchronize() profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3) else: profile_sync_time(time.time() - self._sync_start) dataloader = current_dataloader() if dataloader is None: # Don't allow backpropagation outside of a dataloader loop, because # the batch size would be unknown. raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() scale = dataloader.current_batch_size / dataloader.batch_size self.adascale.set_scale(scale) self._state.gain = self.adascale.gain() adaptdl.torch._metrics.update_progress(self._state.gain) if dataloader.max_batch_size and \ dataloader.max_batch_size > dataloader.batch_size: adaptdl.torch._metrics.update_grad_params(self._key, self.adascale.norm_avg(), self.adascale.var_avg()) self._sync_start = None @property def gain(self): # TODO: should be tracked in the metrics module instead. """ Current estimate of the AdaScale gain (r_t) value. """ return self._state.gain