Ejemplo n.º 1
0
def test_split_sampler(num_samples, num_parts):
    total_count = 0
    indices = []
    for part_idx in range(num_parts):
        sampler = s.SplitSampler(num_samples, num_parts, part_idx)
        count = 0
        for i in sampler:
            count += 1
            indices.append(i)
        total_count += count
        assert count == len(sampler)
    assert total_count == num_samples
    assert sorted(indices) == list(range(num_samples))
Ejemplo n.º 2
0
def test_split_sampler_even_size(num_samples, num_parts):
    total_count = 0
    indices = []
    for part_idx in range(num_parts):
        sampler = s.SplitSampler(num_samples,
                                 num_parts,
                                 part_idx,
                                 even_size=True)
        count = 0
        for i in sampler:
            count += 1
            indices.append(i)
        total_count += count
        assert count == len(sampler)
        print(count)
    expected_count = int(num_samples + num_parts - 1) // num_parts * num_parts
    assert total_count == expected_count, (total_count, expected_count)
Ejemplo n.º 3
0
def test_split_sampler(num_samples, num_parts, repeat):
    total_count = 0
    indices = []
    for part_idx in range(num_parts):
        sampler = s.SplitSampler(num_samples,
                                 num_parts,
                                 part_idx,
                                 repeat=repeat)
        count = 0
        for i in sampler:
            count += 1
            indices.append(i)
        total_count += count
        assert count == len(sampler)
    assert total_count == num_samples * repeat
    assert np.allclose(sorted(indices),
                       np.repeat(list(range(num_samples)), repeat))