def test_dataloader_restarts(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts, num_replicas adaptdl.collective.initialize("0.0.0.0") dataset_size = 100 init_batch_size = 10 dataset = TensorDataset(torch.rand(dataset_size)) dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size) # Load data samples in the following order: # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10) # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12) # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10) assert current_dataloader() is None for idx, batch in enumerate(dataloader): if num_restarts() == 0 and idx == 2: adaptdl.checkpoint.save_all_states() return 4 # Restart with 4 replicas. if num_restarts() == 1 and idx == 5: adaptdl.checkpoint.save_all_states() return 2 # Restart with 2 replicas. assert current_dataloader() is dataloader._elastic local_bsz = batch[0].size(0) assert dataloader.current_local_bsz == local_bsz assert local_bsz == math.ceil(init_batch_size / num_replicas()) assert dataloader.current_batch_size == num_replicas() * local_bsz # After the last 2 batches. assert idx == 1
def _final_callback(self): # This method should be invoked once for each backward pass, after # gradients have been synchronized between each replica. self._final_callback_queued = False # self._sync_start should mark the last time any local gradient # from this module was produced. We assume the duration until now was # primarily spent waiting for gradient synchronization. # TODO: Depends on the internal behavior of DistributedDataParallel, # which might break with future versions of PyTorch. Any better # and well-supported way to measure the synchronization time? if isinstance(self._sync_start, torch.cuda.Event): sync_end = torch.cuda.Event(enable_timing=True) sync_end.record() sync_end.synchronize() profile_sync_time(self._sync_start.elapsed_time(sync_end) / 1e3) else: profile_sync_time(time.time() - self._sync_start) dataloader = current_dataloader() if dataloader is None: # Don't allow backpropagation outside of a dataloader loop, because # the batch size would be unknown. raise RuntimeError("backpropagation outside AdaptiveDataLoader") dataloader.train() scale = dataloader.current_batch_size / dataloader.batch_size self._state.gain = self.gns.gain(scale) self._state.lr_factor = \ np.average(self.scaling_rule.scale_lr(scale)) update_progress(self.gns.get_progress()) if dataloader.max_batch_size and \ dataloader.max_batch_size > dataloader.batch_size: update_grad_params(self._key, self.gns.sqr_avg(), self.gns.var_avg()) self._sync_start = None
def forward(self, *args, **kwargs): # Do not do gradient synchronization during gradient accumulation. dataloader = current_dataloader() if dataloader is not None and dataloader.training: self.require_backward_grad_sync = dataloader.is_optim_step() accum_scale = (dataloader.current_local_bsz * adaptdl.env.num_replicas() / dataloader.batch_size) self.gns.set_accum_scale(accum_scale) return super().forward(*args, **kwargs)
def test_dataloader_break(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts if num_restarts() == 0: return 2 adaptdl.collective.initialize("0.0.0.0") dataset = TensorDataset(torch.rand(100)) dataloader = AdaptiveDataLoader(dataset, batch_size=10) # Break in the middle of the first for-loop, and start another for-loop. assert current_dataloader() is None for idx, batch in enumerate(dataloader): assert current_dataloader() is dataloader._elastic if idx == 5: break assert current_dataloader() is None for idx, batch in enumerate(dataloader): assert current_dataloader() is dataloader._elastic assert idx == 9 # Run 10 batches total.
def __init__(self, *args, **kwargs): if current_dataloader() is not None: raise RuntimeError("accumulator may not be initialized during " "dataloader iteration") epoch = current_epoch() count = _AccumulatorState.init_count[epoch] super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count)) _AccumulatorState.init_count[epoch] += 1 self.results_history = collections.defaultdict(list) self.results = dict(*args, **kwargs) self.updates = {}
def scale_lr(self, scale): dataloader = current_dataloader() # total training steps for warm up total_steps = self._base_warmup_epochs * scale * \ self._data_size / dataloader.batch_size max_lr_multiplier = math.sqrt(scale) # effective training steps taken progress = self.adp.gns.get_progress() if progress < total_steps: lr_factor = max_lr_multiplier * (progress / total_steps) else: lr_factor = max_lr_multiplier return lr_factor
def forward(self, *args, **kwargs): # Do not do gradient synchronization during gradient accumulation # Otherwise, exactly the same as DistributedDataParallel's forward dataloader = current_dataloader() accumulation_steps = dataloader.accumulation_steps # TODO: move this to the dataloader.__iter__ self.adascale.set_accumulation_steps(accumulation_steps) if (self.adascale.is_accumulation_step()): with super().no_sync(): dataloader.is_accumulation_step = True return super().forward(*args, **kwargs) else: dataloader.is_accumulation_step = False return super().forward(*args, **kwargs)
def synchronized(self): """ A context manager which can be used to define the code to execute in *synchronized* mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python ``dict``. The application must ensure that whatever operations performed within this context block are the same across all replicas. .. warning:: Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code. """ if self._synchronized is not None: # Already synchronized, don't need to do anything. yield self return epoch = current_epoch() # Remove saved results from all finished epochs. Since finished # epochs are never replayed, they should never be needed again. for key in list(self._state.results_history.keys()): if key is not None and key < epoch: self._state.results_history.pop(key) # Get the number of synchronizations so far in the current epoch. count = self._sync_count[epoch] self._sync_count[epoch] += 1 results_list = self._state.results_history[epoch] assert count <= len(results_list) if count < len(results_list): # Results for this synchronization are saved in the history. self._synchronized = results_list[count] self._state.updates.clear() else: self._state.sync() # Sync results and updates across replicas. if current_dataloader() is None: # Only save into results history if outside of a dataloader # iteration, since code inside iterations are not replayed. results_list.append(copy.deepcopy(self._state.results)) self._synchronized = self._state.results try: yield self finally: self._synchronized = None