def test_dynamic_cut_sampler_as_cut_pairs_sampler(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = DynamicCutSampler( cut_set, cut_set, shuffle=True, max_duration=5.0, ) source_cuts, target_cuts = [], [] for src_batch, tgt_batch in sampler: source_cuts.extend(src_batch) target_cuts.extend(tgt_batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(source_cuts) == len(cut_set) assert len(target_cuts) == len(cut_set) # Invariant 2: the items are not duplicated assert len(set(c.id for c in source_cuts)) == len(source_cuts) assert len(set(c.id for c in target_cuts)) == len(target_cuts) # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet assert [c.id for c in source_cuts] != [c.id for c in cut_set] # Invariant 4: the source and target cuts are in the same order assert [c.id for c in source_cuts] == [c.id for c in target_cuts]
drop_last=True, num_buckets=2, sampler_type=CutPairsSampler, ), DynamicBucketingSampler(CUTS, max_duration=10.0, shuffle=True, drop_last=True, num_buckets=2), DynamicBucketingSampler(CUTS, CUTS_MOD, max_duration=10.0, shuffle=True, drop_last=True, num_buckets=2), DynamicCutSampler(CUTS, max_duration=10.0, shuffle=True, drop_last=True), DynamicCutSampler(CUTS, CUTS, max_duration=10.0, shuffle=True, drop_last=True), ] @pytest.mark.parametrize("sampler", SAMPLERS_TO_TEST) def test_sampler_pickling(sampler): with NamedTemporaryFile(mode="w+b", suffix=".pkl") as f: pickle.dump(sampler, f) f.flush() f.seek(0) restored = pickle.load(f)
# When drop_last = False: # There will be one more batch with a single 3s cut. expected_num_batches = 17 expected_num_cuts = 50 expected_discarded_cuts = 0 num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts SAMPLERS_FACTORIES_FOR_REPORT_TEST = [ lambda: SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), lambda: DynamicCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), lambda: CutPairsSampler( DummyManifest(CutSet, begin_id=0, end_id=10), DummyManifest(CutSet, begin_id=0, end_id=10), ), lambda: BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10), num_buckets=2), lambda: DynamicBucketingSampler( DummyManifest(CutSet, begin_id=0, end_id=10), max_duration=1.0, num_buckets=2, ), lambda: ZipSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ),