Example #1
0
def test_has_iterable_dataset():
    assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))

    assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))

    class MockDatasetWithoutIterableDataset(RandomDataset):
        def __iter__(self):
            yield 1
            return self

    assert not has_iterable_dataset(
        DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
    def auto_add_sampler(self, dataloader: DataLoader,
                         shuffle: bool) -> DataLoader:

        # don't do anything if it's not a dataloader
        is_dataloader = isinstance(dataloader, DataLoader)
        # don't manipulate iterable datasets
        is_iterable_ds = has_iterable_dataset(dataloader)

        if not is_dataloader or is_iterable_ds:
            return dataloader

        is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu
        need_dist_sampler = is_in_dist and not isinstance(
            dataloader.sampler, DistributedSampler)
        if self.replace_sampler_ddp and need_dist_sampler:
            if not isinstance(dataloader.sampler,
                              (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    'You seem to have configured a sampler in your DataLoader. This will be replaced '
                    ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
                    ' distributed training. Either remove the sampler from your DataLoader or set'
                    ' `replace_sampler_ddp`=False if you want to use your custom sampler.'
                )

            # replace with distributed sampler
            sampler = self._get_distributed_sampler(dataloader, shuffle)
            dataloader = self.replace_sampler(dataloader, sampler)

        return dataloader
 def _requires_distributed_sampler(self, dataloader) -> bool:
     return (self.trainer._accelerator_connector.replace_sampler_ddp
             and self.trainer._accelerator_connector.is_distributed
             and not isinstance(dataloader.sampler, DistributedSampler)
             and not has_iterable_dataset(dataloader)
             # `DistributedSampler` is never used with `poptorch.DataLoader`
             and not isinstance(self.trainer.accelerator, IPUAccelerator))
Example #4
0
    def auto_add_sampler(
        self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
    ) -> DataLoader:
        # don't do anything if it's not a dataloader
        is_dataloader = isinstance(dataloader, DataLoader)
        # don't manipulate iterable datasets
        is_iterable_ds = has_iterable_dataset(dataloader)

        if isinstance(dataloader, CombinedLoader):
            dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
            return dataloader

        if not is_dataloader or is_iterable_ds:
            return dataloader

        need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance(
            dataloader.sampler, DistributedSampler
        )
        if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler:
            if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    'You seem to have configured a sampler in your DataLoader. This will be replaced '
                    ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
                    ' distributed training. Either remove the sampler from your DataLoader or set'
                    ' `replace_sampler_ddp`=False if you want to use your custom sampler.'
                )

            # replace with distributed sampler
            sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode)
            dataloader = self.replace_sampler(dataloader, sampler, mode=mode)

        return dataloader
Example #5
0
def test_warning_with_iterable_dataset_and_len(tmpdir):
    """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """
    model = EvalModelTemplate()
    original_dataset = model.train_dataloader().dataset

    class IterableWithLen(IterableDataset):
        def __iter__(self):
            return iter(original_dataset)

        def __len__(self):
            return len(original_dataset)

    dataloader = DataLoader(IterableWithLen(), batch_size=16)
    assert has_len(dataloader)
    assert has_iterable_dataset(dataloader)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=3,
    )
    with pytest.warns(UserWarning,
                      match='Your `IterableDataset` has `__len__` defined.'):
        trainer.fit(model,
                    train_dataloader=dataloader,
                    val_dataloaders=[dataloader])
    with pytest.warns(UserWarning,
                      match='Your `IterableDataset` has `__len__` defined.'):
        trainer.test(model, test_dataloaders=[dataloader])
 def _requires_distributed_sampler(self, dataloader) -> bool:
     return (
         self._accelerator_connector.replace_sampler_ddp
         and self._accelerator_connector.is_distributed
         and not isinstance(dataloader.sampler, DistributedSampler)
         and not has_iterable_dataset(dataloader)
     )
    def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
        if isinstance(dataloader, CombinedLoader):
            # apply `auto_add_sampler` on all the collection of loaders
            dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle)
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, DataLoader):
            return dataloader

        if (
            self.accelerator_connector.replace_sampler_ddp
            and self.accelerator_connector.is_distributed
            and not isinstance(dataloader.sampler, DistributedSampler)
            and not has_iterable_dataset(dataloader)
        ):
            if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    "You seem to have configured a sampler in your DataLoader. This will be replaced "
                    " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
                    " distributed training. Either remove the sampler from your DataLoader or set"
                    " `replace_sampler_ddp`=False if you want to use your custom sampler."
                )
            sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode)
        else:
            # use current sampler
            sampler = dataloader.sampler

        dataloader = self.replace_sampler(dataloader, sampler, mode=mode)

        return dataloader
Example #8
0
    def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

        # don't do anything if it's not a dataloader
        is_dataloader = isinstance(dataloader, DataLoader)
        # don't manipulate iterable datasets
        is_iterable_ds = has_iterable_dataset(dataloader)

        if not is_dataloader or is_iterable_ds:
            return dataloader

        need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler)
        if self.replace_sampler_ddp and need_dist_sampler:
            wrap_sampler = False
            if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
                rank_zero_warn(f'Wrapping current custom sampler {dataloader.sampler} with DistributedSampler ...'
                               ' If this is not the desired behavior, please specify `replace_sampler_ddp=False`')
                wrap_sampler = True

            # replace with distributed sampler
            sampler = self._get_distributed_sampler(dataloader, shuffle, wrap=wrap_sampler)
            dataloader = self.replace_sampler(dataloader, sampler)

        return dataloader