def test_bucketing_sampler_cut_pairs_equal_duration(shuffle): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = ( 3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] # Target CutSet is going to have different durations # -- make sure the bucketing works well with that. cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration)) sampler = BucketingSampler( cut_set, cut_set_tgt, sampler_type=CutPairsSampler, bucket_method="equal_duration", num_buckets=10, shuffle=shuffle, ) # Ensure that each consecutive bucket has less cuts than the previous one prev_len = float("inf") bucket_cum_durs = [] for bucket_src, bucket_tgt in sampler.buckets: assert list(bucket_src.ids) == list(bucket_tgt.ids) bucket_cum_durs.append(sum(c.duration for c in bucket_src)) curr_len = len(bucket_src) assert curr_len < prev_len prev_len = curr_len # Assert that all bucket cumulative durations are within 1/10th of the mean mean_bucket_dur = mean(bucket_cum_durs) # ~ 1300s for d in bucket_cum_durs: assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
def test_bucketing_sampler_cut_pairs_equal_len(shuffle): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) for idx, c in enumerate(cut_set): c.duration = ( 3 + idx * 1 / 50 ) # each cut has a different duration between [3, 23] # Target CutSet is going to have different durations # -- make sure the bucketing works well with that. cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration)) sampler = BucketingSampler( cut_set, cut_set_tgt, sampler_type=CutPairsSampler, bucket_method="equal_len", num_buckets=10, shuffle=shuffle, ) bucket_cum_durs = [] for bucket_src, bucket_tgt in sampler.buckets: bucket_cum_durs.append(sum(c.duration for c in bucket_src)) assert len(bucket_src) == 100 assert list(bucket_src.ids) == list(bucket_tgt.ids) # The variations in duration are over 10% of the mean bucket duration (because of equal lengths). mean_bucket_dur = mean(bucket_cum_durs) assert not all( abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur for d in bucket_cum_durs )
def test_map(manifest_type): expected = DummyManifest(manifest_type, begin_id=0, end_id=10) for item in expected: item.duration = 3.14 def transform_fn(item): item.duration = 3.14 return item data = DummyManifest(manifest_type, begin_id=0, end_id=10) eager_result = data.map(transform_fn) assert list(eager_result) == list(expected) with as_lazy(data) as lazy_data: lazy_result = lazy_data.map(transform_fn) assert list(lazy_result) == list(expected)