Пример #1
0
def test_update_dataloader_typerror_custom_exception():
    class BadImpl(DataLoader):
        def __init__(self, foo, *args, **kwargs):
            self.foo = foo
            # positional conflict with `dataset`
            super().__init__(foo, *args, **kwargs)

    dataloader = BadImpl([1, 2, 3])
    with pytest.raises(
            MisconfigurationException,
            match="`DataLoader` implementation has an error.*`dataset`"):
        _update_dataloader(dataloader, dataloader.sampler)

    class BadImpl2(DataLoader):
        def __init__(self, randomize, *args, **kwargs):
            self.randomize = randomize
            # keyword conflict with `shuffle`
            super().__init__(*args, shuffle=randomize, **kwargs)

    dataloader = BadImpl2(False, [])
    with pytest.raises(
            MisconfigurationException,
            match="`DataLoader` implementation has an error.*`shuffle`"):
        _update_dataloader(dataloader, dataloader.sampler)

    class GoodImpl(DataLoader):
        def __init__(self, randomize, *args, **kwargs):
            # fixed implementation, kwargs are filtered
            self.randomize = randomize or kwargs.pop("shuffle", False)
            super().__init__(*args, shuffle=randomize, **kwargs)

    dataloader = GoodImpl(False, [])
    new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
    assert isinstance(new_dataloader, GoodImpl)
Пример #2
0
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset,
                                        checked_values):
    with _replace_dataloader_init_method():
        dataloader = cls(*args, **kwargs)

    assert dataloader.__pl_dl_args == args
    assert dataloader.__pl_dl_kwargs == kwargs
    assert dataloader.__pl_dl_arg_names == arg_names
    assert dataloader.__dataset == dataset

    assert dataloader.dataset == dataset

    for key, value in checked_values.items():
        dataloader_value = getattr(dataloader, key)
        if isinstance(dataloader_value, torch.Tensor):
            assert dataloader_value is value
        else:
            assert getattr(dataloader, key) == value

    dataloader = _update_dataloader(dataloader, dataloader.sampler)

    assert isinstance(dataloader, cls)
    assert not hasattr(dataloader, "__pl_dl_kwargs")
    assert not hasattr(dataloader, "__pl_dl_arg_names")
    assert not hasattr(dataloader, "__pl_dl_args")
    assert not hasattr(dataloader, "__dataset")

    assert dataloader.dataset == dataset

    for key, value in checked_values.items():
        dataloader_value = getattr(dataloader, key)
        if isinstance(dataloader_value, torch.Tensor):
            assert dataloader_value is value
        else:
            assert getattr(dataloader, key) == value
Пример #3
0
def test_update_dataloader_with_multiprocessing_context():
    """This test verifies that replace_sampler conserves multiprocessing context."""
    train = RandomDataset(32, 64)
    context = "spawn"
    train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
    new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset))
    assert new_data_loader.multiprocessing_context == train.multiprocessing_context
Пример #4
0
    def _setup_dataloader(self,
                          dataloader: DataLoader,
                          replace_sampler: bool = True,
                          move_to_device: bool = True) -> DataLoader:
        """Set up a single dataloader for accelerated training.

        Args:
            dataloader: The dataloader to accelerate.
            replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader
                for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
            move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to
                the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
                returned data.

        Returns:
            The wrapped dataloader.
        """
        sampler = dataloader.sampler
        if replace_sampler and self._requires_distributed_sampler(dataloader):
            sampler = self._get_distributed_sampler(
                dataloader, **self._strategy.distributed_sampler_kwargs)

        # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
        dataloader = _update_dataloader(dataloader, sampler)

        # add worker_init_fn for correct seeding in worker processes
        _auto_add_worker_init_fn(dataloader, self.global_rank)

        dataloader = self._strategy.process_dataloader(dataloader)
        device = self.device if move_to_device and not isinstance(
            self._strategy, TPUSpawnStrategy) else None
        lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
        lite_dataloader = cast(DataLoader, lite_dataloader)
        return lite_dataloader
    def _prepare_dataloader(self,
                            dataloader: Any,
                            shuffle: Optional[bool] = None,
                            mode: Optional[RunningStage] = None) -> Any:
        """This function handles to following functionalities:

        - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
        - Wrapping the datasets and samplers into fault-tolerant components
        - Wrapping the dataloader based on strategy-specific logic
        """
        if isinstance(dataloader, CombinedLoader):
            # apply `_prepare_dataloader` on all the collection of loaders
            dataloader.loaders = apply_to_collection(
                dataloader.loaders, (DataLoader, CycleIterator),
                self._prepare_dataloader,
                shuffle,
                mode=mode)
            # the length need to recomputed across all dataloaders in case of special behavior.
            dataloader._apply_cycle_iterator_length()
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, (DataLoader, CycleIterator)):
            return dataloader

        cycle_iterator: Optional[CycleIterator] = None

        if isinstance(dataloader, CycleIterator):
            cycle_iterator = dataloader
            dataloader = dataloader.loader

        if (_fault_tolerant_training()  # injects components to track the state
                or self._requires_distributed_sampler(
                    dataloader)  # sets the distributed sampler
                or mode ==
                RunningStage.PREDICTING  # to track indices for the predictions
                # IPUs use a custom `poptorch.DataLoader` which we might need to convert to
                or isinstance(self.trainer.accelerator, IPUAccelerator)):
            if shuffle is None:
                # for training, set to True always
                # for evaluation, decide based on existing sampler
                shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(
                    dataloader)

            sampler = self._resolve_sampler(dataloader,
                                            shuffle=shuffle,
                                            mode=mode)
            dataloader = _update_dataloader(dataloader, sampler, mode=mode)

        dataloader = self.trainer.strategy.process_dataloader(dataloader)

        if cycle_iterator is not None:
            cycle_iterator.loader = dataloader
            return cycle_iterator

        return dataloader
