def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length):
    d1 = dataset
    d2 = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
    d3 = dataset.with_format("python")
    datasets = [d1, d2, d3]
    merged_dataset = interleave_datasets(datasets, probabilities=probas, seed=seed)
    # Check the examples iterable
    assert isinstance(
        merged_dataset._ex_iterable, (CyclingMultiSourcesExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable)
    )
    # Check that it is deterministic
    if seed is not None:
        merged_dataset2 = interleave_datasets([d1, d2, d3], probabilities=probas, seed=seed)
        assert list(merged_dataset) == list(merged_dataset2)
    # Check first example
    if seed is not None:
        rng = np.random.default_rng(seed)
        i = next(iter(RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas)))
        assert next(iter(merged_dataset)) == next(iter(datasets[i]))
    else:
        assert any(next(iter(merged_dataset)) == next(iter(dataset)) for dataset in datasets)
    # Compute length it case it's random
    if expected_length is None:
        expected_length = 0
        counts = [len(list(d)) for d in datasets]
        rng = np.random.default_rng(seed)
        for i in RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas):
            if counts[i] == 0:
                break
            counts[i] -= 1
            expected_length += 1
    # Check length
    assert len(list(merged_dataset)) == expected_length
def test_iterable_dataset_map(dataset: IterableDataset, generate_examples_fn):
    func = lambda x: {"id+1": x["id"] + 1}  # noqa: E731
    mapped_dataset = dataset.map(func)
    assert isinstance(mapped_dataset._ex_iterable, MappedExamplesIterable)
    assert mapped_dataset._ex_iterable.function is func
    assert mapped_dataset._ex_iterable.batched is False
    assert next(iter(mapped_dataset)) == {**next(iter(dataset)), **func(next(iter(generate_examples_fn()))[1])}
def test_iterable_dataset_map_batched(dataset: IterableDataset, generate_examples_fn):
    func = lambda x: {"id+1": [i + 1 for i in x["id"]]}  # noqa: E731
    batch_size = 3
    dataset = dataset.map(func, batched=True, batch_size=batch_size)
    assert isinstance(dataset._ex_iterable, MappedExamplesIterable)
    assert dataset._ex_iterable.function is func
    assert dataset._ex_iterable.batch_size == batch_size
    assert next(iter(dataset)) == {"id": 0, "id+1": 1}
def test_iterable_dataset_map_complex_features(dataset: IterableDataset, generate_examples_fn):
    # https://github.com/huggingface/datasets/issues/3505
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": "positive"})
    features = Features(
        {
            "id": Value("int64"),
            "label": Value("string"),
        }
    )
    dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features))
    dataset = dataset.cast_column("label", ClassLabel(names=["negative", "positive"]))
    dataset = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
    assert isinstance(dataset._ex_iterable, MappedExamplesIterable)
    features["label"] = ClassLabel(names=["negative", "positive"])
    assert [{k: v for k, v in ex.items() if k != "id+1"} for ex in dataset] == [
        features.encode_example(ex) for _, ex in ex_iterable
    ]