Пример #1
0
def test_bucketing_sampler_shuffle():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=10)
    sampler = BucketingSampler(
        cut_set,
        sampler_type=SimpleCutSampler,
        shuffle=True,
        num_buckets=2,
        max_frames=200,
    )

    sampler.set_epoch(0)
    batches_ep0 = []
    for batch in sampler:
        # Convert List[str] to Tuple[str, ...] so that it's hashable
        batches_ep0.append(tuple(c.id for c in batch))
    assert set(cut_set.ids) == set(cid for batch in batches_ep0 for cid in batch)

    sampler.set_epoch(1)
    batches_ep1 = []
    for batch in sampler:
        batches_ep1.append(tuple(c.id for c in batch))
    assert set(cut_set.ids) == set(cid for batch in batches_ep1 for cid in batch)

    # BucketingSampler ordering may be different in different epochs (=> use set() to make it irrelevant)
    # Internal sampler (SimpleCutSampler) ordering should be different in different epochs
    assert set(batches_ep0) != set(batches_ep1)
Пример #2
0
def test_bucketing_sampler_order_is_deterministic_given_epoch():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler)

    sampler.set_epoch(42)
    # calling the sampler twice without epoch update gives identical ordering
    assert [item for item in sampler] == [item for item in sampler]
Пример #3
0
def test_bucketing_sampler_order_differs_between_epochs():
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000)
    sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler)

    last_order = [item for item in sampler]
    for epoch in range(1, 6):
        sampler.set_epoch(epoch)
        new_order = [item for item in sampler]
        assert new_order != last_order
        last_order = new_order
Пример #4
0
def test_bucketing_sampler_len():
    # total duration is 550 seconds
    # each second has 100 frames
    cuts = CutSet.from_cuts(
        dummy_cut(idx, duration=float(duration))
        for idx, duration in enumerate(list(range(1, 11)) * 10))

    sampler = BucketingSampler(cuts,
                               num_buckets=4,
                               shuffle=True,
                               max_frames=64 * 100,
                               max_cuts=6)

    for epoch in range(5):
        assert len(sampler) == len([item for item in sampler])
        sampler.set_epoch(epoch)