Beispiel #1
0
def test_dist_proxy_sampler():

    weights = torch.ones(100)
    weights[:50] += 1
    num_samples = 200
    sampler = WeightedRandomSampler(weights, num_samples)

    num_replicas = 8
    dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]

    for seed in range(100):
        torch.manual_seed(seed)
        true_indices = list(sampler)

        indices_per_rank = []
        for s in dist_samplers:
            s.set_epoch(seed)
            indices_per_rank += list(s)

        set_indices_per_rank = set(indices_per_rank)
        set_true_indices = set(true_indices)
        assert (
            set_indices_per_rank == set_true_indices
        ), f"{set_true_indices - set_indices_per_rank} | {set_indices_per_rank - set_true_indices}"

    with pytest.raises(TypeError, match=r"Argument sampler should be instance of torch Sampler"):
        DistributedProxySampler(None)

    with pytest.raises(TypeError, match=r"Argument sampler should have length"):
        DistributedProxySampler(Sampler([1]))
Beispiel #2
0
def test_dist_proxy_sampler():
    import torch
    from torch.utils.data import WeightedRandomSampler

    weights = torch.ones(100)
    weights[:50] += 1
    num_samples = 200
    sampler = WeightedRandomSampler(weights, num_samples)

    num_replicas = 8
    dist_samplers = [
        DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i)
        for i in range(num_replicas)
    ]

    for seed in range(100):
        torch.manual_seed(seed)
        true_indices = list(sampler)

        indices_per_rank = []
        for s in dist_samplers:
            s.set_epoch(seed)
            indices_per_rank += list(s)

        set_indices_per_rank = set(indices_per_rank)
        set_true_indices = set(true_indices)
        assert set_indices_per_rank == set_true_indices, "{} | {}".format(
            set_true_indices - set_indices_per_rank,
            set_indices_per_rank - set_true_indices)
Beispiel #3
0
def test_dist_proxy_sampler():
    import torch
    from torch.utils.data import WeightedRandomSampler

    weights = torch.ones(100)
    weights[:50] += 1
    num_samples = 100
    sampler = WeightedRandomSampler(weights, num_samples)

    num_replicas = 4
    dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]

    torch.manual_seed(0)
    true_indices = list(sampler)

    indices_per_rank = []
    for s in dist_samplers:
        s.set_epoch(0)
        indices_per_rank += list(s)

    assert set(indices_per_rank) == set(true_indices)