def test_iterable_dataset_shuffle(dataset: IterableDataset, generate_examples_fn, seed, epoch):
    buffer_size = 3
    dataset = deepcopy(dataset)
    dataset._ex_iterable.kwargs["filepaths"] = ["0.txt", "1.txt"]
    dataset = dataset.shuffle(seed, buffer_size=buffer_size)
    assert isinstance(dataset._shuffling, ShufflingConfig)
    assert isinstance(dataset._shuffling.generator, np.random.Generator)
    assert is_rng_equal(dataset._shuffling.generator, np.random.default_rng(seed))
    # Effective seed is sum of seed and epoch
    if epoch is None or epoch == 0:
        effective_seed = seed
    else:
        dataset.set_epoch(epoch)
        effective_seed = np.random.default_rng(seed).integers(0, 1 << 63) - epoch
    # Shuffling adds a shuffle buffer
    expected_first_example_index = next(
        iter(BufferShuffledExamplesIterable._iter_random_indices(np.random.default_rng(effective_seed), buffer_size))
    )
    assert isinstance(dataset._ex_iterable, BufferShuffledExamplesIterable)
    # It also shuffles the underlying examples iterable
    expected_ex_iterable = ExamplesIterable(
        generate_examples_fn, {"filepaths": ["0.txt", "1.txt"]}
    ).shuffle_data_sources(np.random.default_rng(effective_seed))
    assert isinstance(dataset._ex_iterable.ex_iterable, ExamplesIterable)
    assert next(iter(dataset)) == list(islice(expected_ex_iterable, expected_first_example_index + 1))[-1][1]
Esempio n. 2
0
def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch):
    buffer_size = 10
    shuffled_dataset = dataset.shuffle(buffer_size, seed=seed)
    if epoch is not None:
        shuffled_dataset.set_epoch(epoch)
    if seed is None:
        assert shuffled_dataset._effective_seed is None
    else:
        assert shuffled_dataset._effective_seed == seed + (epoch if epoch is not None else 0)
def test_iterable_dataset_shuffle_after_skip_or_take(generate_examples_fn, method):
    seed = 42
    n, n_shards = 3, 10
    count = 7
    ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]})
    dataset = IterableDataset(ex_iterable)
    dataset = dataset.skip(n) if method == "skip" else dataset.take(count)
    shuffled_dataset = dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES)
    # shuffling a skip/take dataset should keep the same examples and don't shuffle the shards
    key = lambda x: f"{x['filepath']}_{x['id']}"  # noqa: E731
    assert sorted(dataset, key=key) == sorted(shuffled_dataset, key=key)
def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch):
    buffer_size = 10
    shuffled_dataset = dataset.shuffle(seed, buffer_size=buffer_size)
    base_generator = shuffled_dataset._shuffling.generator
    if epoch is not None:
        shuffled_dataset.set_epoch(epoch)
    effective_generator = shuffled_dataset._effective_generator()
    assert effective_generator is not None
    if epoch is None or epoch == 0:
        assert is_rng_equal(base_generator, shuffled_dataset._effective_generator())
    else:
        assert not is_rng_equal(base_generator, shuffled_dataset._effective_generator())
        effective_seed = deepcopy(base_generator).integers(0, 1 << 63) - epoch
        assert is_rng_equal(np.random.default_rng(effective_seed), shuffled_dataset._effective_generator())