Exemplo n.º 1
0
    def _setup_dataloader(self,
                          dataloader: DataLoader,
                          replace_sampler: bool = True,
                          move_to_device: bool = True) -> DataLoader:
        """Set up a single dataloader for accelerated training.

        Args:
            dataloader: The dataloader to accelerate.
            replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader
                for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
            move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to
                the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
                returned data.

        Returns:
            The wrapped dataloader.
        """
        sampler = dataloader.sampler
        if replace_sampler and self._requires_distributed_sampler(dataloader):
            sampler = self._get_distributed_sampler(
                dataloader, **self._strategy.distributed_sampler_kwargs)

        # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
        dataloader = _update_dataloader(dataloader, sampler)

        # add worker_init_fn for correct seeding in worker processes
        _auto_add_worker_init_fn(dataloader, self.global_rank)

        dataloader = self._strategy.process_dataloader(dataloader)
        device = self.device if move_to_device and not isinstance(
            self._strategy, TPUSpawnStrategy) else None
        lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device)
        lite_dataloader = cast(DataLoader, lite_dataloader)
        return lite_dataloader
Exemplo n.º 2
0
def test_lite_dataloader_device_placement(src_device, dest_device):
    """Test that the LiteDataLoader moves data to the device in its iterator."""
    sample0 = torch.tensor(0, device=src_device)
    sample1 = torch.tensor(1, device=src_device)
    sample2 = {"data": torch.tensor(2, device=src_device)}
    sample3 = {"data": torch.tensor(3, device=src_device)}
    dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
    lite_dataloader = _LiteDataLoader(dataloader=dataloader,
                                      device=dest_device)
    iterator = iter(lite_dataloader)

    batch0 = next(iterator)
    assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device))

    batch1 = next(iterator)
    assert torch.equal(batch1["data"], torch.tensor([2, 3],
                                                    device=dest_device))
Exemplo n.º 3
0
def test_lite_dataloader_iterator():
    """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic
    device placement)."""
    dataloader = DataLoader(range(5), batch_size=2)
    lite_dataloader = _LiteDataLoader(dataloader)
    assert len(lite_dataloader) == len(dataloader) == 3

    iterator = iter(dataloader)
    lite_iterator = iter(lite_dataloader)

    assert torch.equal(next(iterator), next(lite_iterator))
    assert torch.equal(next(iterator), next(lite_iterator))
    assert torch.equal(next(iterator), next(lite_iterator))

    with pytest.raises(StopIteration):
        next(iterator)

    with pytest.raises(StopIteration):
        next(lite_iterator)
Exemplo n.º 4
0
    def _setup_dataloader(self,
                          dataloader: DataLoader,
                          replace_sampler: bool = True,
                          move_to_device: bool = True) -> Iterable:
        """Setup a single dataloader for accelerated training.

        Args:
            dataloader: The dataloader to accelerate.
            replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader
                for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
            move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatially to
                the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
                returned data.

        Returns:
            The wrapped dataloader.
        """
        sampler = dataloader.sampler
        if replace_sampler and self._requires_distributed_sampler(dataloader):
            if not isinstance(sampler, (SequentialSampler, RandomSampler)):
                raise MisconfigurationException(
                    "You seem to have configured a sampler in your DataLoader. This will be replaced "
                    " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using"
                    " distributed training. Either remove the sampler from your DataLoader or set"
                    " `replace_sampler=False` if you want to use your custom sampler."
                )
            sampler = self._get_distributed_sampler(
                dataloader, **self._strategy.distributed_sampler_kwargs)

        # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
        dataloader = TrainerDataLoadingMixin._update_dataloader(
            dataloader, sampler)

        # add worker_init_fn for correct seeding in worker processes
        TrainerDataLoadingMixin._auto_add_worker_init_fn(
            dataloader, self.global_rank)

        dataloader = self._strategy.process_dataloader(dataloader)
        device = self.device if move_to_device and not isinstance(
            self._strategy, TPUSpawnPlugin) else None
        return _LiteDataLoader(dataloader=dataloader, device=device)