Пример #6
0
    def _setup_dataloader(self,
                          dataloader: DataLoader,
                          replace_sampler: bool = True,
                          move_to_device: bool = True) -> DataLoader:
        """Set up a single dataloader for accelerated training.

        Args:
            dataloader: The dataloader to accelerate.
            replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader
                for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
            move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to
                the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
                returned data.

        Returns:
            The wrapped dataloader.
        """
        sampler = dataloader.sampler
        if replace_sampler and self._requires_distributed_sampler(dataloader):
            if not isinstance(sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    "You seem to have configured a sampler in your DataLoader. This will be replaced "
                    " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
                    " distributed training. Either remove the sampler from your DataLoader or set"
                    " `replace_sampler=False` if you want to use your custom sampler."
                )
            sampler = self._get_distributed_sampler(
                dataloader, **self._strategy.distributed_sampler_kwargs)

        # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
        dataloader = _update_dataloader(dataloader, sampler)

        # add worker_init_fn for correct seeding in worker processes
        _auto_add_worker_init_fn(dataloader, self.global_rank)

        dataloader = self._strategy.process_dataloader(dataloader)
        device = self.device if move_to_device and not isinstance(
            self._strategy, TPUSpawnStrategy) else None
        lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
        lite_dataloader = cast(DataLoader, lite_dataloader)
        return lite_dataloader
    def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
        """This function handles to following functionalities:

        - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
        - Wrapping the datasets and samplers into fault-tolerant components
        """
        if isinstance(dataloader, CombinedLoader):
            # apply `prepare_dataloader` on all the collection of loaders
            dataloader.loaders = apply_to_collection(
                dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode
            )
            # the length need to recomputed across all dataloaders in case of special behavior.
            dataloader._apply_cycle_iterator_length()
            return dataloader

        # don't do anything if it's not a dataloader
        if not isinstance(dataloader, (DataLoader, CycleIterator)):
            return dataloader

        cycle_iterator: Optional[CycleIterator] = None

        if isinstance(dataloader, CycleIterator):
            cycle_iterator = dataloader
            dataloader = dataloader.loader

        if (
            _fault_tolerant_training()  # injects components to track the state
            or self._requires_distributed_sampler(dataloader)  # sets the distributed sampler
            or mode == RunningStage.PREDICTING  # to track indices for the predictions
            or self._accelerator_connector.use_ipu  # IPUs use a custom `DataLoader`
        ):
            sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
            dataloader = _update_dataloader(dataloader, sampler, mode=mode)

        if cycle_iterator is not None:
            cycle_iterator.loader = dataloader
            return cycle_iterator

        return dataloader
def test_dataloaders_with_missing_keyword_arguments():
    ds = RandomDataset(10, 20)

    class TestDataLoader(DataLoader):
        def __init__(self, dataset):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    match = escape(
        "missing arguments are ['batch_sampler', 'sampler', 'shuffle']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, dataset, *args, **kwargs):
            super().__init__(dataset)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    _update_dataloader(loader, sampler, mode="fit")
    _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, *foo, **bar):
            super().__init__(*foo, **bar)

    loader = TestDataLoader(ds)
    sampler = SequentialSampler(ds)
    _update_dataloader(loader, sampler, mode="fit")
    _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, *args, shuffle=False):
            self.num_feat = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing arguments are ['batch_sampler', 'sampler']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape(
        "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']"
    )
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")

    class TestDataLoader(DataLoader):
        def __init__(self, num_feat, dataset, **kwargs):
            self.feat_num = num_feat
            super().__init__(dataset)

    loader = TestDataLoader(1, ds)
    sampler = SequentialSampler(ds)
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="fit")
    match = escape("missing attributes are ['num_feat']")
    with pytest.raises(MisconfigurationException, match=match):
        _update_dataloader(loader, sampler, mode="predict")
def test_update_dataloader_raises():
    with pytest.raises(ValueError,
                       match="needs to subclass `torch.utils.data.DataLoader"):
        _update_dataloader(object(), object(), mode="fit")
Пример #10
0
 def replace_sampler(dataloader: DataLoader) -> DataLoader:
     return _update_dataloader(dataloader,
                               SequentialSampler(
                                   dataloader.dataset),
                               mode=RunningStage.TRAINING)
Пример #11
0
 def replace_sampler(dataloader: DataLoader) -> DataLoader:
     return _update_dataloader(dataloader,
                               sampler=SequentialSampler(
                                   dataloader.dataset),
                               mode=mode)