コード例 #1
0
def test_has_len():
    assert has_len(DataLoader(RandomDataset(1, 1)))

    with pytest.raises(ValueError, match="`Dataloader` returned 0 length."):
        assert has_len(DataLoader(RandomDataset(0, 0)))

    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
コード例 #2
0
def test_has_len():
    assert has_len(DataLoader(RandomDataset(1, 1)))

    with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
        assert has_len(DataLoader(RandomDataset(0, 0)))

    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
コード例 #3
0
    def reset_train_dataloader(self, model: LightningModule) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model.train_dataloader)

        # debugging
        self.dev_debugger.track_load_dataloader_call(
            'train_dataloader', dataloaders=[self.train_dataloader])

        self.num_training_batches = 0

        # automatically add samplers
        self.train_dataloader = self.auto_add_sampler(self.train_dataloader,
                                                      train=True)

        self.num_training_batches = len(self.train_dataloader) if has_len(
            self.train_dataloader) else float('inf')
        self._worker_check(self.train_dataloader, 'train dataloader')

        if isinstance(self.limit_train_batches,
                      int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches,
                                            int(self.limit_train_batches))
        elif self.num_training_batches != float('inf'):
            self.num_training_batches = int(self.num_training_batches *
                                            self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                'When using an IterableDataset for `limit_train_batches`,'
                ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
                ' `num_training_batches` to use.')

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an IterableDataset for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.')
            else:
                self.val_check_batch = int(self.num_training_batches *
                                           self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)
コード例 #4
0
def test_warning_with_iterable_dataset_and_len(tmpdir):
    """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """
    model = EvalModelTemplate()
    original_dataset = model.train_dataloader().dataset

    class IterableWithLen(IterableDataset):
        def __iter__(self):
            return iter(original_dataset)

        def __len__(self):
            return len(original_dataset)

    dataloader = DataLoader(IterableWithLen(), batch_size=16)
    assert has_len(dataloader)
    assert has_iterable_dataset(dataloader)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=3,
    )
    with pytest.warns(UserWarning,
                      match='Your `IterableDataset` has `__len__` defined.'):
        trainer.fit(model,
                    train_dataloader=dataloader,
                    val_dataloaders=[dataloader])
    with pytest.warns(UserWarning,
                      match='Your `IterableDataset` has `__len__` defined.'):
        trainer.test(model, test_dataloaders=[dataloader])
コード例 #5
0
    def training_epoch_end(self, outputs):
        ids = torch.cat([o['ids'] for o in outputs], dim=0)

        # in distributed mode collect ids from every process (gpu)
        if distributed_available():
            gather_ids = [
                torch.zeros_like(ids)
                for _ in range(torch.distributed.get_world_size())
            ]
            torch.distributed.all_gather(gather_ids, ids)
            ids = torch.cat(gather_ids, dim=0)

        if has_len(self.trainer.datamodule.train_dataset):
            received = torch.zeros(len(
                self.trainer.datamodule.train_dataset)).to(dtype=bool)
        else:
            received = torch.zeros(
                len(list(
                    self.trainer.datamodule.train_dataset))).to(dtype=bool)
        received[ids] = True

        if self.check_ids:
            # assert no duplicate element received
            assert len(set(ids.tolist())) == len(
                ids.tolist()), (f"Received {len(ids.tolist())} ids but only"
                                f" {len(set(ids.tolist()))} are unique: {ids}")
            # assert all elements received
            assert all(received), (
                f"({self.trainer.max_steps}) Received not all {len(received)} ids: {received}"
            )
コード例 #6
0
def test_index_batch_sampler_methods():
    dataset = range(15)
    sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, 3, False)
    index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

    assert isinstance(index_batch_sampler, Iterable)
    assert has_len(index_batch_sampler)
コード例 #7
0
 def setup(  # type: ignore[override]
         self,
         dataloader: Iterable,
         batch_to_device: Optional[Callable[[Any], Any]] = None) -> None:
     super().setup(dataloader)
     self._has_len = has_len(dataloader)
     if batch_to_device is not None:
         self.batch_to_device = batch_to_device
