Ejemplo n.º 1
0
    def recreate_data_iterator(self, phase_type, epoch, compute_start_iter):
        """
        Recreate data iterator (including multiprocessing workers) and destroy the
        previous iterators.

        This is called when we load a new checkpoint or when phase changes during
        the training (one epoch to the next).
        DataSampler may need to be informed on those events to update the
        epoch and start_iteration so that the data is deterministically shuffled,
        so we call them here.
        """
        if hasattr(self.dataloaders[phase_type], "sampler"):
            sampler = self.dataloaders[phase_type].sampler
            # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler.
            if hasattr(sampler, "set_epoch"):
                sampler.set_epoch(epoch)
            # Resume from the iteration if valid
            if hasattr(sampler, "set_start_iter"):
                if (compute_start_iter and self.checkpoint is not None
                        and self.checkpoint["iteration"] > 0):
                    num_iters_in_epochs = len(self.dataloaders[phase_type])
                    num_epochs = self.checkpoint["train_phase_idx"] + 1
                    num_train_iters_done = num_epochs * num_iters_in_epochs
                    start_iter = self.checkpoint[
                        "iteration"] - num_train_iters_done
                else:
                    start_iter = 0
                sampler.set_start_iter(start_iter)
            print_sampler_config(sampler)
        # delete the old data iterator
        del self.data_iterator
        gc.collect()
        # recreate the data iterator
        self.data_iterator = iter(self.dataloaders[phase_type])
Ejemplo n.º 2
0
    def set_epoch(self, phase_type: str, epoch: int, start_iter: int):
        if hasattr(self.dataloaders[phase_type], "sampler"):
            sampler = self.dataloaders[phase_type].sampler
            # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler
            # Resume from the iteration if valid
            self.set_epoch_start_iter(sampler, epoch, start_iter)
            print_sampler_config(sampler)

        # call set_epoch and set_start_iter for AirstoreDataset since it handles
        # shuffle and sample skipping behavior internally
        dataset = self.datasets[phase_type]
        if hasattr(dataset, "data_objs"):
            for data_obj in dataset.data_objs:
                self.set_epoch_start_iter(data_obj, epoch, start_iter)
Ejemplo n.º 3
0
    def recreate_data_iterator(self, phase_type, epoch, compute_start_iter):
        """
        Recreate data iterator (including multiprocessing workers) and destroy the
        previous iterators.

        This is called when we load a new checkpoint or when phase changes during
        the training (one epoch to the next).
        DataSampler may need to be informed on those events to update the
        epoch and start_iteration so that the data is deterministically shuffled,
        so we call them here.
        """
        start_iter = 0
        if compute_start_iter:
            start_iter = self._compute_start_iter_from_checkpoint(phase_type)

        if hasattr(self.dataloaders[phase_type], "sampler"):
            sampler = self.dataloaders[phase_type].sampler
            # (Re-)Shuffle data: set epoch of distributed (or fairstore) sampler.
            if hasattr(sampler, "set_epoch"):
                sampler.set_epoch(epoch)
            # Resume from the iteration if valid
            if hasattr(sampler, "set_start_iter"):
                sampler.set_start_iter(start_iter)
            print_sampler_config(sampler)

        # call set_epoch and set_start_iter for AirstoreDataset since it handles
        # shuffle and sample skipping behavior internally
        if hasattr(self.dataloaders[phase_type], "dataset"):
            dataset = self.dataloaders[phase_type].dataset
            if isinstance(dataset, GenericSSLDataset):
                for data_obj in dataset.data_objs:
                    if isinstance(data_obj, AirstoreDataset):
                        data_obj.set_epoch(epoch)
                        data_obj.set_start_iter(start_iter)

        # delete the old data iterator
        del self.data_iterator
        gc.collect()
        # recreate the data iterator
        self.data_iterator = iter(self.dataloaders[phase_type])