def test_fixed_bucket_sampler_compactness(): samples = list( s.FixedBucketSampler(np.arange(16, 32), 8, num_buckets=2, bucket_scheme=s.ConstWidthBucket())) assert len(samples) == 2
reverse=True) sample_ret = list(s.SortedSampler([ele.shape[0] for ele in dataset])) for lhs, rhs in zip(gt_sample_id, sample_ret): assert lhs == rhs @pytest.mark.parametrize( 'seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]]) @pytest.mark.parametrize('ratio', [0.0, 0.5]) @pytest.mark.parametrize('shuffle', [False, True]) @pytest.mark.parametrize('num_buckets', [1, 10, 100, 5000]) @pytest.mark.parametrize( 'bucket_scheme', [s.ConstWidthBucket(), s.LinearWidthBucket(), s.ExpWidthBucket()]) @pytest.mark.parametrize('use_average_length', [False, True]) @pytest.mark.parametrize('num_shards', range(4)) def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme, use_average_length, num_shards): with warnings.catch_warnings(): warnings.simplefilter("ignore") sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets, ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme,
N = 1000 def test_sorted_sampler(): dataset = data.SimpleDataset([np.random.normal(0, 1, (np.random.randint(10, 100), 1, 1)) for _ in range(N)]) gt_sample_id = sorted(range(len(dataset)), key=lambda i: dataset[i].shape, reverse=True) sample_ret = list(s.SortedSampler([ele.shape[0] for ele in dataset])) for lhs, rhs in zip(gt_sample_id, sample_ret): assert lhs == rhs @pytest.mark.parametrize('seq_lengths', [[np.random.randint(10, 100) for _ in range(N)], [(np.random.randint(10, 100), np.random.randint(10, 100)) for _ in range(N)]]) @pytest.mark.parametrize('ratio', [0.0, 0.5]) @pytest.mark.parametrize('shuffle', [False, True]) @pytest.mark.parametrize('num_buckets', [1, 10, 100, 5000]) @pytest.mark.parametrize('bucket_scheme', [s.ConstWidthBucket(), s.LinearWidthBucket(), s.ExpWidthBucket()]) @pytest.mark.parametrize('use_average_length', [False, True]) @pytest.mark.parametrize('num_shards', range(4)) def test_fixed_bucket_sampler(seq_lengths, ratio, shuffle, num_buckets, bucket_scheme, use_average_length, num_shards): sampler = s.FixedBucketSampler(seq_lengths, batch_size=8, num_buckets=num_buckets, ratio=ratio, shuffle=shuffle, use_average_length=use_average_length, bucket_scheme=bucket_scheme, num_shards=num_shards) print(sampler.stats()) total_sampled_ids = []