def test_iterable_dataset_cast(generate_examples_fn):
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": 10})
    features = Features({"id": Value("int64"), "label": Value("int64")})
    dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features))
    new_features = Features({"id": Value("int64"), "label": Value("bool")})
    casted_dataset = dataset.cast(new_features)
    assert list(casted_dataset) == [new_features.encode_example(ex) for _, ex in ex_iterable]
def test_iterable_dataset_shuffle(dataset: IterableDataset, generate_examples_fn, seed, epoch):
    buffer_size = 3
    dataset = deepcopy(dataset)
    dataset._ex_iterable.kwargs["filepaths"] = ["0.txt", "1.txt"]
    dataset = dataset.shuffle(seed, buffer_size=buffer_size)
    assert isinstance(dataset._shuffling, ShufflingConfig)
    assert isinstance(dataset._shuffling.generator, np.random.Generator)
    assert is_rng_equal(dataset._shuffling.generator, np.random.default_rng(seed))
    # Effective seed is sum of seed and epoch
    if epoch is None or epoch == 0:
        effective_seed = seed
    else:
        dataset.set_epoch(epoch)
        effective_seed = np.random.default_rng(seed).integers(0, 1 << 63) - epoch
    # Shuffling adds a shuffle buffer
    expected_first_example_index = next(
        iter(BufferShuffledExamplesIterable._iter_random_indices(np.random.default_rng(effective_seed), buffer_size))
    )
    assert isinstance(dataset._ex_iterable, BufferShuffledExamplesIterable)
    # It also shuffles the underlying examples iterable
    expected_ex_iterable = ExamplesIterable(
        generate_examples_fn, {"filepaths": ["0.txt", "1.txt"]}
    ).shuffle_data_sources(np.random.default_rng(effective_seed))
    assert isinstance(dataset._ex_iterable.ex_iterable, ExamplesIterable)
    assert next(iter(dataset)) == list(islice(expected_ex_iterable, expected_first_example_index + 1))[-1][1]
def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length):
    d1 = dataset
    d2 = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
    d3 = dataset.with_format("python")
    datasets = [d1, d2, d3]
    merged_dataset = interleave_datasets(datasets, probabilities=probas, seed=seed)
    # Check the examples iterable
    assert isinstance(
        merged_dataset._ex_iterable, (CyclingMultiSourcesExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable)
    )
    # Check that it is deterministic
    if seed is not None:
        merged_dataset2 = interleave_datasets([d1, d2, d3], probabilities=probas, seed=seed)
        assert list(merged_dataset) == list(merged_dataset2)
    # Check first example
    if seed is not None:
        rng = np.random.default_rng(seed)
        i = next(iter(RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas)))
        assert next(iter(merged_dataset)) == next(iter(datasets[i]))
    else:
        assert any(next(iter(merged_dataset)) == next(iter(dataset)) for dataset in datasets)
    # Compute length it case it's random
    if expected_length is None:
        expected_length = 0
        counts = [len(list(d)) for d in datasets]
        rng = np.random.default_rng(seed)
        for i in RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas):
            if counts[i] == 0:
                break
            counts[i] -= 1
            expected_length += 1
    # Check length
    assert len(list(merged_dataset)) == expected_length
def test_iterable_dataset_shuffle_after_skip_or_take(generate_examples_fn, method):
    seed = 42
    n, n_shards = 3, 10
    count = 7
    ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "filepaths": [f"{i}.txt" for i in range(n_shards)]})
    dataset = IterableDataset(ex_iterable)
    dataset = dataset.skip(n) if method == "skip" else dataset.take(count)
    shuffled_dataset = dataset.shuffle(seed, buffer_size=DEFAULT_N_EXAMPLES)
    # shuffling a skip/take dataset should keep the same examples and don't shuffle the shards
    key = lambda x: f"{x['filepath']}_{x['id']}"  # noqa: E731
    assert sorted(dataset, key=key) == sorted(shuffled_dataset, key=key)
def test_iterable_dataset_map(dataset: IterableDataset, generate_examples_fn):
    func = lambda x: {"id+1": x["id"] + 1}  # noqa: E731
    mapped_dataset = dataset.map(func)
    assert isinstance(mapped_dataset._ex_iterable, MappedExamplesIterable)
    assert mapped_dataset._ex_iterable.function is func
    assert mapped_dataset._ex_iterable.batched is False
    assert next(iter(mapped_dataset)) == {**next(iter(dataset)), **func(next(iter(generate_examples_fn()))[1])}
def test_iterable_dataset_with_format(dataset: IterableDataset, format_type):
    formatted_dataset = dataset.with_format(format_type)
    assert formatted_dataset._format_type == format_type
    if format_type == "torch":
        import torch

        assert isinstance(formatted_dataset, torch.utils.data.IterableDataset)
def test_iterable_dataset_map_batched(dataset: IterableDataset, generate_examples_fn):
    func = lambda x: {"id+1": [i + 1 for i in x["id"]]}  # noqa: E731
    batch_size = 3
    dataset = dataset.map(func, batched=True, batch_size=batch_size)
    assert isinstance(dataset._ex_iterable, MappedExamplesIterable)
    assert dataset._ex_iterable.function is func
    assert dataset._ex_iterable.batch_size == batch_size
    assert next(iter(dataset)) == {"id": 0, "id+1": 1}
