コード例 #1
0
def test_zip_sampler_merge_batches_true():
    cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100)
    cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100)
    sampler = ZipSampler(
        # Note: each cut is 1s duration in this test.
        SingleCutSampler(cuts1, max_duration=10),
        SingleCutSampler(cuts2, max_duration=2),
    )
    batches = [b for b in sampler]
    assert len(batches) == 10
    for idx, batch in enumerate(batches):
        assert len(batch) == 12  # twelve 1s items
        assert (len([c for c in batch if 0 <= int(c.id.split("-")[-1]) <= 100
                     ]) == 10)  # ten come from cuts1
        assert (len([
            c for c in batch if 1000 <= int(c.id.split("-")[-1]) <= 1100
        ]) == 2)  # two come from cuts2
コード例 #2
0
def test_zip_sampler_merge_batches_false():
    cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100)
    cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100)
    sampler = ZipSampler(
        # Note: each cut is 1s duration in this test.
        SimpleCutSampler(cuts1, max_duration=10),
        SimpleCutSampler(cuts2, max_duration=2),
        merge_batches=False,
    )
    batches = [b for b in sampler]
    assert len(batches) == 10
    for idx, (batch_sampler1, batch_sampler2) in enumerate(batches):
        assert len(batch_sampler1) == 10
        assert (len([
            c for c in batch_sampler1 if 0 <= int(c.id.split("-")[-1]) <= 100
        ]) == 10)  # ten come from cuts1
        assert len(batch_sampler2) == 2
        assert (len([
            c for c in batch_sampler2
            if 1000 <= int(c.id.split("-")[-1]) <= 1100
        ]) == 2)  # two come from cuts2
コード例 #3
0
    assert len(batches) == expected_num_batches
    assert num_sampled_cuts == expected_num_cuts
    assert num_discarded_cuts == expected_discarded_cuts


@pytest.mark.parametrize(
    "sampler",
    [
        SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        CutPairsSampler(
            DummyManifest(CutSet, begin_id=0, end_id=10),
            DummyManifest(CutSet, begin_id=0, end_id=10),
        ),
        BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        ZipSampler(
            SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
            SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)),
        ),
    ],
)
def test_sampler_get_report(sampler):
    _ = [b for b in sampler]
    print(sampler.get_report())
    # It runs - voila!


@pytest.mark.parametrize(
    "sampler_cls",
    [
        SimpleCutSampler,
        BucketingSampler,
    ],