def test_auto_dataloader_warning(distributed_context_single_node_gloo): with pytest.warns(UserWarning, match=r"Found batch_sampler in provided kwargs"): auto_dataloader(DummyDS(), batch_sampler=BatchSampler(SequentialSampler( range(10)), batch_size=3, drop_last=False))
def _test_auto_dataloader(ws, nproc, sampler_name=None, dl_type=DataLoader): data = torch.rand(100, 3, 12, 12) if sampler_name is None: sampler = None elif sampler_name == "WeightedRandomSampler": sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100) else: raise RuntimeError("Unknown sampler name: {}".format(sampler_name)) # Test auto_dataloader assert idist.get_world_size() == ws dataloader = auto_dataloader(data, batch_size=10, num_workers=2, sampler=sampler, shuffle=sampler is None) assert isinstance(dataloader, dl_type) assert dataloader.batch_size == 10 // ws assert dataloader.num_workers == (2 + nproc - 1) // nproc if ws < 2: sampler_type = RandomSampler if sampler is None else type(sampler) assert isinstance(dataloader.sampler, sampler_type) else: sampler_type = DistributedSampler if sampler is None else DistributedProxySampler assert isinstance(dataloader.sampler, sampler_type)
def _test(data): if sampler_name is None: sampler = None elif sampler_name == "WeightedRandomSampler": sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100) elif sampler_name == "DistributedSampler": sampler = DistributedSampler(data, num_replicas=ws, rank=idist.get_rank()) else: raise RuntimeError(f"Unknown sampler name: {sampler_name}") # Test auto_dataloader assert idist.get_world_size( ) == ws, f"{idist.get_world_size()} vs {ws}" shuffle = sampler is None if not isinstance(data, IterableDataset) else False dataloader = auto_dataloader(data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=shuffle) assert isinstance(dataloader, dl_type) if hasattr(dataloader, "_loader"): dataloader = dataloader._loader if ws < batch_size: assert dataloader.batch_size == batch_size // ws else: assert dataloader.batch_size == batch_size if ws <= num_workers: assert dataloader.num_workers == (num_workers + nproc - 1) // nproc else: assert dataloader.num_workers == num_workers if isinstance(data, IterableDataset): sampler_type = _InfiniteConstantSampler elif ws > 1: if sampler is None or isinstance(sampler, DistributedSampler): sampler_type = DistributedSampler else: sampler_type = DistributedProxySampler else: sampler_type = RandomSampler if sampler is None else type(sampler) assert isinstance(dataloader.sampler, sampler_type) if isinstance(dataloader, DataLoader): assert dataloader.pin_memory == ("cuda" in idist.device().type)
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader): data = torch.rand(100, 3, 12, 12) if sampler_name is None: sampler = None elif sampler_name == "WeightedRandomSampler": sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100) else: raise RuntimeError("Unknown sampler name: {}".format(sampler_name)) # Test auto_dataloader assert idist.get_world_size() == ws, "{} vs {}".format( idist.get_world_size(), ws) dataloader = auto_dataloader(data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None) assert isinstance(dataloader, dl_type) if hasattr(dataloader, "_loader"): dataloader = dataloader._loader if ws < batch_size: assert dataloader.batch_size == batch_size // ws else: assert dataloader.batch_size == batch_size if ws <= num_workers: assert dataloader.num_workers == (num_workers + nproc - 1) // nproc else: assert dataloader.num_workers == num_workers if ws < 2: sampler_type = RandomSampler if sampler is None else type(sampler) assert isinstance(dataloader.sampler, sampler_type) else: sampler_type = DistributedSampler if sampler is None else DistributedProxySampler assert isinstance(dataloader.sampler, sampler_type) if isinstance(dataloader, DataLoader): assert dataloader.pin_memory == ("cuda" in idist.device().type)
def test_auto_dataloader_warning_distributed_sampler(distributed_context_single_node_gloo): dataset = DummyDS() rank = idist.get_rank() world_size = idist.get_world_size() auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size, rank=rank)) if world_size > 1: wrong_rank = (rank + 1) % world_size expected_warning = f"Found distributed sampler with rank={wrong_rank}, but process rank is {rank}" with pytest.warns(UserWarning, match=expected_warning): auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size, rank=wrong_rank)) expected_warning = f"Found distributed sampler with num_replicas={world_size + 1}, but world size is {world_size}" with pytest.warns(UserWarning, match=expected_warning): auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size + 1, rank=rank))
def test_auto_dataloader_warning_tpu(): with pytest.warns(UserWarning, match=r"Found incompatible options: xla support and pin_memory"): auto_dataloader(DummyDS(), pin_memory=True)