Ejemplo n.º 1
0
def test_bucketing_sampler_order_is_deterministic_given_epoch():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler)

    sampler.set_epoch(42)
    # calling the sampler twice without epoch update gives identical ordering
    assert [item for item in sampler] == [item for item in sampler]
Ejemplo n.º 2
0
def test_bucketing_sampler_order_differs_between_epochs():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler)

    last_order = [item for item in sampler]
    for epoch in range(1, 6):
        sampler.set_epoch(epoch)
        new_order = [item for item in sampler]
        assert new_order != last_order
        last_order = new_order
Ejemplo n.º 3
0
def test_bucketing_sampler_single_cuts_equal_duration():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (3 + idx * 1 / 50
                      )  # each cut has a different duration between [3, 23]
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        bucket_method="equal_duration",
        num_buckets=10,
    )

    # Ensure that each consecutive bucket has less cuts than the previous one
    sampled_cuts, bucket_cum_durs = [], []
    prev_min, prev_max = 0, 0
    num_overlapping_bins = 0
    for (bucket, ) in sampler.buckets:
        bucket_durs = [c.duration for c in bucket]
        sampled_cuts.extend(c for c in bucket)
        bucket_cum_durs.append(sum(bucket_durs))
        bucket_min, bucket_max = min(bucket_durs), max(bucket_durs)
        # Ensure that bucket lengths do not overlap, except for the middle
        # 3 buckets maybe
        if prev_max > bucket_min:
            num_overlapping_bins += 1
        assert num_overlapping_bins < 3
        prev_min = bucket_min
        prev_max = bucket_max

    # Assert that all bucket cumulative durations are within 1/10th of the mean
    mean_bucket_dur = mean(bucket_cum_durs)  # ~ 1300s
    for d in bucket_cum_durs:
        assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur

    assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
Ejemplo n.º 4
0
def test_bucketing_sampler_single_cuts():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SingleCutSampler)
    cut_ids = []
    for batch in sampler:
        cut_ids.extend(batch)
    assert set(cut_set.ids) == set(cut_ids)
Ejemplo n.º 5
0
def test_bucketing_sampler_buckets_have_different_durations():
    cut_set_1s = DummyManifest(CutSet, begin_id=0, end_id=10)
    cut_set_2s = DummyManifest(CutSet, begin_id=10, end_id=20)
    for c in cut_set_2s:
        c.duration = 2.0
    cut_set = cut_set_1s + cut_set_2s

    # The bucketing sampler should return 5 batches with two 1s cuts, and 10 batches with one 2s cut.
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SingleCutSampler,
        max_frames=200,
        num_buckets=2
    )
    batches = list(sampler)
    assert len(batches) == 15

    # All cuts have the same durations (i.e. are from the same bucket in this case)
    for batch in batches:
        batch_durs = [cut_set[cid].duration for cid in batch]
        assert all(d == batch_durs[0] for d in batch_durs)

    batches = sorted(batches, key=len)
    assert all(len(b) == 1 for b in batches[:10])
    assert all(len(b) == 2 for b in batches[10:])
Ejemplo n.º 6
0
def test_bucketing_sampler_single_cuts_equal_duration():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (
            3 + idx * 1 / 50
        )  # each cut has a different duration between [3, 23]
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        bucket_method="equal_duration",
        num_buckets=10,
    )

    # Ensure that each consecutive bucket has less cuts than the previous one
    prev_len = float("inf")
    bucket_cum_durs = []
    for (bucket,) in sampler.buckets:
        bucket_cum_durs.append(sum(c.duration for c in bucket))
        curr_len = len(bucket)
        assert curr_len < prev_len
        prev_len = curr_len

    # Assert that all bucket cumulative durations are within 1/10th of the mean
    mean_bucket_dur = mean(bucket_cum_durs)  # ~ 1300s
    for d in bucket_cum_durs:
        assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
Ejemplo n.º 7
0
def test_bucketing_sampler_len():
    # total duration is 550 seconds
    # each second has 100 frames
    cuts = CutSet.from_cuts(
        dummy_cut(idx, duration=float(duration))
        for idx, duration in enumerate(list(range(1, 11)) * 10))

    sampler = BucketingSampler(cuts,
                               num_buckets=4,
                               shuffle=True,
                               max_frames=64 * 100,
                               max_cuts=6)

    for epoch in range(5):
        assert len(sampler) == len([item for item in sampler])
        sampler.set_epoch(epoch)
Ejemplo n.º 8
0
def test_bucketing_sampler_time_constraints(constraint):
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler, **constraint)
    sampled_cuts = []
    for batch in sampler:
        sampled_cuts.extend(batch)
    assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
