Example #1
0
def test_dynamic_bucketing_sampler_filter():
    cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
    for i, c in enumerate(cuts):
        if i < 5:
            c.duration = 1
        else:
            c.duration = 2

    sampler = DynamicBucketingSampler(cuts,
                                      max_duration=5,
                                      num_buckets=2,
                                      seed=0)
    sampler.filter(lambda cut: cut.duration > 1)
    batches = [b for b in sampler]
    sampled_cuts = [c for b in batches for c in b]

    # Invariant: no duplicated cut IDs
    assert len(set(c.id for b in batches for c in b)) == len(sampled_cuts)

    # Same number of sampled and source cuts.
    assert len(sampled_cuts) < len(cuts)
    assert len(sampled_cuts) == 5

    # We sampled 4 batches with this RNG, like the following:
    assert len(batches) == 3

    assert len(batches[0]) == 2
    assert sum(c.duration for c in batches[0]) == 4

    assert len(batches[1]) == 2
    assert sum(c.duration for c in batches[1]) == 4

    assert len(batches[2]) == 1
    assert sum(c.duration for c in batches[2]) == 2
Example #2
0
def test_dynamic_bucketing_sampler_cut_pairs_filter():
    cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
    for i, c in enumerate(cuts):
        if i < 5:
            c.duration = 1
        else:
            c.duration = 2

    sampler = DynamicBucketingSampler(cuts,
                                      cuts,
                                      max_duration=5,
                                      num_buckets=2,
                                      seed=0)
    sampler.filter(lambda c: c.duration > 1)
    batches = [b for b in sampler]
    sampled_cut_pairs = [cut_pair for b in batches for cut_pair in zip(*b)]
    source_cuts = [sc for sc, tc in sampled_cut_pairs]
    target_cuts = [tc for sc, tc in sampled_cut_pairs]

    # Invariant: no duplicated cut IDs (there are 5 unique IDs)
    assert len(set(c.id for c in source_cuts)) == 5
    assert len(set(c.id for c in target_cuts)) == 5

    # Smaller number of sampled cuts than the source cuts.
    assert len(sampled_cut_pairs) < len(cuts)
    assert len(sampled_cut_pairs) == 5

    # We sampled 3 batches with this RNG, like the following:
    assert len(batches) == 3

    bidx = 0
    sc, tc = batches[bidx][0], batches[bidx][1]
    assert len(sc) == 2
    assert len(tc) == 2
    assert sum(c.duration for c in sc) == 4
    assert sum(c.duration for c in tc) == 4

    bidx = 1
    sc, tc = batches[bidx][0], batches[bidx][1]
    assert len(sc) == 2
    assert len(tc) == 2
    assert sum(c.duration for c in sc) == 4
    assert sum(c.duration for c in tc) == 4

    bidx = 2
    sc, tc = batches[bidx][0], batches[bidx][1]
    assert len(sc) == 1
    assert len(tc) == 1
    assert sum(c.duration for c in sc) == 2
    assert sum(c.duration for c in tc) == 2