コード例 #1
0
def test_auto_dataloader_warning(distributed_context_single_node_gloo):
    with pytest.warns(UserWarning,
                      match=r"Found batch_sampler in provided kwargs"):
        auto_dataloader(DummyDS(),
                        batch_sampler=BatchSampler(SequentialSampler(
                            range(10)),
                                                   batch_size=3,
                                                   drop_last=False))
コード例 #2
0
ファイル: test_auto.py プロジェクト: Joel-hanson/ignite
def _test_auto_dataloader(ws, nproc, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100),
                                        num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(data,
                                 batch_size=10,
                                 num_workers=2,
                                 sampler=sampler,
                                 shuffle=sampler is None)

    assert isinstance(dataloader, dl_type)
    assert dataloader.batch_size == 10 // ws
    assert dataloader.num_workers == (2 + nproc - 1) // nproc
    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
コード例 #3
0
ファイル: test_auto.py プロジェクト: gucifer/ignite
    def _test(data):
        if sampler_name is None:
            sampler = None
        elif sampler_name == "WeightedRandomSampler":
            sampler = WeightedRandomSampler(weights=torch.ones(100),
                                            num_samples=100)
        elif sampler_name == "DistributedSampler":
            sampler = DistributedSampler(data,
                                         num_replicas=ws,
                                         rank=idist.get_rank())
        else:
            raise RuntimeError(f"Unknown sampler name: {sampler_name}")

        # Test auto_dataloader
        assert idist.get_world_size(
        ) == ws, f"{idist.get_world_size()} vs {ws}"

        shuffle = sampler is None if not isinstance(data,
                                                    IterableDataset) else False
        dataloader = auto_dataloader(data,
                                     batch_size=batch_size,
                                     num_workers=num_workers,
                                     sampler=sampler,
                                     shuffle=shuffle)

        assert isinstance(dataloader, dl_type)
        if hasattr(dataloader, "_loader"):
            dataloader = dataloader._loader
        if ws < batch_size:
            assert dataloader.batch_size == batch_size // ws
        else:
            assert dataloader.batch_size == batch_size
        if ws <= num_workers:
            assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
        else:
            assert dataloader.num_workers == num_workers

        if isinstance(data, IterableDataset):
            sampler_type = _InfiniteConstantSampler
        elif ws > 1:
            if sampler is None or isinstance(sampler, DistributedSampler):
                sampler_type = DistributedSampler
            else:
                sampler_type = DistributedProxySampler
        else:
            sampler_type = RandomSampler if sampler is None else type(sampler)

        assert isinstance(dataloader.sampler, sampler_type)

        if isinstance(dataloader, DataLoader):
            assert dataloader.pin_memory == ("cuda" in idist.device().type)
コード例 #4
0
ファイル: test_auto.py プロジェクト: vieozhu/ignite
def _test_auto_dataloader(ws,
                          nproc,
                          batch_size,
                          num_workers=1,
                          sampler_name=None,
                          dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100),
                                        num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws, "{} vs {}".format(
        idist.get_world_size(), ws)
    dataloader = auto_dataloader(data,
                                 batch_size=batch_size,
                                 num_workers=num_workers,
                                 sampler=sampler,
                                 shuffle=sampler is None)

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type)
コード例 #5
0
def test_auto_dataloader_warning_distributed_sampler(distributed_context_single_node_gloo):
    dataset = DummyDS()
    rank = idist.get_rank()
    world_size = idist.get_world_size()
    auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size, rank=rank))
    if world_size > 1:
        wrong_rank = (rank + 1) % world_size
        expected_warning = f"Found distributed sampler with rank={wrong_rank}, but process rank is {rank}"
        with pytest.warns(UserWarning, match=expected_warning):
            auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size, rank=wrong_rank))
    expected_warning = f"Found distributed sampler with num_replicas={world_size + 1}, but world size is {world_size}"
    with pytest.warns(UserWarning, match=expected_warning):
        auto_dataloader(dataset, sampler=DistributedSampler(dataset, num_replicas=world_size + 1, rank=rank))
コード例 #6
0
def test_auto_dataloader_warning_tpu():
    with pytest.warns(UserWarning, match=r"Found incompatible options: xla support and pin_memory"):
        auto_dataloader(DummyDS(), pin_memory=True)