def test_fixed_bucket_sampler_compactness(): samples = list( s.FixedBucketSampler(np.arange(16, 32), 8, num_buckets=2, bucket_scheme=nlp.data.ConstWidthBucket())) assert len(samples) == 2
def test_fixed_bucket_sampler_with_single_key(bucket_keys, ratio, shuffle): seq_lengths = [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)] sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=None, bucket_keys=bucket_keys, ratio=ratio, shuffle=shuffle) print(sampler.stats()) total_sampled_ids = [] for batch_sample_ids in sampler: total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme, use_average_length): sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets, ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme) print(sampler) total_sampled_ids = [] for batch_sample_ids in sampler: total_sampled_ids.extend(batch_sample_ids) assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N
def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme, use_average_length, num_shards): sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets, ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme, num_shards=num_shards) print(sampler.stats()) total_sampled_ids = [] for batch_sample_ids in sampler: if num_shards > 0: assert len(batch_sample_ids) == num_shards else: total_sampled_ids.extend(batch_sample_ids) if num_shards == 0: assert len(set(total_sampled_ids)) == len(total_sampled_ids) == N