def test_update_dataloader_typerror_custom_exception(): class BadImpl(DataLoader): def __init__(self, foo, *args, **kwargs): self.foo = foo # positional conflict with `dataset` super().__init__(foo, *args, **kwargs) dataloader = BadImpl([1, 2, 3]) with pytest.raises( MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) class BadImpl2(DataLoader): def __init__(self, randomize, *args, **kwargs): self.randomize = randomize # keyword conflict with `shuffle` super().__init__(*args, shuffle=randomize, **kwargs) dataloader = BadImpl2(False, []) with pytest.raises( MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): _update_dataloader(dataloader, dataloader.sampler) class GoodImpl(DataLoader): def __init__(self, randomize, *args, **kwargs): # fixed implementation, kwargs are filtered self.randomize = randomize or kwargs.pop("shuffle", False) super().__init__(*args, shuffle=randomize, **kwargs) dataloader = GoodImpl(False, []) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, GoodImpl)
def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values): with _replace_dataloader_init_method(): dataloader = cls(*args, **kwargs) assert dataloader.__pl_dl_args == args assert dataloader.__pl_dl_kwargs == kwargs assert dataloader.__pl_dl_arg_names == arg_names assert dataloader.__dataset == dataset assert dataloader.dataset == dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: assert getattr(dataloader, key) == value dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(dataloader, cls) assert not hasattr(dataloader, "__pl_dl_kwargs") assert not hasattr(dataloader, "__pl_dl_arg_names") assert not hasattr(dataloader, "__pl_dl_args") assert not hasattr(dataloader, "__dataset") assert dataloader.dataset == dataset for key, value in checked_values.items(): dataloader_value = getattr(dataloader, key) if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: assert getattr(dataloader, key) == value
def test_update_dataloader_with_multiprocessing_context(): """This test verifies that replace_sampler conserves multiprocessing context.""" train = RandomDataset(32, 64) context = "spawn" train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset)) assert new_data_loader.multiprocessing_context == train.multiprocessing_context
def _setup_dataloader(self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True) -> DataLoader: """Set up a single dataloader for accelerated training. Args: dataloader: The dataloader to accelerate. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloader. """ sampler = dataloader.sampler if replace_sampler and self._requires_distributed_sampler(dataloader): sampler = self._get_distributed_sampler( dataloader, **self._strategy.distributed_sampler_kwargs) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) dataloader = _update_dataloader(dataloader, sampler) # add worker_init_fn for correct seeding in worker processes _auto_add_worker_init_fn(dataloader, self.global_rank) dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance( self._strategy, TPUSpawnStrategy) else None lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) lite_dataloader = cast(DataLoader, lite_dataloader) return lite_dataloader
def _prepare_dataloader(self, dataloader: Any, shuffle: Optional[bool] = None, mode: Optional[RunningStage] = None) -> Any: """This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment - Wrapping the datasets and samplers into fault-tolerant components - Wrapping the dataloader based on strategy-specific logic """ if isinstance(dataloader, CombinedLoader): # apply `_prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( dataloader.loaders, (DataLoader, CycleIterator), self._prepare_dataloader, shuffle, mode=mode) # the length need to recomputed across all dataloaders in case of special behavior. dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader cycle_iterator: Optional[CycleIterator] = None if isinstance(dataloader, CycleIterator): cycle_iterator = dataloader dataloader = dataloader.loader if (_fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler( dataloader) # sets the distributed sampler or mode == RunningStage.PREDICTING # to track indices for the predictions # IPUs use a custom `poptorch.DataLoader` which we might need to convert to or isinstance(self.trainer.accelerator, IPUAccelerator)): if shuffle is None: # for training, set to True always # for evaluation, decide based on existing sampler shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled( dataloader) sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) dataloader = self.trainer.strategy.process_dataloader(dataloader) if cycle_iterator is not None: cycle_iterator.loader = dataloader return cycle_iterator return dataloader
def _setup_dataloader(self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True) -> DataLoader: """Set up a single dataloader for accelerated training. Args: dataloader: The dataloader to accelerate. replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this to this argument to ``False``. move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the returned data. Returns: The wrapped dataloader. """ sampler = dataloader.sampler if replace_sampler and self._requires_distributed_sampler(dataloader): if not isinstance(sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( "You seem to have configured a sampler in your DataLoader. This will be replaced " " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler=False` if you want to use your custom sampler." ) sampler = self._get_distributed_sampler( dataloader, **self._strategy.distributed_sampler_kwargs) # the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler) dataloader = _update_dataloader(dataloader, sampler) # add worker_init_fn for correct seeding in worker processes _auto_add_worker_init_fn(dataloader, self.global_rank) dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance( self._strategy, TPUSpawnStrategy) else None lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) lite_dataloader = cast(DataLoader, lite_dataloader) return lite_dataloader
def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: """This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment - Wrapping the datasets and samplers into fault-tolerant components """ if isinstance(dataloader, CombinedLoader): # apply `prepare_dataloader` on all the collection of loaders dataloader.loaders = apply_to_collection( dataloader.loaders, (DataLoader, CycleIterator), self.prepare_dataloader, shuffle, mode=mode ) # the length need to recomputed across all dataloaders in case of special behavior. dataloader._apply_cycle_iterator_length() return dataloader # don't do anything if it's not a dataloader if not isinstance(dataloader, (DataLoader, CycleIterator)): return dataloader cycle_iterator: Optional[CycleIterator] = None if isinstance(dataloader, CycleIterator): cycle_iterator = dataloader dataloader = dataloader.loader if ( _fault_tolerant_training() # injects components to track the state or self._requires_distributed_sampler(dataloader) # sets the distributed sampler or mode == RunningStage.PREDICTING # to track indices for the predictions or self._accelerator_connector.use_ipu # IPUs use a custom `DataLoader` ): sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) if cycle_iterator is not None: cycle_iterator.loader = dataloader return cycle_iterator return dataloader
def test_dataloaders_with_missing_keyword_arguments(): ds = RandomDataset(10, 20) class TestDataLoader(DataLoader): def __init__(self, dataset): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) match = escape( "missing arguments are ['batch_sampler', 'sampler', 'shuffle']") with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']" ) with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, dataset, *args, **kwargs): super().__init__(dataset) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) _update_dataloader(loader, sampler, mode="fit") _update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, *foo, **bar): super().__init__(*foo, **bar) loader = TestDataLoader(ds) sampler = SequentialSampler(ds) _update_dataloader(loader, sampler, mode="fit") _update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, *args, shuffle=False): self.num_feat = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing arguments are ['batch_sampler', 'sampler']") with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="fit") match = escape( "missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']" ) with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="predict") class TestDataLoader(DataLoader): def __init__(self, num_feat, dataset, **kwargs): self.feat_num = num_feat super().__init__(dataset) loader = TestDataLoader(1, ds) sampler = SequentialSampler(ds) match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="fit") match = escape("missing attributes are ['num_feat']") with pytest.raises(MisconfigurationException, match=match): _update_dataloader(loader, sampler, mode="predict")
def test_update_dataloader_raises(): with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): _update_dataloader(object(), object(), mode="fit")
def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader(dataloader, SequentialSampler( dataloader.dataset), mode=RunningStage.TRAINING)
def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader(dataloader, sampler=SequentialSampler( dataloader.dataset), mode=mode)