Ejemplo n.º 1
0
def test_bucketing_sampler_cut_pairs_equal_duration(shuffle):
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (
            3 + idx * 1 / 50
        )  # each cut has a different duration between [3, 23]
    # Target CutSet is going to have different durations
    # -- make sure the bucketing works well with that.
    cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration))

    sampler = BucketingSampler(
        cut_set,
        cut_set_tgt,
        sampler_type=CutPairsSampler,
        bucket_method="equal_duration",
        num_buckets=10,
        shuffle=shuffle,
    )

    # Ensure that each consecutive bucket has less cuts than the previous one
    prev_len = float("inf")
    bucket_cum_durs = []
    for bucket_src, bucket_tgt in sampler.buckets:
        assert list(bucket_src.ids) == list(bucket_tgt.ids)
        bucket_cum_durs.append(sum(c.duration for c in bucket_src))
        curr_len = len(bucket_src)
        assert curr_len < prev_len
        prev_len = curr_len

    # Assert that all bucket cumulative durations are within 1/10th of the mean
    mean_bucket_dur = mean(bucket_cum_durs)  # ~ 1300s
    for d in bucket_cum_durs:
        assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
Ejemplo n.º 2
0
def test_bucketing_sampler_cut_pairs_equal_len(shuffle):
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (
            3 + idx * 1 / 50
        )  # each cut has a different duration between [3, 23]
    # Target CutSet is going to have different durations
    # -- make sure the bucketing works well with that.
    cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration))

    sampler = BucketingSampler(
        cut_set,
        cut_set_tgt,
        sampler_type=CutPairsSampler,
        bucket_method="equal_len",
        num_buckets=10,
        shuffle=shuffle,
    )

    bucket_cum_durs = []
    for bucket_src, bucket_tgt in sampler.buckets:
        bucket_cum_durs.append(sum(c.duration for c in bucket_src))
        assert len(bucket_src) == 100
        assert list(bucket_src.ids) == list(bucket_tgt.ids)

    # The variations in duration are over 10% of the mean bucket duration (because of equal lengths).
    mean_bucket_dur = mean(bucket_cum_durs)
    assert not all(
        abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur for d in bucket_cum_durs
    )
Ejemplo n.º 3
0
def test_map(manifest_type):

    expected = DummyManifest(manifest_type, begin_id=0, end_id=10)
    for item in expected:
        item.duration = 3.14

    def transform_fn(item):
        item.duration = 3.14
        return item

    data = DummyManifest(manifest_type, begin_id=0, end_id=10)
    eager_result = data.map(transform_fn)
    assert list(eager_result) == list(expected)

    with as_lazy(data) as lazy_data:
        lazy_result = lazy_data.map(transform_fn)
        assert list(lazy_result) == list(expected)