def test_split_odd(manifest_type): manifest = DummyManifest(manifest_type, begin_id=0, end_id=100) manifest_subsets = manifest.split(num_splits=3) assert len(manifest_subsets) == 3 assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=34) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=34, end_id=68) assert manifest_subsets[2] == DummyManifest(manifest_type, begin_id=68, end_id=100)
def test_split_randomize(manifest_type): manifest = DummyManifest(manifest_type, begin_id=0, end_id=100) manifest_subsets = manifest.split(num_splits=2, shuffle=True) assert len(manifest_subsets) == 2 recombined_items = list(manifest_subsets[0]) + list(manifest_subsets[1]) assert len(recombined_items) == len(manifest) # Different ordering (we convert to lists first because the *Set classes might internally # re-order after concatenation, e.g. by using dict or post-init sorting) assert recombined_items != list(manifest)
def test_split_even(manifest_type): manifest = DummyManifest(manifest_type, begin_id=0, end_id=100) manifest_subsets = manifest.split(num_splits=2) assert len(manifest_subsets) == 2 assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=50) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=50, end_id=100)
def test_sequential_jsonl_writer_overwrite(overwrite): cuts = DummyManifest(CutSet, begin_id=0, end_id=100) half = cuts.split(num_splits=2)[0] with NamedTemporaryFile(suffix='.jsonl') as jsonl_f: # Store the first half half.to_file(jsonl_f.name) # Open sequential writer with CutSet.open_writer(jsonl_f.name, overwrite=overwrite) as writer: if overwrite: assert all(not writer.contains(id_) for id_ in half.ids) else: assert all(writer.contains(id_) for id_ in half.ids)
def test_split_odd_2(manifest_type, drop_last): manifest = DummyManifest(manifest_type, begin_id=0, end_id=32) manifest_subsets = manifest.split(num_splits=3, drop_last=drop_last) assert len(manifest_subsets) == 3 if drop_last: assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=10) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=10, end_id=20) assert manifest_subsets[2] == DummyManifest(manifest_type, begin_id=20, end_id=30) else: assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=11) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=11, end_id=22) assert manifest_subsets[2] == DummyManifest(manifest_type, begin_id=22, end_id=32)
def test_split_odd_1(manifest_type, drop_last): manifest = DummyManifest(manifest_type, begin_id=0, end_id=100) manifest_subsets = manifest.split(num_splits=3, drop_last=drop_last) assert len(manifest_subsets) == 3 if drop_last: assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=33) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=33, end_id=66) assert manifest_subsets[2] == DummyManifest(manifest_type, begin_id=66, end_id=99) else: assert manifest_subsets[0] == DummyManifest(manifest_type, begin_id=0, end_id=34) assert manifest_subsets[1] == DummyManifest(manifest_type, begin_id=34, end_id=67) assert manifest_subsets[2] == DummyManifest(manifest_type, begin_id=67, end_id=100)
def test_cannot_split_to_more_chunks_than_items(manifest_type): manifest = DummyManifest(manifest_type, begin_id=0, end_id=1) with pytest.raises(ValueError): manifest.split(num_splits=10)