def test_single_cut_sampler_time_constraints( max_duration, max_frames, max_samples, exception_expectation ): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) if max_frames is None: cut_set = cut_set.drop_features() with exception_expectation: sampler = SimpleCutSampler( cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_frames=max_frames, max_samples=max_samples, max_duration=max_duration, ) sampler_cut_ids = [] for batch in sampler: sampler_cut_ids.extend(batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(sampler_cut_ids) == len(cut_set) # Invariant 2: the items are not duplicated assert len(set(c.id for c in sampler_cut_ids)) == len(sampler_cut_ids) # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet assert [c.id for c in sampler_cut_ids] != [c.id for c in cut_set]
def test_cut_pairs_sampler_time_constraints(max_duration, max_frames, max_samples, exception_expectation): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) if max_frames is None: cut_set = cut_set.drop_features() with exception_expectation: sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=max_frames, max_target_frames=max_frames / 2 if max_frames is not None else None, max_source_samples=max_samples, max_target_samples=max_samples / 2 if max_samples is not None else None, max_source_duration=max_duration, max_target_duration=max_duration / 2 if max_duration is not None else None, ) source_cuts, target_cuts = [], [] for src_batch, tgt_batch in sampler: source_cuts.extend(src_batch) target_cuts.extend(tgt_batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(source_cuts) == len(cut_set) assert len(target_cuts) == len(cut_set) # Invariant 2: the items are not duplicated assert len(set(c.id for c in source_cuts)) == len(source_cuts) assert len(set(c.id for c in target_cuts)) == len(target_cuts) # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet assert [c.id for c in source_cuts] != [c.id for c in cut_set] # Invariant 4: the source and target cuts are in the same order assert [c.id for c in source_cuts] == [c.id for c in target_cuts]