Esempio n. 1
0
def test_dataloader_restarts():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts, num_replicas
    adaptdl.collective.initialize("0.0.0.0")
    dataset_size = 100
    init_batch_size = 10
    dataset = TensorDataset(torch.rand(dataset_size))
    dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size)
    # Load data samples in the following order:
    # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10)
    # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12)
    # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10)
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        if num_restarts() == 0 and idx == 2:
            adaptdl.checkpoint.save_all_states()
            return 4  # Restart with 4 replicas.
        if num_restarts() == 1 and idx == 5:
            adaptdl.checkpoint.save_all_states()
            return 2  # Restart with 2 replicas.
        assert current_dataloader() is dataloader._elastic
        local_bsz = batch[0].size(0)
        assert dataloader.current_local_bsz == local_bsz
        assert local_bsz == math.ceil(init_batch_size / num_replicas())
        assert dataloader.current_batch_size == num_replicas() * local_bsz
    # After the last 2 batches.
    assert idx == 1
Esempio n. 2
0
    def _final_callback(self):
        # This method should be invoked once for each backward pass, after
        # gradients have been synchronized between each replica.
        self._final_callback_queued = False
        # self._sync_start should mark the last time any local gradient
        # from this module was produced. We assume the duration until now was
        # primarily spent waiting for gradient synchronization.
        # TODO: Depends on the internal behavior of DistributedDataParallel,
        #       which might break with future versions of PyTorch. Any better
        #       and well-supported way to measure the synchronization time?
        if isinstance(self._sync_start, torch.cuda.Event):
            sync_end = torch.cuda.Event(enable_timing=True)
            sync_end.record()
            sync_end.synchronize()
            profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3)
        else:
            profile_sync_time(time.time() - self._sync_start)

        dataloader = current_dataloader()
        if dataloader is None:
            # Don't allow backpropagation outside of a dataloader loop, because
            # the batch size would be unknown.
            raise RuntimeError("backpropagation outside AdaptiveDataLoader")
        dataloader.train()

        scale = dataloader.current_batch_size / dataloader.batch_size
        self._state.gain = self.gns.gain(scale)
        self._state.lr_factor = \
            np.average(self.scaling_rule.scale_lr(scale))
        update_progress(self.gns.get_progress())
        if dataloader.max_batch_size and \
                dataloader.max_batch_size > dataloader.batch_size:
            update_grad_params(self._key, self.gns.sqr_avg(),
                               self.gns.var_avg())
        self._sync_start = None
Esempio n. 3
0
 def forward(self, *args, **kwargs):
     # Do not do gradient synchronization during gradient accumulation.
     dataloader = current_dataloader()
     if dataloader is not None and dataloader.training:
         self.require_backward_grad_sync = dataloader.is_optim_step()
         accum_scale = (dataloader.current_local_bsz *
                        adaptdl.env.num_replicas() / dataloader.batch_size)
         self.gns.set_accum_scale(accum_scale)
     return super().forward(*args, **kwargs)
Esempio n. 4
0
def test_dataloader_break():
    import adaptdl.checkpoint
    import adaptdl.collective
    from adaptdl.env import num_restarts
    if num_restarts() == 0:
        return 2
    adaptdl.collective.initialize("0.0.0.0")
    dataset = TensorDataset(torch.rand(100))
    dataloader = AdaptiveDataLoader(dataset, batch_size=10)
    # Break in the middle of the first for-loop, and start another for-loop.
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        assert current_dataloader() is dataloader._elastic
        if idx == 5:
            break
    assert current_dataloader() is None
    for idx, batch in enumerate(dataloader):
        assert current_dataloader() is dataloader._elastic
    assert idx == 9  # Run 10 batches total.
Esempio n. 5
0
    def __init__(self, *args, **kwargs):
        if current_dataloader() is not None:
            raise RuntimeError("accumulator may not be initialized during "
                               "dataloader iteration")
        epoch = current_epoch()
        count = _AccumulatorState.init_count[epoch]
        super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count))
        _AccumulatorState.init_count[epoch] += 1

        self.results_history = collections.defaultdict(list)
        self.results = dict(*args, **kwargs)
        self.updates = {}
Esempio n. 6
0
 def scale_lr(self, scale):
     dataloader = current_dataloader()
     # total training steps for warm up
     total_steps = self._base_warmup_epochs * scale * \
         self._data_size / dataloader.batch_size
     max_lr_multiplier = math.sqrt(scale)
     # effective training steps taken
     progress = self.adp.gns.get_progress()
     if progress < total_steps:
         lr_factor = max_lr_multiplier * (progress / total_steps)
     else:
         lr_factor = max_lr_multiplier
     return lr_factor
Esempio n. 7
0
 def forward(self, *args, **kwargs):
     # Do not do gradient synchronization during gradient accumulation
     # Otherwise, exactly the same as DistributedDataParallel's forward
     dataloader = current_dataloader()
     accumulation_steps = dataloader.accumulation_steps
     # TODO: move this to the dataloader.__iter__
     self.adascale.set_accumulation_steps(accumulation_steps)
     if (self.adascale.is_accumulation_step()):
         with super().no_sync():
             dataloader.is_accumulation_step = True
             return super().forward(*args, **kwargs)
     else:
         dataloader.is_accumulation_step = False
         return super().forward(*args, **kwargs)
Esempio n. 8
0
    def synchronized(self):
        """
        A context manager which can be used to define the code to execute in
        *synchronized* mode. Within the context manager, any code can interact
        with this accumulator as if it were a regular Python ``dict``. The
        application must ensure that whatever operations performed within this
        context block are the same across all replicas.

        .. warning::
            Entering this context manager is a distributed synchronization
            point! Please ensure that all replicas enter this context manager
            at the same point in their code.
        """
        if self._synchronized is not None:
            # Already synchronized, don't need to do anything.
            yield self
            return
        epoch = current_epoch()
        # Remove saved results from all finished epochs. Since finished
        # epochs are never replayed, they should never be needed again.
        for key in list(self._state.results_history.keys()):
            if key is not None and key < epoch:
                self._state.results_history.pop(key)
        # Get the number of synchronizations so far in the current epoch.
        count = self._sync_count[epoch]
        self._sync_count[epoch] += 1
        results_list = self._state.results_history[epoch]
        assert count <= len(results_list)
        if count < len(results_list):
            # Results for this synchronization are saved in the history.
            self._synchronized = results_list[count]
            self._state.updates.clear()
        else:
            self._state.sync()  # Sync results and updates across replicas.
            if current_dataloader() is None:
                # Only save into results history if outside of a dataloader
                # iteration, since code inside iterations are not replayed.
                results_list.append(copy.deepcopy(self._state.results))
            self._synchronized = self._state.results
        try:
            yield self
        finally:
            self._synchronized = None