def test_zip_sampler_merge_batches_true(): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) sampler = ZipSampler( # Note: each cut is 1s duration in this test. SingleCutSampler(cuts1, max_duration=10), SingleCutSampler(cuts2, max_duration=2), ) batches = [b for b in sampler] assert len(batches) == 10 for idx, batch in enumerate(batches): assert len(batch) == 12 # twelve 1s items assert (len([c for c in batch if 0 <= int(c.id.split("-")[-1]) <= 100 ]) == 10) # ten come from cuts1 assert (len([ c for c in batch if 1000 <= int(c.id.split("-")[-1]) <= 1100 ]) == 2) # two come from cuts2
def test_zip_sampler_merge_batches_false(): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) sampler = ZipSampler( # Note: each cut is 1s duration in this test. SimpleCutSampler(cuts1, max_duration=10), SimpleCutSampler(cuts2, max_duration=2), merge_batches=False, ) batches = [b for b in sampler] assert len(batches) == 10 for idx, (batch_sampler1, batch_sampler2) in enumerate(batches): assert len(batch_sampler1) == 10 assert (len([ c for c in batch_sampler1 if 0 <= int(c.id.split("-")[-1]) <= 100 ]) == 10) # ten come from cuts1 assert len(batch_sampler2) == 2 assert (len([ c for c in batch_sampler2 if 1000 <= int(c.id.split("-")[-1]) <= 1100 ]) == 2) # two come from cuts2
assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts @pytest.mark.parametrize( "sampler", [ SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), CutPairsSampler( DummyManifest(CutSet, begin_id=0, end_id=10), DummyManifest(CutSet, begin_id=0, end_id=10), ), BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), ZipSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ), ], ) def test_sampler_get_report(sampler): _ = [b for b in sampler] print(sampler.get_report()) # It runs - voila! @pytest.mark.parametrize( "sampler_cls", [ SimpleCutSampler, BucketingSampler, ],