def _setup_dataloader(self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True) -> DataLoader: """Set up a single dataloader for accelerated training. Args: dataloader: The dataloader to accelerate. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloader. """ sampler = dataloader.sampler if replace_sampler and self._requires_distributed_sampler(dataloader): sampler = self._get_distributed_sampler( dataloader, **self._strategy.distributed_sampler_kwargs) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) dataloader = _update_dataloader(dataloader, sampler) # add worker_init_fn for correct seeding in worker processes _auto_add_worker_init_fn(dataloader, self.global_rank) dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance( self._strategy, TPUSpawnStrategy) else None lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) lite_dataloader = cast(DataLoader, lite_dataloader) return lite_dataloader
def test_lite_dataloader_device_placement(src_device, dest_device): """Test that the LiteDataLoader moves data to the device in its iterator.""" sample0 = torch.tensor(0, device=src_device) sample1 = torch.tensor(1, device=src_device) sample2 = {"data": torch.tensor(2, device=src_device)} sample3 = {"data": torch.tensor(3, device=src_device)} dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2) lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device) iterator = iter(lite_dataloader) batch0 = next(iterator) assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device)) batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
def test_lite_dataloader_iterator(): """Test that the iteration over a LiteDataLoader wraps the iterator of the underlying dataloader (no automatic device placement).""" dataloader = DataLoader(range(5), batch_size=2) lite_dataloader = _LiteDataLoader(dataloader) assert len(lite_dataloader) == len(dataloader) == 3 iterator = iter(dataloader) lite_iterator = iter(lite_dataloader) assert torch.equal(next(iterator), next(lite_iterator)) assert torch.equal(next(iterator), next(lite_iterator)) assert torch.equal(next(iterator), next(lite_iterator)) with pytest.raises(StopIteration): next(iterator) with pytest.raises(StopIteration): next(lite_iterator)
def _setup_dataloader(self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True) -> Iterable: """Setup a single dataloader for accelerated training. Args: dataloader: The dataloader to accelerate. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatially to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloader. """ sampler = dataloader.sampler if replace_sampler and self._requires_distributed_sampler(dataloader): if not isinstance(sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( "You seem to have configured a sampler in your DataLoader. This will be replaced " " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler=False` if you want to use your custom sampler." ) sampler = self._get_distributed_sampler( dataloader, **self._strategy.distributed_sampler_kwargs) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) dataloader = TrainerDataLoadingMixin._update_dataloader( dataloader, sampler) # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn( dataloader, self.global_rank) dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance( self._strategy, TPUSpawnPlugin) else None return _LiteDataLoader(dataloader=dataloader, device=device)