Ejemplo n.º 1
0
 def _check_eval_shuffling(dataloader, mode):
     if _is_dataloader_shuffled(dataloader):
         rank_zero_warn(
             f"Your `{mode.dataloader_prefix}_dataloader`'s sampler has shuffling enabled,"
             " it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.",
             category=PossibleUserWarning,
         )
Ejemplo n.º 2
0
    def _prepare_dataloader(self,
                            dataloader: Any,
                            shuffle: Optional[bool] = None,
                            mode: Optional[RunningStage] = None) -> Any:
        """This function handles to following functionalities:

        - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
        - Wrapping the datasets and samplers into fault-tolerant components
        - Wrapping the dataloader based on strategy-specific logic
        """
        if isinstance(dataloader, CombinedLoader):
            # apply `_prepare_dataloader` on all the collection of loaders
            dataloader.loaders = apply_to_collection(
                dataloader.loaders, (DataLoader, CycleIterator),
                self._prepare_dataloader,
                shuffle,
                mode=mode)
            # the length need to recomputed across all dataloaders in case of special behavior.
            dataloader._apply_cycle_iterator_length()
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, (DataLoader, CycleIterator)):
            return dataloader

        cycle_iterator: Optional[CycleIterator] = None

        if isinstance(dataloader, CycleIterator):
            cycle_iterator = dataloader
            dataloader = dataloader.loader

        if (_fault_tolerant_training()  # injects components to track the state
                or self._requires_distributed_sampler(
                    dataloader)  # sets the distributed sampler
                or mode ==
                RunningStage.PREDICTING  # to track indices for the predictions
                # IPUs use a custom `poptorch.DataLoader` which we might need to convert to
                or isinstance(self.trainer.accelerator, IPUAccelerator)):
            if shuffle is None:
                # for training, set to True always
                # for evaluation, decide based on existing sampler
                shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(
                    dataloader)

            sampler = self._resolve_sampler(dataloader,
                                            shuffle=shuffle,
                                            mode=mode)
            dataloader = _update_dataloader(dataloader, sampler, mode=mode)

        dataloader = self.trainer.strategy.process_dataloader(dataloader)

        if cycle_iterator is not None:
            cycle_iterator.loader = dataloader
            return cycle_iterator

        return dataloader