def test_has_iterable_dataset(): assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1))) assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1))) class MockDatasetWithoutIterableDataset(RandomDataset): def __iter__(self): yield 1 return self assert not has_iterable_dataset( DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader is_in_dist = self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu need_dist_sampler = is_in_dist and not isinstance( dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' ' distributed training. Either remove the sampler from your DataLoader or set' ' `replace_sampler_ddp`=False if you want to use your custom sampler.' ) # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader, shuffle) dataloader = self.replace_sampler(dataloader, sampler) return dataloader
def _requires_distributed_sampler(self, dataloader) -> bool: return (self.trainer._accelerator_connector.replace_sampler_ddp and self.trainer._accelerator_connector.is_distributed and not isinstance(dataloader.sampler, DistributedSampler) and not has_iterable_dataset(dataloader) # `DistributedSampler` is never used with `poptorch.DataLoader` and not isinstance(self.trainer.accelerator, IPUAccelerator))
def auto_add_sampler( self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None ) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets is_iterable_ds = has_iterable_dataset(dataloader) if isinstance(dataloader, CombinedLoader): dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle) return dataloader if not is_dataloader or is_iterable_ds: return dataloader need_dist_sampler = self.accelerator_connector.is_distributed and not isinstance( dataloader.sampler, DistributedSampler ) if self.accelerator_connector.replace_sampler_ddp and need_dist_sampler: if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' ' distributed training. Either remove the sampler from your DataLoader or set' ' `replace_sampler_ddp`=False if you want to use your custom sampler.' ) # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode) dataloader = self.replace_sampler(dataloader, sampler, mode=mode) return dataloader
def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset class IterableWithLen(IterableDataset): def __iter__(self): return iter(original_dataset) def __len__(self): return len(original_dataset) dataloader = DataLoader(IterableWithLen(), batch_size=16) assert has_len(dataloader) assert has_iterable_dataset(dataloader) trainer = Trainer( default_root_dir=tmpdir, max_steps=3, ) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): trainer.test(model, test_dataloaders=[dataloader])
def _requires_distributed_sampler(self, dataloader) -> bool: return ( self._accelerator_connector.replace_sampler_ddp and self._accelerator_connector.is_distributed and not isinstance(dataloader.sampler, DistributedSampler) and not has_iterable_dataset(dataloader) )
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: if isinstance(dataloader, CombinedLoader): # apply `auto_add_sampler` on all the collection of loaders dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle) return dataloader # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): return dataloader if ( self.accelerator_connector.replace_sampler_ddp and self.accelerator_connector.is_distributed and not isinstance(dataloader.sampler, DistributedSampler) and not has_iterable_dataset(dataloader) ): if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): raise MisconfigurationException( "You seem to have configured a sampler in your DataLoader. This will be replaced " " by `DistributedSampler` since `replace_sampler_ddp` is True and you are using" " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler_ddp`=False if you want to use your custom sampler." ) sampler = self._get_distributed_sampler(dataloader, shuffle, mode=mode) else: # use current sampler sampler = dataloader.sampler dataloader = self.replace_sampler(dataloader, sampler, mode=mode) return dataloader
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader need_dist_sampler = self.require_distributed_sampler and not isinstance(dataloader.sampler, DistributedSampler) if self.replace_sampler_ddp and need_dist_sampler: wrap_sampler = False if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): rank_zero_warn(f'Wrapping current custom sampler {dataloader.sampler} with DistributedSampler ...' ' If this is not the desired behavior, please specify `replace_sampler_ddp=False`') wrap_sampler = True # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader, shuffle, wrap=wrap_sampler) dataloader = self.replace_sampler(dataloader, sampler) return dataloader