def test_iterable_dataset_info(generate_examples_fn):
    info = DatasetInfo(description="desc", citation="@article{}", size_in_bytes=42)
    ex_iterable = ExamplesIterable(generate_examples_fn, {})
    dataset = IterableDataset(ex_iterable, info=info)
    assert dataset.info == info
    assert dataset.description == info.description
    assert dataset.citation == info.citation
    assert dataset.size_in_bytes == info.size_in_bytes
def test_iterable_dataset_features(generate_examples_fn, features):
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": 0})
    dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features))
    if features:
        expected = [features.encode_example(x) for _, x in ex_iterable]
    else:
        expected = [x for _, x in ex_iterable]
    assert list(dataset) == expected
def test_iterable_dataset_map_complex_features(dataset: IterableDataset, generate_examples_fn):
    # https://github.com/huggingface/datasets/issues/3505
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": "positive"})
    features = Features(
        {
            "id": Value("int64"),
            "label": Value("string"),
        }
    )
    dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features))
    dataset = dataset.cast_column("label", ClassLabel(names=["negative", "positive"]))
    dataset = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
    assert isinstance(dataset._ex_iterable, MappedExamplesIterable)
    features["label"] = ClassLabel(names=["negative", "positive"])
    assert [{k: v for k, v in ex.items() if k != "id+1"} for ex in dataset] == [
        features.encode_example(ex) for _, ex in ex_iterable
    ]
예제 #11
0
def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch):
    buffer_size = 10
    shuffled_dataset = dataset.shuffle(buffer_size, seed=seed)
    if epoch is not None:
        shuffled_dataset.set_epoch(epoch)
    if seed is None:
        assert shuffled_dataset._effective_seed is None
    else:
        assert shuffled_dataset._effective_seed == seed + (epoch if epoch is not None else 0)
예제 #12
0
def dataset_with_several_columns(generate_examples_fn):
    ex_iterable = ExamplesIterable(
        generate_examples_fn,
        {
            "filepath": ["data0.txt", "data1.txt", "data2.txt"],
            "metadata": {
                "sources": ["https://foo.bar"]
            }
        },
    )
    return IterableDataset(ex_iterable,
                           info=DatasetInfo(description="dummy"),
                           split="train")
def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch):
    buffer_size = 10
    shuffled_dataset = dataset.shuffle(seed, buffer_size=buffer_size)
    base_generator = shuffled_dataset._shuffling.generator
    if epoch is not None:
        shuffled_dataset.set_epoch(epoch)
    effective_generator = shuffled_dataset._effective_generator()
    assert effective_generator is not None
    if epoch is None or epoch == 0:
        assert is_rng_equal(base_generator, shuffled_dataset._effective_generator())
    else:
        assert not is_rng_equal(base_generator, shuffled_dataset._effective_generator())
        effective_seed = deepcopy(base_generator).integers(0, 1 << 63) - epoch
        assert is_rng_equal(np.random.default_rng(effective_seed), shuffled_dataset._effective_generator())
def test_interleave_datasets_with_features(dataset: IterableDataset, generate_examples_fn):
    features = Features(
        {
            "id": Value("int64"),
            "label": ClassLabel(names=["negative", "positive"]),
        }
    )
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": 0})
    dataset_with_features = IterableDataset(ex_iterable, info=DatasetInfo(features=features))

    merged_dataset = interleave_datasets([dataset, dataset_with_features], probabilities=[0, 1])
    assert isinstance(merged_dataset._ex_iterable, CyclingMultiSourcesExamplesIterable)
    assert isinstance(merged_dataset._ex_iterable.ex_iterables[1], TypedExamplesIterable)
    assert merged_dataset._ex_iterable.ex_iterables[1].features == features
    assert next(iter(merged_dataset)) == next(iter(dataset_with_features))
def test_iterable_dataset(generate_examples_fn):
    dataset = IterableDataset(ExamplesIterable(generate_examples_fn, {}))
    expected = [x for _, x in generate_examples_fn()]
    assert next(iter(dataset)) == expected[0]
    assert list(dataset) == expected
def test_iterable_dataset_set_epoch(dataset: IterableDataset):
    assert dataset._epoch == 0
    dataset.set_epoch(42)
    assert dataset._epoch == 42
def test_iterable_dataset_take(dataset: IterableDataset, n):
    take_dataset = dataset.take(n)
    assert isinstance(take_dataset._ex_iterable, TakeExamplesIterable)
    assert take_dataset._ex_iterable.n == n
    assert list(take_dataset) == list(dataset)[:n]
def test_iterable_dataset_skip(dataset: IterableDataset, n):
    skip_dataset = dataset.skip(n)
    assert isinstance(skip_dataset._ex_iterable, SkipExamplesIterable)
    assert skip_dataset._ex_iterable.n == n
    assert list(skip_dataset) == list(dataset)[n:]
def dataset(generate_examples_fn):
    ex_iterable = ExamplesIterable(generate_examples_fn, {})
    return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train")