Ejemplo n.º 9
0
def test_bucketing_sampler_cut_pairs_equal_duration(shuffle):
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (
            3 + idx * 1 / 50
        )  # each cut has a different duration between [3, 23]
    # Target CutSet is going to have different durations
    # -- make sure the bucketing works well with that.
    cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration))

    sampler = BucketingSampler(
        cut_set,
        cut_set_tgt,
        sampler_type=CutPairsSampler,
        bucket_method="equal_duration",
        num_buckets=10,
        shuffle=shuffle,
    )

    # Ensure that each consecutive bucket has less cuts than the previous one
    prev_len = float("inf")
    bucket_cum_durs = []
    for bucket_src, bucket_tgt in sampler.buckets:
        assert list(bucket_src.ids) == list(bucket_tgt.ids)
        bucket_cum_durs.append(sum(c.duration for c in bucket_src))
        curr_len = len(bucket_src)
        assert curr_len < prev_len
        prev_len = curr_len

    # Assert that all bucket cumulative durations are within 1/10th of the mean
    mean_bucket_dur = mean(bucket_cum_durs)  # ~ 1300s
    for d in bucket_cum_durs:
        assert abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
Ejemplo n.º 10
0
def test_bucketing_sampler_chooses_buckets_randomly():
    # Construct a CutSet that has 1000 cuts with 100 unique durations.
    # Makes it simple to track which bucket was selected.
    cut_set = CutSet({})  # empty
    for i in range(100):
        new_cuts = DummyManifest(CutSet, begin_id=i * 10, end_id=(i + 1) * 10)
        for c in new_cuts:
            c.duration = i
        cut_set = cut_set + new_cuts

    # Sampler that always select one cut.
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        max_cuts=1,
        max_frames=1000000000,
        num_buckets=100,
    )

    # Batches of 1 guarantee that item is always a single-element list of cut IDs.
    durations = [cut_set[item[0].id].duration for item in sampler]

    # This is the "trick" part - 'groupby' groups the cuts together by their duration.
    # If there is a group that has a size of 10, that means the same bucket was chosen
    # for 10 consecutive batches, which is not what BucketingSampler is supposed to do
    # (the probability of that is extremely low).
    # We're actually setting that threshold lower to 8 which should never be triggered
    # anyway.
    lens = []
    for key, group in groupby(durations):
        lens.append(len(list(group)))
    assert all(l < 8 for l in lens)
    print(lens)
Ejemplo n.º 11
0
def test_bucketing_sampler_cut_pairs_equal_len(shuffle):
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (
            3 + idx * 1 / 50
        )  # each cut has a different duration between [3, 23]
    # Target CutSet is going to have different durations
    # -- make sure the bucketing works well with that.
    cut_set_tgt = cut_set.map(lambda c: fastcopy(c, duration=1 / c.duration))

    sampler = BucketingSampler(
        cut_set,
        cut_set_tgt,
        sampler_type=CutPairsSampler,
        bucket_method="equal_len",
        num_buckets=10,
        shuffle=shuffle,
    )

    bucket_cum_durs = []
    for bucket_src, bucket_tgt in sampler.buckets:
        bucket_cum_durs.append(sum(c.duration for c in bucket_src))
        assert len(bucket_src) == 100
        assert list(bucket_src.ids) == list(bucket_tgt.ids)

    # The variations in duration are over 10% of the mean bucket duration (because of equal lengths).
    mean_bucket_dur = mean(bucket_cum_durs)
    assert not all(
        abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur for d in bucket_cum_durs
    )
Ejemplo n.º 12
0
def test_bucketing_sampler_single_cuts_no_proportional_sampling():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(
        cut_set, proportional_sampling=False, sampler_type=SimpleCutSampler
    )
    sampled_cuts = []
    for batch in sampler:
        sampled_cuts.extend(batch)
    assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
Ejemplo n.º 13
0
def test_bucketing_sampler_raises_value_error_on_lazy_cuts_input():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=2)
    with NamedTemporaryFile(suffix=".jsonl") as f:
        cut_set.to_jsonl(f.name)
        lazy_cuts = CutSet.from_jsonl_lazy(f.name)
        with pytest.raises(ValueError):
            sampler = BucketingSampler(
                lazy_cuts,
                max_duration=10.0,
            )
Ejemplo n.º 14
0
def test_bucketing_sampler_cut_pairs():
    cut_set1 = DummyManifest(CutSet, begin_id=0, end_id=1000)
    cut_set2 = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set1, cut_set2, sampler_type=CutPairsSampler)

    cut_ids = []
    for batch in sampler:
        cut_ids.extend(batch)
    assert set(cut_set1.ids) == set(cut_ids)
    assert set(cut_set2.ids) == set(cut_ids)
Ejemplo n.º 15
0
def test_bucketing_sampler_cut_pairs():
    cut_set1 = DummyManifest(CutSet, begin_id=0, end_id=1000)
    cut_set2 = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set1, cut_set2, sampler_type=CutPairsSampler)

    src_cuts, tgt_cuts = [], []
    for src_batch, tgt_batch in sampler:
        src_cuts.extend(src_batch)
        tgt_cuts.extend(tgt_batch)
    assert set(cut_set1.ids) == set(c.id for c in src_cuts)
    assert set(cut_set2.ids) == set(c.id for c in tgt_cuts)
