예제 #1
0
    def __init__(self,
                 model,
                 optimizer,
                 lr_scheduler=None,
                 mp_scaler=None,
                 name="adaptdl-dataparallel",
                 **kwargs):
        super().__init__(model, **kwargs)
        self._key = id(self)
        # Register backward hooks on model parameters. Depends on these hooks
        # being invoked before gradients are averaged. This is technically an
        # internal behavior of DistributedDataParallel, but seems to be abused
        # pretty widely so there should be little chance of it changing.
        # https://discuss.pytorch.org/t/59291
        for param in model.parameters():
            param.register_hook(functools.partial(self._backward_hook, param))

        # Setup for AdaScale, must be after registering backward hooks!
        self.adascale = AdaScale(self,
                                 optimizer,
                                 mp_scaler=mp_scaler,
                                 patch_optimizer=True)

        self._state = _AdaptiveDataParallelState(model, optimizer,
                                                 lr_scheduler, mp_scaler, name)
        adaptdl.checkpoint.load_state(self._state)

        self._sync_start = None
예제 #2
0
class AdaptiveDataParallel(DistributedDataParallel):
    """
    This class extends PyTorch DistributedDataParallel with support for
    adaptive batch sizes and checkpoint-restart elasticity. It automatically
    saves the given model, optimizer, and (optionally) LR scheduler whenever a
    checkpoint is triggered, and restores their states after restart. The
    optimizer is automatically patched with AdaScale.

    Arguments:
        model (torch.nn.Module): Model to be distributed.
        optimizer (torch.optim.Optimizer): Optimizer used to update the given
            model's parameters, will be patched using
            :class:`adaptdl.torch.adascale.AdaScale`.
        lr_scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used
            to anneal the learning rate for the given optimizer.
        name (string): Unique name for each instance of this class, needed only
            if multiple instances exist.
    """
    def __init__(self,
                 model,
                 optimizer,
                 lr_scheduler=None,
                 name="adaptdl-dataparallel",
                 **kwargs):
        super().__init__(model, **kwargs)
        self._key = id(self)
        # Register backward hooks on model parameters. Depends on these hooks
        # being invoked before gradients are averaged. This is technically an
        # internal behavior of DistributedDataParallel, but seems to be abused
        # pretty widely so there should be little chance of it changing.
        # https://discuss.pytorch.org/t/59291
        for param in model.parameters():
            param.register_hook(functools.partial(self._backward_hook, param))

        # Setup for AdaScale, must be after registering backward hooks!
        self.adascale = AdaScale(optimizer, patch_optimizer=True)

        self._state = _AdaptiveDataParallelState(model, optimizer,
                                                 lr_scheduler, name)
        adaptdl.checkpoint.load_state(self._state)

        self._sync_start = None

    def _backward_hook(self, param, grad):
        # This method should be invoked once for each parameter during the
        # backward pass, before gradients are synchronized between replicas.
        if grad.device.type.startswith("cuda"):
            self._sync_start = torch.cuda.Event(enable_timing=True)
            self._sync_start.record()
        else:
            self._sync_start = time.time()
        self._final_callback_queued = False
        Variable._execution_engine.queue_callback(self._queue_callback)

    def _queue_callback(self):
        # This method should be invoked after the entire backward pass. We want
        # to make sure self._final_callback is invoked once, only after all
        # gradients have been synchronized between each replica. However, the
        # synchronization code in DistributedDataParallel is also done in a
        # callback, which might not yet be executed. Therefore, we enqueue
        # self._final_callback from this method, which should ensure it is
        # invoked after the gradient synchronization callback.
        if self._final_callback_queued:
            return
        self._final_callback_queued = True
        Variable._execution_engine.queue_callback(self._final_callback)

    def _final_callback(self):
        # This method should be invoked once for each backward pass, after
        # gradients have been synchronized between each replica.
        self._final_callback_queued = False
        # self._sync_start should mark the last time any local gradient
        # from this module was produced. We assume the duration until now was
        # primarily spent waiting for gradient synchronization.
        # TODO: Depends on the internal behavior of DistributedDataParallel,
        #       which might break with future versions of PyTorch. Any better
        #       and well-supported way to measure the synchronization time?
        if isinstance(self._sync_start, torch.cuda.Event):
            sync_end = torch.cuda.Event(enable_timing=True)
            sync_end.record()
            sync_end.synchronize()
            profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3)
        else:
            profile_sync_time(time.time() - self._sync_start)

        dataloader = current_dataloader()
        if dataloader is None:
            # Don't allow backpropagation outside of a dataloader loop, because
            # the batch size would be unknown.
            raise RuntimeError("backpropagation outside AdaptiveDataLoader")
        dataloader.train()

        scale = dataloader.current_batch_size / dataloader.batch_size
        self.adascale.set_scale(scale)
        self._state.gain = self.adascale.gain()
        adaptdl.torch._metrics.update_progress(self._state.gain)
        if dataloader.max_batch_size and \
                dataloader.max_batch_size > dataloader.batch_size:
            adaptdl.torch._metrics.update_grad_params(self._key,
                                                      self.adascale.norm_avg(),
                                                      self.adascale.var_avg())
        self._sync_start = None

    @property
    def gain(self):  # TODO: should be tracked in the metrics module instead.
        """
        Current estimate of the AdaScale gain (r_t) value.
        """
        return self._state.gain