def test_cut_pairs_sampler_filter(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = CutPairsSampler( cut_set, cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=1000, ) removed_cut_id = "dummy-cut-0010" sampler.filter(lambda cut: cut.id != removed_cut_id) source_cuts, target_cuts = [], [] for src_batch, tgt_batch in sampler: source_cuts.extend(src_batch) target_cuts.extend(tgt_batch) # The filtered cut is not there assert removed_cut_id in set(cut_set.ids) assert removed_cut_id not in set(c.id for c in source_cuts) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet, # minus the filtered item assert len(source_cuts) == len(cut_set) - 1 assert len(target_cuts) == len(cut_set) - 1 # Invariant 2: the items are not duplicated assert len(set(c.id for c in source_cuts)) == len(source_cuts) assert len(set(c.id for c in target_cuts)) == len(target_cuts)
def test_cut_pairs_sampler_len(): # total duration is 55 seconds # each second has 100 frames cuts = CutSet.from_cuts(dummy_cut(idx, duration=float(idx)) for idx in range(1, 11)) sampler = CutPairsSampler( source_cuts=cuts, target_cuts=cuts, shuffle=True, max_source_frames=10 * 100, max_target_frames=10 * 100, ) for epoch in range(5): assert len(sampler) == len([batch for batch in sampler]) sampler.set_epoch(epoch)
def test_cut_pairs_sampler(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=1000, max_target_frames=500, ) source_cuts, target_cuts = [], [] for src_batch, tgt_batch in sampler: source_cuts.extend(src_batch) target_cuts.extend(tgt_batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(source_cuts) == len(cut_set) assert len(target_cuts) == len(cut_set) # Invariant 2: the items are not duplicated assert len(set(c.id for c in source_cuts)) == len(source_cuts) assert len(set(c.id for c in target_cuts)) == len(target_cuts) # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet assert [c.id for c in source_cuts] != [c.id for c in cut_set] # Invariant 4: the source and target cuts are in the same order assert [c.id for c in source_cuts] == [c.id for c in target_cuts]
def test_cut_pairs_sampler_order_is_deterministic_given_epoch(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=1000, max_target_frames=500, ) sampler.set_epoch(42) # calling the sampler twice without epoch update gives identical ordering assert [item for item in sampler] == [item for item in sampler]
def test_zip_sampler_cut_pairs_merge_batches_true(): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) sampler = ZipSampler( # Note: each cut is 1s duration in this test. CutPairsSampler(cuts1, cuts1, max_source_duration=10), CutPairsSampler(cuts2, cuts2, max_source_duration=2), ) batches = [b for b in sampler] assert len(batches) == 10 for idx, (batch_src, batch_tgt) in enumerate(batches): assert len(batch_src) == len(batch_tgt) assert len(batch_src) == 12 # twelve 1s items assert ( len([c for c in batch_src if 0 <= int(c.id.split("-")[-1]) <= 100]) == 10 ) # ten come from cuts1 assert ( len([c for c in batch_src if 1000 <= int(c.id.split("-")[-1]) <= 1100]) == 2 ) # two come from cuts2
def test_cut_pairs_sampler_order_differs_between_epochs(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=1000, max_target_frames=500, ) last_order = [item for item in sampler] for epoch in range(1, 6): sampler.set_epoch(epoch) new_order = [item for item in sampler] assert new_order != last_order last_order = new_order
def test_cut_pairs_sampler_2(): cut_set = CutSet.from_cuts([ dummy_cut(0, duration=10), dummy_cut(1, duration=20), ]) sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, max_source_duration=50, max_target_duration=50, ) batch = next(iter(sampler)) assert len(batch) == 2
def test_cut_pairs_sampler_time_constraints(max_duration, max_frames, max_samples, exception_expectation): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) if max_frames is None: cut_set = cut_set.drop_features() with exception_expectation: sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=max_frames, max_target_frames=max_frames / 2 if max_frames is not None else None, max_source_samples=max_samples, max_target_samples=max_samples / 2 if max_samples is not None else None, max_source_duration=max_duration, max_target_duration=max_duration / 2 if max_duration is not None else None, ) sampler_cut_ids = [] for batch in sampler: sampler_cut_ids.extend(batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(sampler_cut_ids) == len(cut_set) # Invariant 2: the items are not duplicated assert len(set(sampler_cut_ids)) == len(sampler_cut_ids) # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet assert sampler_cut_ids != [c.id for c in cut_set]
expected_num_cuts = 50 expected_discarded_cuts = 0 num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts @pytest.mark.parametrize( "sampler", [ SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), CutPairsSampler( DummyManifest(CutSet, begin_id=0, end_id=10), DummyManifest(CutSet, begin_id=0, end_id=10), ), BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), ZipSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ), ], ) def test_sampler_get_report(sampler): _ = [b for b in sampler] print(sampler.get_report()) # It runs - voila! @pytest.mark.parametrize(