Ejemplo n.º 16
0
def test_bucketing_sampler_shuffle():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=10)
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        shuffle=True,
        num_buckets=2,
        max_frames=200,
    )

    sampler.set_epoch(0)
    batches_ep0 = []
    for batch in sampler:
        # Convert List[str] to Tuple[str, ...] so that it's hashable
        batches_ep0.append(tuple(c.id for c in batch))
    assert set(cut_set.ids) == set(cid for batch in batches_ep0 for cid in batch)

    sampler.set_epoch(1)
    batches_ep1 = []
    for batch in sampler:
        batches_ep1.append(tuple(c.id for c in batch))
    assert set(cut_set.ids) == set(cid for batch in batches_ep1 for cid in batch)

    # BucketingSampler ordering may be different in different epochs (=> use set() to make it irrelevant)
    # Internal sampler (SimpleCutSampler) ordering should be different in different epochs
    assert set(batches_ep0) != set(batches_ep1)
Ejemplo n.º 17
0
def test_bucketing_sampler_drop_last(drop_last):
    # CutSet that has 50 cuts: 10 have 1s, 10 have 2s, etc.
    cut_set = CutSet()
    for i in range(5):
        new_cuts = DummyManifest(CutSet, begin_id=i * 10, end_id=(i + 1) * 10)
        for c in new_cuts:
            c.duration = i + 1
        cut_set = cut_set + new_cuts

    # Sampler that always select one cut.
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        max_duration=10.5,
        num_buckets=5,
        drop_last=drop_last,
    )
    batches = []
    for batch in sampler:
        # Assert there is a consistent cut duration per bucket in this test.
        for cut in batch:
            assert cut.duration == batch[0].duration
        batches.append(batch)

    # Expectation:
    if drop_last:
        # When drop_last = True:
        #   10 x 1s cuts == 1 batch (10 cuts each, 0 left over)
        #   10 x 2s cuts == 2 batches (5 cuts each, 0 left over)
        #   10 x 3s cuts == 3 batches (3 cuts each, 1 left over)
        #   10 x 4s cuts == 5 batches (2 cuts each, 0 left over)
        #   10 x 5s cuts == 5 batches (2 cuts each, 0 left over)
        expected_num_batches = 16
        expected_num_cuts = 49
        expected_discarded_cuts = 1
    else:
        # When drop_last = False:
        #   There will be one more batch with a single 3s cut.
        expected_num_batches = 17
        expected_num_cuts = 50
        expected_discarded_cuts = 0

    num_sampled_cuts = sum(len(b) for b in batches)
    num_discarded_cuts = len(cut_set) - num_sampled_cuts
    assert len(batches) == expected_num_batches
    assert num_sampled_cuts == expected_num_cuts
    assert num_discarded_cuts == expected_discarded_cuts
Ejemplo n.º 18
0
def test_bucketing_sampler_single_cuts_equal_len():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    for idx, c in enumerate(cut_set):
        c.duration = (3 + idx * 1 / 50
                      )  # each cut has a different duration between [3, 23]
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SingleCutSampler,
        bucket_method="equal_len",
        num_buckets=10,
    )

    bucket_cum_durs = []
    for (bucket, ) in sampler.buckets:
        bucket_cum_durs.append(sum(c.duration for c in bucket))
        assert len(bucket) == 100

    # The variations in duration are over 10% of the mean bucket duration (because of equal lengths).
    mean_bucket_dur = mean(bucket_cum_durs)
    assert not all(
        abs(d - mean_bucket_dur) < 0.1 * mean_bucket_dur
        for d in bucket_cum_durs)
Ejemplo n.º 19
0
    num_sampled_cuts = sum(len(b) for b in batches)
    num_discarded_cuts = len(cut_set) - num_sampled_cuts
    assert len(batches) == expected_num_batches
    assert num_sampled_cuts == expected_num_cuts
    assert num_discarded_cuts == expected_discarded_cuts


SAMPLERS_FACTORIES_FOR_REPORT_TEST = [
    lambda: SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
    lambda: DynamicCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
    lambda: CutPairsSampler(
        DummyManifest(CutSet, begin_id=0, end_id=10),
        DummyManifest(CutSet, begin_id=0, end_id=10),
    ),
    lambda: BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10),
                             num_buckets=2),
    lambda: DynamicBucketingSampler(
        DummyManifest(CutSet, begin_id=0, end_id=10),
        max_duration=1.0,
        num_buckets=2,
    ),
    lambda: ZipSampler(
        SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)),
    ),
    lambda: RoundRobinSampler(
        SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)),
    ),
]
Ejemplo n.º 20
0
    num_sampled_cuts = sum(len(b) for b in batches)
    num_discarded_cuts = len(cut_set) - num_sampled_cuts
    assert len(batches) == expected_num_batches
    assert num_sampled_cuts == expected_num_cuts
    assert num_discarded_cuts == expected_discarded_cuts


@pytest.mark.parametrize(
    "sampler",
    [
        SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        CutPairsSampler(
            DummyManifest(CutSet, begin_id=0, end_id=10),
            DummyManifest(CutSet, begin_id=0, end_id=10),
        ),
        BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        ZipSampler(
            SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
            SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)),
        ),
    ],
)
def test_sampler_get_report(sampler):
    _ = [b for b in sampler]
    print(sampler.get_report())
    # It runs - voila!


@pytest.mark.parametrize(
    "sampler_cls",
    [