예제 #1
0
    def _final_callback(self):
        # This method should be invoked once for each backward pass, after
        # gradients have been synchronized between each replica.
        self._final_callback_queued = False
        # self._sync_start should mark the last time any local gradient
        # from this module was produced. We assume the duration until now was
        # primarily spent waiting for gradient synchronization.
        # TODO: Depends on the internal behavior of DistributedDataParallel,
        #       which might break with future versions of PyTorch. Any better
        #       and well-supported way to measure the synchronization time?
        if isinstance(self._sync_start, torch.cuda.Event):
            sync_end = torch.cuda.Event(enable_timing=True)
            sync_end.record()
            sync_end.synchronize()
            profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3)
        else:
            profile_sync_time(time.time() - self._sync_start)

        dataloader = current_dataloader()
        if dataloader is None:
            # Don't allow backpropagation outside of a dataloader loop, because
            # the batch size would be unknown.
            raise RuntimeError("backpropagation outside AdaptiveDataLoader")
        dataloader.train()

        scale = dataloader.current_batch_size / dataloader.batch_size
        self._state.gain = self.gns.gain(scale)
        self._state.lr_factor = \
            np.average(self.scaling_rule.scale_lr(scale))
        update_progress(self.gns.get_progress())
        if dataloader.max_batch_size and \
                dataloader.max_batch_size > dataloader.batch_size:
            update_grad_params(self._key, self.gns.sqr_avg(),
                               self.gns.var_avg())
        self._sync_start = None
예제 #2
0
def test_profile(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_sync_time(1.0)
        profile_sync_time(2.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        assert len(profile) == 1
        assert profile[1, 1, 2]["count"] == 1
        assert profile[1, 1, 2]["sync_time"] == 3.0
        assert profile[1, 1, 2]["step_time"] > 0.0
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        assert len(profile) == 1
        assert profile[1, 1, 2]["count"] == 1
        assert profile[1, 1, 2]["sync_time"] == 3.0
        assert profile[1, 1, 2]["step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        old_step_time = profile[1, num_replicas, 3]["step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        assert len(profile) == 2
        assert profile[1, num_replicas, 3]["count"] == 2
        assert profile[1, num_replicas, 3]["sync_time"] == 12.0
        assert profile[1, num_replicas, 3]["step_time"] > old_step_time > 0.0
예제 #3
0
def test_profile_accumulation(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state, _fit_perf_params)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_sync_time(4.0)
        profile_step_commit(accumulation_step=False)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_sync_time(6.0)
        profile_step_commit(accumulation_step=False)
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        key = (1, 1, 2)
        assert len(profile) == 2
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["accum_step_time"] > 0.0
        assert profile[key]["optim_step_time"] > 0.0
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        # Check that fitting parameters works even
        # without a final accumulation_step=False commit
        for val in profile.values():
            # Ensure step time is at least sync time.
            val["optim_step_time"] += val["optim_sync_time"]
        _fit_perf_params()
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        key = (1, 1, 2)
        assert len(profile) == 3
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["optim_step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        key = (1, num_replicas, 3)
        old_step_time = profile[key]["optim_step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        if num_replicas == 1:
            assert len(profile) == 3
        else:
            assert len(profile) == 4
        assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2
        assert profile[key]["optim_count"] == 2
        assert profile[key]["optim_sync_time"] == 12.0
        assert profile[key]["optim_step_time"] > old_step_time > 0.0