コード例 #8
0
    def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']):
        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        for dataloader in dataloaders:
            if not has_len(dataloader):
                raise MisconfigurationException(
                    "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
                    " HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
                )
コード例 #9
0
    def num_training_steps(self) -> int:
        r""" Total training steps inferred from datasets length, nodes and devices. """
        if self.trainer.max_steps is not None and self.trainer.max_steps >= 0:
            return self.trainer.max_steps

        if not has_len(self.trainer.datamodule.train_dataset):
            rank_zero_warn("Using IterableDataset, cannot compute max_steps, returning None")
            return None

        # train samples
        train_samples = len(self.trainer.datamodule.train_dataset)

        # number of training devices
        if self.trainer._accelerator_connector.use_dp:
            total_devices = 1    # with dp, a single batch is divided across many gpus
        elif self.trainer._accelerator_connector.use_ddp2:
            total_devices = self.trainer.num_nodes
        else:
            total_devices = self.trainer.num_processes * self.trainer.num_nodes

        # the number of training samples may be modified in distributed training
        # to be divisible by the number of GPUs...
        train_samples_per_device = math.ceil(train_samples / total_devices)

        # train batches from the dataloader
        train_batches_per_device = math.ceil(train_samples_per_device / self.hyperparameters.batch_size)

        # eventually limit train batches
        limit_batches = self.trainer.limit_train_batches
        train_batches_per_device = (
            min(train_batches_per_device, limit_batches)
            if isinstance(limit_batches, int) else int(limit_batches * train_batches_per_device)
        )

        # train steps for each device
        train_steps_per_device = math.ceil(train_batches_per_device / self.trainer.accumulate_grad_batches)

        # total train steps across all epochs
        total_train_steps = train_steps_per_device * self.trainer.max_epochs
        rank_zero_warn(f"Automatically computed total steps equal to {total_train_steps}")

        return total_train_steps
コード例 #10
0
def _is_valid_batch_size(current_size, dataloader):
    return not has_len(dataloader) or current_size <= len(dataloader)
コード例 #11
0
    def _reset_eval_dataloader(
            self, model: LightningModule,
            mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            model: The current `LightningModule`
            mode: Either `'val'` or `'test'`

        Returns:
            Tuple (num_batches, dataloaders)
        """
        # always get the loaders first so we can count how many there are
        loader_name = f'{mode}_dataloader'
        dataloaders = self.request_dataloader(getattr(model, loader_name))

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        # when overfitting use the training loader as val and test
        # duplicate it the numb of times needed to match the train loaders
        if self.overfit_batches > 0:
            num_loaders = len(dataloaders)
            train_dataloader = self.request_dataloader(
                getattr(model, 'train_dataloader'))
            dataloaders = [
                deepcopy(train_dataloader) for _ in range(num_loaders)
            ]

        self.dev_debugger.track_load_dataloader_call(loader_name,
                                                     dataloaders=dataloaders)

        for loader_i in range(len(dataloaders)):
            loader = dataloaders[loader_i]

            # shuffling in val and test set is bad practice
            if mode in ('val', 'test') and hasattr(
                    loader, 'sampler') and isinstance(loader.sampler,
                                                      RandomSampler):

                # when overfitting, the dataloader should not have sampler
                if self.overfit_batches > 0:
                    rank_zero_warn(
                        'You requested to overfit but enabled test/val dataloader shuffling.'
                        ' We are turning it off for you.')
                    dataloaders[loader_i] = self.replace_sampler(
                        loader, SequentialSampler(loader.dataset))

                else:
                    rank_zero_warn(
                        f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
                        ' this off for validation and test dataloaders.')

        if any([dl is None for dl in dataloaders]):
            rank_zero_warn(
                "One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [
            self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders
            if dl is not None
        ]

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                num_batches = len(dataloader) if has_len(
                    dataloader) else float('inf')
                self._worker_check(dataloader, f'{mode} dataloader {i}')

                # percent or num_steps
                limit_eval_batches = getattr(self, f'limit_{mode}_batches')

                # limit num batches either as a percent or num steps
                if isinstance(limit_eval_batches,
                              int) or limit_eval_batches == 0.0:
                    num_batches = min(num_batches, int(limit_eval_batches))
                elif num_batches != float('inf'):
                    num_batches = int(num_batches * limit_eval_batches)
                elif limit_eval_batches != 1.0:
                    raise MisconfigurationException(
                        'When using an IterableDataset for `limit_{mode}_batches`,'
                        f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
                        f' `num_{mode}_batches` to use.')

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(
                        limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f'you requested to check {limit_eval_batches} of the {mode} dataloader but'
                        f' {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches.'
                        f' Try at least limit_{mode}_batches={min_pct}')

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders
コード例 #12
0
    def reset_train_dataloader(self, model: LightningModule) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The current `LightningModule`
        """
        self.train_dataloader = self.request_dataloader(model, "train")

        if self.overfit_batches > 0:
            if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
                rank_zero_warn(
                    'You requested to overfit but enabled training dataloader shuffling.'
                    ' We are turning it off for you.'
                )
                self.train_dataloader = self.replace_sampler(
                    self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
                )

        # debugging
        self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])

        # automatically add samplers
        self.train_dataloader = apply_to_collection(
            self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
        )

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)

        self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

        if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
        elif self.num_training_batches != float('inf'):
            self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                'When using an IterableDataset for `limit_train_batches`,'
                ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
                ' `num_training_batches` to use.'
            )

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
                    f'to the number of the training batches ({self.num_training_batches}). '
                    'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float('inf')
                else:
                    raise MisconfigurationException(
                        'When using an IterableDataset for `train_dataloader`,'
                        ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
                        ' checking validation every k training batches.'
                    )
            else:
                self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)
