def test_iterable_dataset_shuffle_after_skip_or_take(generate_examples_fn, method): seed = 42 n, n_shards = 3, 10 count = 7 ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]}) dataset = IterableDataset(ex_iterable) dataset = dataset.skip(n) if method == "skip" else dataset.take(count) shuffled_dataset = dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES) # shuffling a skip/take dataset should keep the same examples and don't shuffle the shards key = lambda x: f"{x['filepath']}_{x['id']}" # noqa: E731 assert sorted(dataset, key=key) == sorted(shuffled_dataset, key=key)
def test_iterable_dataset_skip(dataset: IterableDataset, n): skip_dataset = dataset.skip(n) assert isinstance(skip_dataset._ex_iterable, SkipExamplesIterable) assert skip_dataset._ex_iterable.n == n assert list(skip_dataset) == list(dataset)[n:]