Exemple #1
0
def test_load_balancing_data_sampler_shuffles_and_balances_load():
    complexities = []
    for i in range(50):
        c = torch.randint(0, 100, (1, )).item()
        complexities.append(c)
        complexities.append(c)
    random.shuffle(complexities)

    samples = \
        [torch.FloatTensor([val]) for val in range(100)]
    samples_and_complexities = list(zip(samples, complexities))
    dataset = MyDataset(samples_and_complexities)

    def complexity_fn(sample):
        return sample[1]

    data_sampler0 = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=2,
        rank=0,
        shuffle=True)
    data_sampler1 = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=2,
        rank=1,
        shuffle=True)

    for index0, index1 in zip(data_sampler0, data_sampler1):
        assert samples_and_complexities[index0][1] == \
            samples_and_complexities[index1][1]
Exemple #2
0
def test_load_balancing_data_sampler_balances_load():
    samples_and_complexities = \
        [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)]
    dataset = MyDataset(samples_and_complexities)

    def complexity_fn(sample):
        return sample[1]

    data_sampler0 = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=2,
        rank=0,
        shuffle=False)
    data_sampler1 = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=2,
        rank=1,
        shuffle=False)

    largest_complexity = -1
    for index in data_sampler0:
        assert samples_and_complexities[index][1] >= largest_complexity
        largest_complexity = samples_and_complexities[index][1]

    largest_complexity = -1
    for index in data_sampler1:
        assert samples_and_complexities[index][1] >= largest_complexity
        largest_complexity = samples_and_complexities[index][1]
Exemple #3
0
def test_load_balancing_data_sampler_sorts_in_groups():
    samples_and_complexities = \
        [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)]
    dataset = MyDataset(samples_and_complexities)

    def complexity_fn(sample):
        return sample[1]

    group_size = 8
    samples_and_complexities_sorted = samples_and_complexities.copy()
    for begin_index in range(0, len(samples_and_complexities), group_size):
        end_index = min(begin_index + group_size,
                        len(samples_and_complexities))
        samples_and_complexities_sorted[begin_index:end_index] = sorted(
            samples_and_complexities_sorted[begin_index:end_index],
            key=lambda x: x[1])

    data_sampler = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=1,
        rank=0,
        shuffle=False,
        group_size=8)

    for index, sorted_sample in zip(data_sampler,
                                    samples_and_complexities_sorted):
        assert samples_and_complexities[index][1] == sorted_sample[1]
Exemple #4
0
def test_load_balancing_batch_sampler_uses_data_sampler():
    samples_and_complexities = \
        [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)]
    dataset = MyDataset(samples_and_complexities)

    def complexity_fn(sample):
        return sample[1]

    data_sampler = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=1,
        rank=0,
        shuffle=False)

    batch_size = 12

    def batch_fn(indices):
        nonlocal batch_size
        batches = []
        for batch_index_begin in range(0, len(indices), batch_size):
            batch_index_end = min(batch_index_begin + batch_size, len(indices))
            batches.append(indices[batch_index_begin:batch_index_end])
        return batches

    batch_sampler = sampler.LoadBalancingDistributedBatchSampler(
        data_sampler, batch_fn)

    for batch in batch_sampler:
        assert len(batch) == batch_size or \
            len(batch) == len(samples_and_complexities) % batch_size
def test_load_balancing_data_sampler_sorts_and_shuffles_in_groups():
    samples_and_complexities = [(torch.FloatTensor([val]),
                                 torch.randint(0, 100, (1, )).item())
                                for val in range(100)]
    dataset = MyDataset(samples_and_complexities)

    def complexity_fn(sample):
        return sample[1]

    group_size = 8
    samples_and_complexities_sorted = samples_and_complexities.copy()
    for begin_index in range(0, len(samples_and_complexities), group_size):
        end_index = min(begin_index + group_size,
                        len(samples_and_complexities))
        samples_and_complexities_sorted[begin_index:end_index] = sorted(
            samples_and_complexities_sorted[begin_index:end_index],
            key=lambda x: x[1])

    samples_and_complexities_sorted_and_shuffled = samples_and_complexities_sorted.copy(
    )
    shuffled_group_order = torch.randperm(
        (len(samples_and_complexities) + group_size - 1) // group_size,
        generator=torch.Generator().manual_seed(0)).tolist()
    end = 0
    for group_index in shuffled_group_order:
        original_begin = group_index * group_size
        original_end = min(original_begin + group_size,
                           len(samples_and_complexities))
        begin = end
        end = begin + (original_end - original_begin)
        samples_and_complexities_sorted_and_shuffled[
            begin:end] = samples_and_complexities_sorted[
                original_begin:original_end]

    data_sampler = sampler.LoadBalancingDistributedSampler(
        dataset,
        complexity_fn=complexity_fn,
        world_size=1,
        rank=0,
        shuffle=True,
        group_size=8)

    for index, sorted_and_shuffled_sample in zip(
            data_sampler, samples_and_complexities_sorted_and_shuffled):
        assert samples_and_complexities[index][
            1] == sorted_and_shuffled_sample[1]