def test_dataloader_reinit_for_subclass(): class CustomDataLoader(torch.utils.data.DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, dummy_kwarg=None): super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn) self.dummy_kwarg = dummy_kwarg trainer = Trainer(gpus=[0, 1], num_nodes=1, distributed_backend='ddp') class CustomDummyObj: sampler = None result = trainer.auto_add_sampler(CustomDummyObj(), train=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg')
def test_dataloader_reinit_for_subclass(tmpdir): class CustomDataLoader(torch.utils.data.DataLoader): def __init__( self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, dummy_kwarg=None, **kwargs ): super().__init__( dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn ) self.dummy_kwarg = dummy_kwarg trainer = Trainer( gpus=[0, 1], num_nodes=1, accelerator='ddp_spawn', default_root_dir=tmpdir, ) class CustomDummyObj: sampler = None result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" dataset = list(range(1000)) result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') # Shuffled DataLoader should also work result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') class CustomSampler(torch.utils.data.Sampler): pass # Should raise an error if existing sampler is being replaced with pytest.raises(MisconfigurationException, match='DistributedSampler'): trainer.auto_add_sampler( CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True )
def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir): """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] dataloader = CombinedLoader( { "a": DataLoader(CustomDataset(range(10))), "b": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))}, "e": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))], } ) trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2) dataloader = trainer.auto_add_sampler(dataloader, shuffle=True) _count = 0 def _assert_distributed_sampler(v): nonlocal _count _count += 1 assert isinstance(v, DistributedSampler) apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) assert _count == 5
def test_dataloader_reinit_for_subclass(): class CustomDataLoader(DataLoader): def __init__( self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, dummy_kwarg=None, ): super().__init__( dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, ) self.dummy_kwarg = dummy_kwarg self.something_unrelated = 1 trainer = Trainer(num_processes=1, accelerator="ddp_cpu") class CustomDummyObj: sampler = None result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" dataset = list(range(10)) result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) assert isinstance(result, DataLoader) assert isinstance(result, CustomDataLoader) assert result.dummy_kwarg is None # Shuffled DataLoader should also work result = trainer.auto_add_sampler(CustomDataLoader(dataset, shuffle=True), shuffle=True) assert isinstance(result, DataLoader) assert isinstance(result, CustomDataLoader) assert result.dummy_kwarg is None class CustomSampler(Sampler): pass # Should raise an error if existing sampler is being replaced dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"): trainer.auto_add_sampler(dataloader, shuffle=True)