Exemplo n.º 1
0
def test_fixed_bucket_sampler_compactness():
    samples = list(
        s.FixedBucketSampler(np.arange(16, 32),
                             8,
                             num_buckets=2,
                             bucket_scheme=nlp.data.ConstWidthBucket()))
    assert len(samples) == 2
def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle):
    seq_lengths = [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]
    sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None,
                                   bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle)
    print(sampler.stats())
    total_sampled_ids = []
    for batch_sample_ids in sampler:
        total_sampled_ids.extend(batch_sample_ids)
    assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
Exemplo n.º 3
0
def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets,
                              bucket_scheme, use_average_length):
    sampler = s.FixedBucketSampler(seq_lengths,
                                   batch_size=8,
                                   num_buckets=num_buckets,
                                   ratio=ratio,
                                   shuffle=shuffle,
                                   use_average_length=use_average_length,
                                   bucket_scheme=bucket_scheme)
    print(sampler)
    total_sampled_ids = []
    for batch_sample_ids in sampler:
        total_sampled_ids.extend(batch_sample_ids)
    assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme,
                              use_average_length, num_shards):
    sampler = s.FixedBucketSampler(seq_lengths,
                                   batch_size=8,
                                   num_buckets=num_buckets,
                                   ratio=ratio, shuffle=shuffle,
                                   use_average_length=use_average_length,
                                   bucket_scheme=bucket_scheme,
                                   num_shards=num_shards)
    print(sampler.stats())
    total_sampled_ids = []
    for batch_sample_ids in sampler:
        if num_shards > 0:
            assert len(batch_sample_ids) == num_shards
        else:
            total_sampled_ids.extend(batch_sample_ids)
    if num_shards == 0:
        assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N