Beispiel #1
0
def test_buckets_by_boundaries(key_values, bucket_boundaries):

    buckets = buckets_by_boundaries(key_values=lengths,
                                    bucket_boundaries=bucket_boundaries)

    num_buckets = len(bucket_boundaries) + 1
    expected_bucket_ids = list(range(num_buckets))

    data_size = key_values.shape[0]
    bucket_ids = [buckets(idx) for idx in range(data_size)]
    key_values = [key_values[idx] for idx in range(data_size)]

    # do bucket ids only range from 0 to num_buckets-1?
    # missing ids only allowed if corresponding bucket empty
    missing_buckets = list(set(expected_bucket_ids) - set(bucket_ids))
    for bucket_id in missing_buckets:
        if bucket_id == 0:
            assert not any(
                [v < bucket_boundaries[bucket_id] for v in key_values])
        elif bucket_id == expected_bucket_ids[-1]:
            assert not any(
                [v > bucket_boundaries[(bucket_id - 1)] for v in key_values])
        else:
            assert not any([
                v in range(bucket_boundaries[(bucket_id - 1)],
                           bucket_boundaries[bucket_id]) for v in key_values
            ])

    sort_indices = sorted(range(data_size), key=lambda i: key_values[i])
    sorted_buckets = [bucket_ids[i] for i in sort_indices]
    diff_buckets = np.diff(sorted_buckets)

    # sorted with monotonously increasing/decreasing length of key values?
    assert all(diff >= 0 for diff in diff_buckets)
Beispiel #2
0
                v in range(bucket_boundaries[(bucket_id - 1)],
                           bucket_boundaries[bucket_id]) for v in key_values
            ])

    sort_indices = sorted(range(data_size), key=lambda i: key_values[i])
    sorted_buckets = [bucket_ids[i] for i in sort_indices]
    diff_buckets = np.diff(sorted_buckets)

    # sorted with monotonously increasing/decreasing length of key values?
    assert all(diff >= 0 for diff in diff_buckets)


@pytest.mark.parametrize("data", [data])
@pytest.mark.parametrize("key_func", [
    buckets_of_even_size(lengths, num_buckets, reverse=False),
    buckets_by_boundaries(lengths, bucket_boundaries)
])
@pytest.mark.parametrize("expected_num_datasets", [num_buckets])
@pytest.mark.parametrize("batch_sizes", [batch_sizes])
@pytest.mark.parametrize("expected_num_batches", [num_batches])
@pytest.mark.parametrize(
    "permuted_order", [False, random.shuffle(list(range(num_buckets)))])
@pytest.mark.parametrize("shuffle_each_bucket", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_bucket_sampler(data, key_func, expected_num_datasets, batch_sizes,
                        expected_num_batches, permuted_order,
                        shuffle_each_bucket, drop_last):

    subsets = defined_split(data, key_func)
    concat_dataset = ConcatDataset(subsets)