def test_examples_iterable_shuffle_data_sources(generate_examples_fn): ex_iterable = ExamplesIterable(generate_examples_fn, {"filepaths": ["0.txt", "1.txt"]}) ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(40)) expected = list(generate_examples_fn( filepaths=["1.txt", "0.txt"])) # shuffle the filepaths assert list(ex_iterable) == expected
def test_examples_iterable_shuffle_shards_and_metadata(): def gen(filepaths, all_metadata): for i, (filepath, metadata) in enumerate(zip(filepaths, all_metadata)): yield i, {"filepath": filepath, "metadata": metadata} ex_iterable = ExamplesIterable( gen, { "filepaths": [f"{i}.txt" for i in range(100)], "all_metadata": [{"id": str(i)} for i in range(100)], }, ) ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(42)) out = list(ex_iterable) filepaths_ids = [x["filepath"].split(".")[0] for _, x in out] metadata_ids = [x["metadata"]["id"] for _, x in out] assert filepaths_ids == metadata_ids, "entangled lists of shards/metadata should be shuffled the same way"