def _final_callback(self): # This method should be invoked once for each backward pass, after # gradients have been synchronized between each replica. self._final_callback_queued = False # self._sync_start should mark the last time any local gradient # from this module was produced. We assume the duration until now was # primarily spent waiting for gradient synchronization. # TODO: Depends on the internal behavior of DistributedDataParallel, # which might break with future versions of PyTorch. Any better # and well-supported way to measure the synchronization time? if isinstance(self._sync_start, torch.cuda.Event): sync_end = torch.cuda.Event(enable_timing=True) sync_end.record() sync_end.synchronize() profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3) else: profile_sync_time(time.time() - self._sync_start) dataloader = current_dataloader() if dataloader is None: # Don't allow backpropagation outside of a dataloader loop, because # the batch size would be unknown. raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() scale = dataloader.current_batch_size / dataloader.batch_size self._state.gain = self.gns.gain(scale) self._state.lr_factor = \ np.average(self.scaling_rule.scale_lr(scale)) update_progress(self.gns.get_progress()) if dataloader.max_batch_size and \ dataloader.max_batch_size > dataloader.batch_size: update_grad_params(self._key, self.gns.sqr_avg(), self.gns.var_avg()) self._sync_start = None
def test_profile(num_replicas): import adaptdl.checkpoint from adaptdl.env import num_restarts from adaptdl.torch._metrics import ( profile_step_start, profile_sync_time, profile_step_commit, _metrics_state) if num_restarts() == 0: profile = _metrics_state().profile assert len(profile) == 0 # Profile local_bsz=1 but don't commit. profile_step_start(1) profile_sync_time(1.0) # Profile local_bsz=2 and commit. profile_step_start(2) profile_sync_time(1.0) profile_sync_time(2.0) profile_step_commit() # Ensure profile is updated correctly. profile = _metrics_state().profile assert len(profile) == 1 assert profile[1, 1, 2]["count"] == 1 assert profile[1, 1, 2]["sync_time"] == 3.0 assert profile[1, 1, 2]["step_time"] > 0.0 # Checkpoint and restart. adaptdl.checkpoint.save_all_states() return num_replicas elif num_restarts() == 1: profile = _metrics_state().profile # Ensure checkpoint is loaded correctly. assert len(profile) == 1 assert profile[1, 1, 2]["count"] == 1 assert profile[1, 1, 2]["sync_time"] == 3.0 assert profile[1, 1, 2]["step_time"] > 0.0 # Profile local_bsz=3 and commit twice. profile_step_start(3) profile_sync_time(2.0) profile_sync_time(3.0) profile_step_commit() old_step_time = profile[1, num_replicas, 3]["step_time"] profile_step_start(3) profile_sync_time(3.0) profile_sync_time(4.0) profile_step_commit() # Ensure profile is updated correctly. assert len(profile) == 2 assert profile[1, num_replicas, 3]["count"] == 2 assert profile[1, num_replicas, 3]["sync_time"] == 12.0 assert profile[1, num_replicas, 3]["step_time"] > old_step_time > 0.0
def test_profile_accumulation(num_replicas): import adaptdl.checkpoint from adaptdl.env import num_restarts from adaptdl.torch._metrics import ( profile_step_start, profile_sync_time, profile_step_commit, _metrics_state, _fit_perf_params) if num_restarts() == 0: profile = _metrics_state().profile assert len(profile) == 0 # Profile local_bsz=1 but don't commit. profile_step_start(1) profile_sync_time(1.0) # Profile local_bsz=2 and commit. profile_step_start(2) profile_step_commit(accumulation_step=True) profile_step_start(2) profile_step_commit(accumulation_step=True) profile_step_start(2) profile_sync_time(4.0) profile_step_commit(accumulation_step=False) profile_step_start(5) profile_step_commit(accumulation_step=True) profile_step_start(5) profile_step_commit(accumulation_step=True) profile_step_start(5) profile_sync_time(6.0) profile_step_commit(accumulation_step=False) # Ensure profile is updated correctly. profile = _metrics_state().profile key = (1, 1, 2) assert len(profile) == 2 assert profile[key]["accum_count"] == 2 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 4.0 assert profile[key]["accum_step_time"] > 0.0 assert profile[key]["optim_step_time"] > 0.0 profile_step_start(3) profile_step_commit(accumulation_step=True) profile_step_start(3) profile_step_commit(accumulation_step=True) # Check that fitting parameters works even # without a final accumulation_step=False commit for val in profile.values(): # Ensure step time is at least sync time. val["optim_step_time"] += val["optim_sync_time"] _fit_perf_params() # Checkpoint and restart. adaptdl.checkpoint.save_all_states() return num_replicas elif num_restarts() == 1: profile = _metrics_state().profile # Ensure checkpoint is loaded correctly. key = (1, 1, 2) assert len(profile) == 3 assert profile[key]["accum_count"] == 2 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 4.0 assert profile[key]["optim_step_time"] > 0.0 # Profile local_bsz=3 and commit twice. profile_step_start(3) profile_sync_time(2.0) profile_sync_time(3.0) profile_step_commit() key = (1, num_replicas, 3) old_step_time = profile[key]["optim_step_time"] profile_step_start(3) profile_sync_time(3.0) profile_sync_time(4.0) profile_step_commit() # Ensure profile is updated correctly. if num_replicas == 1: assert len(profile) == 3 else: assert len(profile) == 4 assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2 assert profile[key]["optim_count"] == 2 assert profile[key]["optim_sync_time"] == 12.0 assert profile[key]["optim_step_time"] > old_step_time > 0.0