def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length): d1 = dataset d2 = dataset.map(lambda x: {"id+1": x["id"] + 1, **x}) d3 = dataset.with_format("python") datasets = [d1, d2, d3] merged_dataset = interleave_datasets(datasets, probabilities=probas, seed=seed) # Check the examples iterable assert isinstance( merged_dataset._ex_iterable, (CyclingMultiSourcesExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable) ) # Check that it is deterministic if seed is not None: merged_dataset2 = interleave_datasets([d1, d2, d3], probabilities=probas, seed=seed) assert list(merged_dataset) == list(merged_dataset2) # Check first example if seed is not None: rng = np.random.default_rng(seed) i = next(iter(RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas))) assert next(iter(merged_dataset)) == next(iter(datasets[i])) else: assert any(next(iter(merged_dataset)) == next(iter(dataset)) for dataset in datasets) # Compute length it case it's random if expected_length is None: expected_length = 0 counts = [len(list(d)) for d in datasets] rng = np.random.default_rng(seed) for i in RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas): if counts[i] == 0: break counts[i] -= 1 expected_length += 1 # Check length assert len(list(merged_dataset)) == expected_length
def test_iterable_dataset_map(dataset: IterableDataset, generate_examples_fn): func = lambda x: {"id+1": x["id"] + 1} # noqa: E731 mapped_dataset = dataset.map(func) assert isinstance(mapped_dataset._ex_iterable, MappedExamplesIterable) assert mapped_dataset._ex_iterable.function is func assert mapped_dataset._ex_iterable.batched is False assert next(iter(mapped_dataset)) == {**next(iter(dataset)), **func(next(iter(generate_examples_fn()))[1])}
def test_iterable_dataset_map_batched(dataset: IterableDataset, generate_examples_fn): func = lambda x: {"id+1": [i + 1 for i in x["id"]]} # noqa: E731 batch_size = 3 dataset = dataset.map(func, batched=True, batch_size=batch_size) assert isinstance(dataset._ex_iterable, MappedExamplesIterable) assert dataset._ex_iterable.function is func assert dataset._ex_iterable.batch_size == batch_size assert next(iter(dataset)) == {"id": 0, "id+1": 1}
def test_iterable_dataset_map_complex_features(dataset: IterableDataset, generate_examples_fn): # https://github.com/huggingface/datasets/issues/3505 ex_iterable = ExamplesIterable(generate_examples_fn, {"label": "positive"}) features = Features( { "id": Value("int64"), "label": Value("string"), } ) dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features)) dataset = dataset.cast_column("label", ClassLabel(names=["negative", "positive"])) dataset = dataset.map(lambda x: {"id+1": x["id"] + 1, **x}) assert isinstance(dataset._ex_iterable, MappedExamplesIterable) features["label"] = ClassLabel(names=["negative", "positive"]) assert [{k: v for k, v in ex.items() if k != "id+1"} for ex in dataset] == [ features.encode_example(ex) for _, ex in ex_iterable ]