def recreate_data_iterator(self, phase_type, epoch, compute_start_iter): """ Recreate data iterator (including multiprocessing workers) and destroy the previous iterators. This is called when we load a new checkpoint or when phase changes during the training (one epoch to the next). DataSampler may need to be informed on those events to update the epoch and start_iteration so that the data is deterministically shuffled, so we call them here. """ if hasattr(self.dataloaders[phase_type], "sampler"): sampler = self.dataloaders[phase_type].sampler # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler. if hasattr(sampler, "set_epoch"): sampler.set_epoch(epoch) # Resume from the iteration if valid if hasattr(sampler, "set_start_iter"): if (compute_start_iter and self.checkpoint is not None and self.checkpoint["iteration"] > 0): num_iters_in_epochs = len(self.dataloaders[phase_type]) num_epochs = self.checkpoint["train_phase_idx"] + 1 num_train_iters_done = num_epochs * num_iters_in_epochs start_iter = self.checkpoint[ "iteration"] - num_train_iters_done else: start_iter = 0 sampler.set_start_iter(start_iter) print_sampler_config(sampler) # delete the old data iterator del self.data_iterator gc.collect() # recreate the data iterator self.data_iterator = iter(self.dataloaders[phase_type])
def set_epoch(self, phase_type: str, epoch: int, start_iter: int): if hasattr(self.dataloaders[phase_type], "sampler"): sampler = self.dataloaders[phase_type].sampler # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler # Resume from the iteration if valid self.set_epoch_start_iter(sampler, epoch, start_iter) print_sampler_config(sampler) # call set_epoch and set_start_iter for AirstoreDataset since it handles # shuffle and sample skipping behavior internally dataset = self.datasets[phase_type] if hasattr(dataset, "data_objs"): for data_obj in dataset.data_objs: self.set_epoch_start_iter(data_obj, epoch, start_iter)
def recreate_data_iterator(self, phase_type, epoch, compute_start_iter): """ Recreate data iterator (including multiprocessing workers) and destroy the previous iterators. This is called when we load a new checkpoint or when phase changes during the training (one epoch to the next). DataSampler may need to be informed on those events to update the epoch and start_iteration so that the data is deterministically shuffled, so we call them here. """ start_iter = 0 if compute_start_iter: start_iter = self._compute_start_iter_from_checkpoint(phase_type) if hasattr(self.dataloaders[phase_type], "sampler"): sampler = self.dataloaders[phase_type].sampler # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler. if hasattr(sampler, "set_epoch"): sampler.set_epoch(epoch) # Resume from the iteration if valid if hasattr(sampler, "set_start_iter"): sampler.set_start_iter(start_iter) print_sampler_config(sampler) # call set_epoch and set_start_iter for AirstoreDataset since it handles # shuffle and sample skipping behavior internally if hasattr(self.dataloaders[phase_type], "dataset"): dataset = self.dataloaders[phase_type].dataset if isinstance(dataset, GenericSSLDataset): for data_obj in dataset.data_objs: if isinstance(data_obj, AirstoreDataset): data_obj.set_epoch(epoch) data_obj.set_start_iter(start_iter) # delete the old data iterator del self.data_iterator gc.collect() # recreate the data iterator self.data_iterator = iter(self.dataloaders[phase_type])