def __init__(self, *args, **kwargs): if current_dataloader() is not None: raise RuntimeError("accumulator may not be initialized during " "dataloader iteration") epoch = current_epoch() count = _AccumulatorState.init_count[epoch] super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count)) _AccumulatorState.init_count[epoch] += 1 self.results_history = collections.defaultdict(list) self.results = dict(*args, **kwargs) self.updates = {}
def __init__(self): if current_dataloader() is not None: raise RuntimeError("dataloader may not be initialized during " "dataloader iteration") epoch = current_epoch() count = _AdaptiveDataLoaderState.init_count[epoch] super().__init__("adaptdl-dataloader-epoch{}-{}".format(epoch, count)) _AdaptiveDataLoaderState.init_count[epoch] += 1 self.current_index = 0 # Index within the current dataloader loop. self.end_index = 0 # End index of the current DataLoader loop. self.last_position = {} # Epoch -> position of last completed loop.
def context(self): """ All iterators should be iterated under this context. It ensures proper cleanup of elastic context at the end of each epoch. """ epoch = current_epoch() try: if AdaptiveDataLoaderHelper._current is not None: raise RuntimeError("overlapping dataloader \ iterations detected") AdaptiveDataLoaderHelper._current = self yield finally: self._state.current_index = 0 self._state.end_index = 0 self._state.last_position[epoch] = self._position[epoch] self._position[epoch] += 1 AdaptiveDataLoaderHelper._current = None
def skipdone(self): """ Should be called just after entering the `_elastic` context to make sure that the dataloader loop is not replayed if has already finished before a restart. """ epoch = current_epoch() position = self._position[epoch] if position <= self._state.last_position.get(epoch, -1): # Already completed the dataloader loop at the current # position, skip this loop and keep replaying the application # code. LOG.info("skipping %s loop at position %s in epoch %s", self.__class__.__name__, position, epoch) self._position[epoch] += 1 return True else: return False
def synchronized(self): """ A context manager which can be used to define the code to execute in *synchronized* mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python ``dict``. The application must ensure that whatever operations performed within this context block are the same across all replicas. .. warning:: Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code. """ if self._synchronized is not None: # Already synchronized, don't need to do anything. yield self return epoch = current_epoch() # Remove saved results from all finished epochs. Since finished # epochs are never replayed, they should never be needed again. for key in list(self._state.results_history.keys()): if key is not None and key < epoch: self._state.results_history.pop(key) # Get the number of synchronizations so far in the current epoch. count = self._sync_count[epoch] self._sync_count[epoch] += 1 results_list = self._state.results_history[epoch] assert count <= len(results_list) if count < len(results_list): # Results for this synchronization are saved in the history. self._synchronized = results_list[count] self._state.updates.clear() else: self._state.sync() # Sync results and updates across replicas. if current_dataloader() is None: # Only save into results history if outside of a dataloader # iteration, since code inside iterations are not replayed. results_list.append(copy.deepcopy(self._state.results)) self._synchronized = self._state.results try: yield self finally: self._synchronized = None
def __iter__(self): """ Iterate over batches of data. When adaptive batch size is disabled, stops after the entire dataset has been processed once in total by all replicas. This means if there are K replicas, then this method will iterate over ~1/K of the dataset. When adaptive batch size is enabled, stops after making enough statistical progress roughly equivalent to one pass over the dataset with non-adaptive batch size. In this case, the dataset may be processed more than once. A checkpoint-restart may be triggered in-between each batch. In this case, the current iteration state will be saved and restored after the restart, and continue where it left off. """ epoch = current_epoch() num_replicas = adaptdl.env.num_replicas() with self._elastic.context(): if self._elastic.skipdone(): return done = False while not done: self.sampler.set_epoch(epoch, index=self._elastic.current_index) self.batch_sampler.batch_size = self._elastic._sync_local_bsz() for idx, batch in enumerate(super().__iter__()): with self._elastic.profile(self.training and idx >= 1): yield batch # Increment by the number of data samples processed self._elastic.current_index += \ num_replicas * self.batch_sampler.batch_size if self._elastic.max_batch_size is not None and \ get_progress() >= len(self.dataset) * \ (epoch + 1) / self.batch_size: done = True break if self._elastic.max_batch_size is None: done = True self._elastic.current_index -= \ self._elastic.current_index % -len(self.dataset)