def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length):
    d1 = dataset
    d2 = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
    d3 = dataset.with_format("python")
    datasets = [d1, d2, d3]
    merged_dataset = interleave_datasets(datasets, probabilities=probas, seed=seed)
    # Check the examples iterable
    assert isinstance(
        merged_dataset._ex_iterable, (CyclingMultiSourcesExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable)
    )
    # Check that it is deterministic
    if seed is not None:
        merged_dataset2 = interleave_datasets([d1, d2, d3], probabilities=probas, seed=seed)
        assert list(merged_dataset) == list(merged_dataset2)
    # Check first example
    if seed is not None:
        rng = np.random.default_rng(seed)
        i = next(iter(RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas)))
        assert next(iter(merged_dataset)) == next(iter(datasets[i]))
    else:
        assert any(next(iter(merged_dataset)) == next(iter(dataset)) for dataset in datasets)
    # Compute length it case it's random
    if expected_length is None:
        expected_length = 0
        counts = [len(list(d)) for d in datasets]
        rng = np.random.default_rng(seed)
        for i in RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas):
            if counts[i] == 0:
                break
            counts[i] -= 1
            expected_length += 1
    # Check length
    assert len(list(merged_dataset)) == expected_length
def test_interleave_datasets_with_features(dataset: IterableDataset, generate_examples_fn):
    features = Features(
        {
            "id": Value("int64"),
            "label": ClassLabel(names=["negative", "positive"]),
        }
    )
    ex_iterable = ExamplesIterable(generate_examples_fn, {"label": 0})
    dataset_with_features = IterableDataset(ex_iterable, info=DatasetInfo(features=features))

    merged_dataset = interleave_datasets([dataset, dataset_with_features], probabilities=[0, 1])
    assert isinstance(merged_dataset._ex_iterable, CyclingMultiSourcesExamplesIterable)
    assert isinstance(merged_dataset._ex_iterable.ex_iterables[1], TypedExamplesIterable)
    assert merged_dataset._ex_iterable.ex_iterables[1].features == features
    assert next(iter(merged_dataset)) == next(iter(dataset_with_features))