def test_bucketing_sampler_order_is_deterministic_given_epoch(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler) sampler.set_epoch(42) # calling the sampler twice without epoch update gives identical ordering assert [item for item in sampler] == [item for item in sampler]
def test_bucketing_sampler_order_differs_between_epochs(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler) last_order = [item for item in sampler] for epoch in range(1, 6): sampler.set_epoch(epoch) new_order = [item for item in sampler] assert new_order != last_order last_order = new_order
def test_bucketing_sampler_single_cuts_equal_duration(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = (3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, bucket_method="equal_duration", num_buckets=10, ) # Ensure that each consecutive bucket has less cuts than the previous one sampled_cuts, bucket_cum_durs = [], [] prev_min, prev_max = 0, 0 num_overlapping_bins = 0 for (bucket, ) in sampler.buckets: bucket_durs = [c.duration for c in bucket] sampled_cuts.extend(c for c in bucket) bucket_cum_durs.append(sum(bucket_durs)) bucket_min, bucket_max = min(bucket_durs), max(bucket_durs) # Ensure that bucket lengths do not overlap, except for the middle # 3 buckets maybe if prev_max > bucket_min: num_overlapping_bins += 1 assert num_overlapping_bins < 3 prev_min = bucket_min prev_max = bucket_max # Assert that all bucket cumulative durations are within 1/10th of the mean mean_bucket_dur = mean(bucket_cum_durs) # ~ 1300s for d in bucket_cum_durs: assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
def test_bucketing_sampler_single_cuts(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SingleCutSampler) cut_ids = [] for batch in sampler: cut_ids.extend(batch) assert set(cut_set.ids) == set(cut_ids)
def test_bucketing_sampler_buckets_have_different_durations(): cut_set_1s = DummyManifest(CutSet, begin_id=0, end_id=10) cut_set_2s = DummyManifest(CutSet, begin_id=10, end_id=20) for c in cut_set_2s: c.duration = 2.0 cut_set = cut_set_1s + cut_set_2s # The bucketing sampler should return 5 batches with two 1s cuts, and 10 batches with one 2s cut. sampler = BucketingSampler( cut_set, sampler_type=SingleCutSampler, max_frames=200, num_buckets=2 ) batches = list(sampler) assert len(batches) == 15 # All cuts have the same durations (i.e. are from the same bucket in this case) for batch in batches: batch_durs = [cut_set[cid].duration for cid in batch] assert all(d == batch_durs[0] for d in batch_durs) batches = sorted(batches, key=len) assert all(len(b) == 1 for b in batches[:10]) assert all(len(b) == 2 for b in batches[10:])
def test_bucketing_sampler_single_cuts_equal_duration(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = ( 3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, bucket_method="equal_duration", num_buckets=10, ) # Ensure that each consecutive bucket has less cuts than the previous one prev_len = float("inf") bucket_cum_durs = [] for (bucket,) in sampler.buckets: bucket_cum_durs.append(sum(c.duration for c in bucket)) curr_len = len(bucket) assert curr_len < prev_len prev_len = curr_len # Assert that all bucket cumulative durations are within 1/10th of the mean mean_bucket_dur = mean(bucket_cum_durs) # ~ 1300s for d in bucket_cum_durs: assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
def test_bucketing_sampler_len(): # total duration is 550 seconds # each second has 100 frames cuts = CutSet.from_cuts( dummy_cut(idx, duration=float(duration)) for idx, duration in enumerate(list(range(1, 11)) * 10)) sampler = BucketingSampler(cuts, num_buckets=4, shuffle=True, max_frames=64 * 100, max_cuts=6) for epoch in range(5): assert len(sampler) == len([item for item in sampler]) sampler.set_epoch(epoch)
def test_bucketing_sampler_time_constraints(constraint): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler, **constraint) sampled_cuts = [] for batch in sampler: sampled_cuts.extend(batch) assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
def test_bucketing_sampler_cut_pairs_equal_duration(shuffle): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = ( 3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] # Target CutSet is going to have different durations # -- make sure the bucketing works well with that. cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration)) sampler = BucketingSampler( cut_set, cut_set_tgt, sampler_type=CutPairsSampler, bucket_method="equal_duration", num_buckets=10, shuffle=shuffle, ) # Ensure that each consecutive bucket has less cuts than the previous one prev_len = float("inf") bucket_cum_durs = [] for bucket_src, bucket_tgt in sampler.buckets: assert list(bucket_src.ids) == list(bucket_tgt.ids) bucket_cum_durs.append(sum(c.duration for c in bucket_src)) curr_len = len(bucket_src) assert curr_len < prev_len prev_len = curr_len # Assert that all bucket cumulative durations are within 1/10th of the mean mean_bucket_dur = mean(bucket_cum_durs) # ~ 1300s for d in bucket_cum_durs: assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
def test_bucketing_sampler_chooses_buckets_randomly(): # Construct a CutSet that has 1000 cuts with 100 unique durations. # Makes it simple to track which bucket was selected. cut_set = CutSet({}) # empty for i in range(100): new_cuts = DummyManifest(CutSet, begin_id=i * 10, end_id=(i + 1) * 10) for c in new_cuts: c.duration = i cut_set = cut_set + new_cuts # Sampler that always select one cut. sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, max_cuts=1, max_frames=1000000000, num_buckets=100, ) # Batches of 1 guarantee that item is always a single-element list of cut IDs. durations = [cut_set[item[0].id].duration for item in sampler] # This is the "trick" part - 'groupby' groups the cuts together by their duration. # If there is a group that has a size of 10, that means the same bucket was chosen # for 10 consecutive batches, which is not what BucketingSampler is supposed to do # (the probability of that is extremely low). # We're actually setting that threshold lower to 8 which should never be triggered # anyway. lens = [] for key, group in groupby(durations): lens.append(len(list(group))) assert all(l < 8 for l in lens) print(lens)
def test_bucketing_sampler_cut_pairs_equal_len(shuffle): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = ( 3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] # Target CutSet is going to have different durations # -- make sure the bucketing works well with that. cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration)) sampler = BucketingSampler( cut_set, cut_set_tgt, sampler_type=CutPairsSampler, bucket_method="equal_len", num_buckets=10, shuffle=shuffle, ) bucket_cum_durs = [] for bucket_src, bucket_tgt in sampler.buckets: bucket_cum_durs.append(sum(c.duration for c in bucket_src)) assert len(bucket_src) == 100 assert list(bucket_src.ids) == list(bucket_tgt.ids) # The variations in duration are over 10% of the mean bucket duration (because of equal lengths). mean_bucket_dur = mean(bucket_cum_durs) assert not all( abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur for d in bucket_cum_durs )
def test_bucketing_sampler_single_cuts_no_proportional_sampling(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler( cut_set, proportional_sampling=False, sampler_type=SimpleCutSampler ) sampled_cuts = [] for batch in sampler: sampled_cuts.extend(batch) assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
def test_bucketing_sampler_raises_value_error_on_lazy_cuts_input(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=2) with NamedTemporaryFile(suffix=".jsonl") as f: cut_set.to_jsonl(f.name) lazy_cuts = CutSet.from_jsonl_lazy(f.name) with pytest.raises(ValueError): sampler = BucketingSampler( lazy_cuts, max_duration=10.0, )
def test_bucketing_sampler_cut_pairs(): cut_set1 = DummyManifest(CutSet, begin_id=0, end_id=1000) cut_set2 = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set1, cut_set2, sampler_type=CutPairsSampler) cut_ids = [] for batch in sampler: cut_ids.extend(batch) assert set(cut_set1.ids) == set(cut_ids) assert set(cut_set2.ids) == set(cut_ids)
def test_bucketing_sampler_cut_pairs(): cut_set1 = DummyManifest(CutSet, begin_id=0, end_id=1000) cut_set2 = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set1, cut_set2, sampler_type=CutPairsSampler) src_cuts, tgt_cuts = [], [] for src_batch, tgt_batch in sampler: src_cuts.extend(src_batch) tgt_cuts.extend(tgt_batch) assert set(cut_set1.ids) == set(c.id for c in src_cuts) assert set(cut_set2.ids) == set(c.id for c in tgt_cuts)
def test_bucketing_sampler_shuffle(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=10) sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, shuffle=True, num_buckets=2, max_frames=200, ) sampler.set_epoch(0) batches_ep0 = [] for batch in sampler: # Convert List[str] to Tuple[str, ...] so that it's hashable batches_ep0.append(tuple(c.id for c in batch)) assert set(cut_set.ids) == set(cid for batch in batches_ep0 for cid in batch) sampler.set_epoch(1) batches_ep1 = [] for batch in sampler: batches_ep1.append(tuple(c.id for c in batch)) assert set(cut_set.ids) == set(cid for batch in batches_ep1 for cid in batch) # BucketingSampler ordering may be different in different epochs (=> use set() to make it irrelevant) # Internal sampler (SimpleCutSampler) ordering should be different in different epochs assert set(batches_ep0) != set(batches_ep1)
def test_bucketing_sampler_drop_last(drop_last): # CutSet that has 50 cuts: 10 have 1s, 10 have 2s, etc. cut_set = CutSet() for i in range(5): new_cuts = DummyManifest(CutSet, begin_id=i * 10, end_id=(i + 1) * 10) for c in new_cuts: c.duration = i + 1 cut_set = cut_set + new_cuts # Sampler that always select one cut. sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, max_duration=10.5, num_buckets=5, drop_last=drop_last, ) batches = [] for batch in sampler: # Assert there is a consistent cut duration per bucket in this test. for cut in batch: assert cut.duration == batch[0].duration batches.append(batch) # Expectation: if drop_last: # When drop_last = True: # 10 x 1s cuts == 1 batch (10 cuts each, 0 left over) # 10 x 2s cuts == 2 batches (5 cuts each, 0 left over) # 10 x 3s cuts == 3 batches (3 cuts each, 1 left over) # 10 x 4s cuts == 5 batches (2 cuts each, 0 left over) # 10 x 5s cuts == 5 batches (2 cuts each, 0 left over) expected_num_batches = 16 expected_num_cuts = 49 expected_discarded_cuts = 1 else: # When drop_last = False: # There will be one more batch with a single 3s cut. expected_num_batches = 17 expected_num_cuts = 50 expected_discarded_cuts = 0 num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts
def test_bucketing_sampler_single_cuts_equal_len(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = (3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] sampler = BucketingSampler( cut_set, sampler_type=SingleCutSampler, bucket_method="equal_len", num_buckets=10, ) bucket_cum_durs = [] for (bucket, ) in sampler.buckets: bucket_cum_durs.append(sum(c.duration for c in bucket)) assert len(bucket) == 100 # The variations in duration are over 10% of the mean bucket duration (because of equal lengths). mean_bucket_dur = mean(bucket_cum_durs) assert not all( abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur for d in bucket_cum_durs)
num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts SAMPLERS_FACTORIES_FOR_REPORT_TEST = [ lambda: SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), lambda: DynamicCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), lambda: CutPairsSampler( DummyManifest(CutSet, begin_id=0, end_id=10), DummyManifest(CutSet, begin_id=0, end_id=10), ), lambda: BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10), num_buckets=2), lambda: DynamicBucketingSampler( DummyManifest(CutSet, begin_id=0, end_id=10), max_duration=1.0, num_buckets=2, ), lambda: ZipSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ), lambda: RoundRobinSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ), ]
num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts @pytest.mark.parametrize( "sampler", [ SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), CutPairsSampler( DummyManifest(CutSet, begin_id=0, end_id=10), DummyManifest(CutSet, begin_id=0, end_id=10), ), BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), ZipSampler( SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)), SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)), ), ], ) def test_sampler_get_report(sampler): _ = [b for b in sampler] print(sampler.get_report()) # It runs - voila! @pytest.mark.parametrize( "sampler_cls", [