コード例 #13
0
    def _reset_eval_dataloader(
        self,
        mode: RunningStage,
        model: Optional["pl.LightningModule"] = None
    ) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            mode: The running stage of the ``Trainer``
            model: The ``LightningModule`` if calling this outside of the trainer scope.

        Returns:
            Tuple (num_batches, dataloaders)
        """
        assert mode.evaluating or mode == RunningStage.PREDICTING

        # always get the loaders first so we can count how many there are
        loader_name = f"{mode.dataloader_prefix}_dataloader"
        dataloaders = self.request_dataloader(mode, model=model)

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        # when overfitting, use the training loader as val and test
        # duplicate it the numb of times needed to match the train loaders
        if self.overfit_batches > 0:
            train_dataloader = self.request_dataloader(RunningStage.TRAINING,
                                                       model=model)
            dataloaders = [
                deepcopy(train_dataloader) for _ in range(len(dataloaders))
            ]

        self.dev_debugger.track_load_dataloader_call(loader_name,
                                                     dataloaders=dataloaders)

        for loader_i in range(len(dataloaders)):
            loader = dataloaders[loader_i]

            if hasattr(loader, "sampler") and isinstance(
                    loader.sampler, RandomSampler):

                # when overfitting, the dataloader should not have sampler
                if self.overfit_batches > 0 and mode.evaluating:
                    rank_zero_warn(
                        "You requested to overfit but enabled val/test dataloader shuffling."
                        " We are turning it off for you.")
                    dataloaders[loader_i] = self.replace_sampler(
                        loader, SequentialSampler(loader.dataset), mode=mode)
                else:
                    rank_zero_warn(
                        f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
                        "it is strongly recommended that you turn this off for val/test/predict dataloaders."
                    )

        if any(dl is None for dl in dataloaders):
            rank_zero_warn(
                "One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [
            self.auto_add_sampler(dl, False, mode=mode) for dl in dataloaders
            if dl is not None
        ]

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(dataloaders,
                            dtype=DataLoader,
                            function=self.auto_add_worker_init_fn)

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                num_batches = len(dataloader) if has_len(
                    dataloader) else float("inf")
                self._worker_check(dataloader,
                                   f"{mode.dataloader_prefix}_dataloader {i}")

                # percent or num_steps
                limit_eval_batches = getattr(
                    self, f"limit_{mode.dataloader_prefix}_batches")

                # limit num batches either as a percent or num steps
                if isinstance(limit_eval_batches,
                              int) or limit_eval_batches == 0.0:
                    num_batches = min(num_batches, int(limit_eval_batches))
                elif num_batches != float("inf"):
                    num_batches = int(num_batches * limit_eval_batches)
                elif limit_eval_batches != 1.0:
                    raise MisconfigurationException(
                        f"When using an IterableDataset for `limit_{mode}_batches`,"
                        f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k"
                        f" specifies `num_{mode.dataloader_prefix}_batches` to use."
                    )

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(
                        limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but"
                        f" {limit_eval_batches}*{num_batches} < 1. Please increase the"
                        f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least"
                        f" `limit_{mode.dataloader_prefix}_batches={min_pct}`")

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders
コード例 #14
0
    def reset_train_dataloader(self,
                               model: Optional["pl.LightningModule"] = None
                               ) -> None:
        """Resets the train dataloader and initialises required variables
        (number of batches, when to validate, etc.).

        Args:
            model: The `LightningModule` if calling this outside of the trainer scope.
        """
        self.train_dataloader = self.request_dataloader(RunningStage.TRAINING,
                                                        model=model)

        if self.overfit_batches > 0:
            if hasattr(self.train_dataloader, "sampler") and isinstance(
                    self.train_dataloader.sampler, RandomSampler):
                rank_zero_warn(
                    "You requested to overfit but enabled training dataloader shuffling."
                    " We are turning off the training dataloader shuffling for you."
                )
                self.train_dataloader = self.replace_sampler(
                    self.train_dataloader,
                    SequentialSampler(self.train_dataloader.dataset),
                    mode=RunningStage.TRAINING)

        # debugging
        self.dev_debugger.track_load_dataloader_call(
            "train_dataloader", dataloaders=[self.train_dataloader])

        # automatically add samplers
        self.train_dataloader = apply_to_collection(self.train_dataloader,
                                                    DataLoader,
                                                    self.auto_add_sampler,
                                                    shuffle=True,
                                                    mode=RunningStage.TRAINING)

        # check the workers recursively
        apply_to_collection(self.train_dataloader, DataLoader,
                            self._worker_check, "train_dataloader")

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(self.train_dataloader, DataLoader,
                            self.auto_add_worker_init_fn)

        # add collate_fn to collect metadata for fault tolerant training
        if _fault_tolerant_training():
            apply_to_collection(self.train_dataloader, DataLoader,
                                self._add_sampler_metadata_collate)

        # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
        self.train_dataloader = CombinedLoader(
            self.train_dataloader,
            self.data_connector.multiple_trainloader_mode)

        self.num_training_batches = len(self.train_dataloader) if has_len(
            self.train_dataloader) else float("inf")

        if isinstance(self.limit_train_batches,
                      int) or self.limit_train_batches == 0.0:
            self.num_training_batches = min(self.num_training_batches,
                                            int(self.limit_train_batches))
        elif self.num_training_batches != float("inf"):
            self.num_training_batches = int(self.num_training_batches *
                                            self.limit_train_batches)
        elif self.limit_train_batches != 1.0:
            raise MisconfigurationException(
                "When using an IterableDataset for `limit_train_batches`,"
                " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
                " `num_training_batches` to use.")

        # determine when to check validation
        # if int passed in, val checks that often
        # otherwise, it checks in [0, 1.0] % range of a training epoch
        if isinstance(self.val_check_interval, int):
            self.val_check_batch = self.val_check_interval
            if self.val_check_batch > self.num_training_batches:
                raise ValueError(
                    f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
                    f"to the number of the training batches ({self.num_training_batches}). "
                    "If you want to disable validation set `limit_val_batches` to 0.0 instead."
                )
        else:
            if not has_len(self.train_dataloader):
                if self.val_check_interval == 1.0:
                    self.val_check_batch = float("inf")
                else:
                    raise MisconfigurationException(
                        "When using an IterableDataset for `train_dataloader`,"
                        " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
                        " checking validation every k training batches.")
            else:
                self.val_check_batch = int(self.num_training_batches *
                                           self.val_check_interval)
                self.val_check_batch = max(1, self.val_check_batch)

        if self.logger and self.num_training_batches < self.log_every_n_steps:
            rank_zero_warn(
                f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
                f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
                " you want to see logs for the training epoch.")
コード例 #15
0
    def _reset_eval_dataloader(
        self, mode: str, model: Optional["pl.LightningModule"] = None
    ) -> Tuple[List[Union[int, float]], List[DataLoader]]:
        """Generic method to reset a dataloader for evaluation.

        Args:
            mode: Either `'val'`, `'test'` or `'predict'`
            model: The `LightningModule` if calling this outside of the trainer scope.

        Returns:
            Tuple (num_batches, dataloaders)
        """
        # always get the loaders first so we can count how many there are
        loader_name = f"{mode}_dataloader"
        dataloaders = self.request_dataloader(mode, model=model)

        if not isinstance(dataloaders, list):
            dataloaders = [dataloaders]

        # when overfitting use the training loader as val and test
        # duplicate it the numb of times needed to match the train loaders
        if self.overfit_batches > 0:
            num_loaders = len(dataloaders)
            train_dataloader = self.request_dataloader("train", model=model)
            dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]

        self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)

        for loader_i in range(len(dataloaders)):
            loader = dataloaders[loader_i]

            # shuffling in val and test set is bad practice
            modes = ("val", "test", "predict")
            if mode in modes and hasattr(loader, "sampler") and isinstance(loader.sampler, RandomSampler):

                # when overfitting, the dataloader should not have sampler
                if self.overfit_batches > 0 and mode != "predict":
                    rank_zero_warn(
                        "You requested to overfit but enabled val/test dataloader shuffling."
                        " We are turning it off for you."
                    )
                    dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))

                else:
                    rank_zero_warn(
                        f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"
                        " this off for val/test/predict dataloaders."
                    )

        if any(dl is None for dl in dataloaders):
            rank_zero_warn("One of given dataloaders is None and it will be skipped.")

        # add samplers
        dataloaders = [
            self.auto_add_sampler(dl, shuffle=False, mode=self.state.stage) for dl in dataloaders if dl is not None
        ]

        # add worker_init_fn for correct seeding in worker processes
        apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)

        # allow accelerator to modify dataloader
        hook_name = f"on_reset_{mode}_dataloader"
        dataloaders = getattr(self.accelerator, hook_name)(dataloaders)

        loader_num_batches = []

        # determine number of batches
        # datasets could be none, 1 or 2+
        if len(dataloaders) != 0:
            for i, dataloader in enumerate(dataloaders):
                num_batches = len(dataloader) if has_len(dataloader) else float("inf")
                self._worker_check(dataloader, f"{mode} dataloader {i}")

                # percent or num_steps
                limit_eval_batches = getattr(self, f"limit_{mode}_batches")

                # limit num batches either as a percent or num steps
                if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
                    num_batches = min(num_batches, int(limit_eval_batches))
                elif num_batches != float("inf"):
                    num_batches = int(num_batches * limit_eval_batches)
                elif limit_eval_batches != 1.0:
                    raise MisconfigurationException(
                        "When using an IterableDataset for `limit_{mode}_batches`,"
                        f" `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
                        f" `num_{mode}_batches` to use."
                    )

                if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
                    min_pct = 1.0 / len(dataloader)
                    raise MisconfigurationException(
                        f"you requested to check {limit_eval_batches} of the {mode} dataloader but"
                        f" {limit_eval_batches}*{num_batches} < 1. Please increase the limit_{mode}_batches."
                        f" Try at least limit_{mode}_batches={min_pct}"
                    )

                loader_num_batches.append(num_batches)

        return loader_num_batches, dataloaders