Exemplo n.º 1
0
def test_cut_pairs_sampler_filter():
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
    sampler = CutPairsSampler(
        cut_set,
        cut_set,
        shuffle=True,
        # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
        # This way we're testing that it works okay when returning multiple batches in
        # a full epoch.
        max_source_frames=1000,
    )
    removed_cut_id = "dummy-cut-0010"
    sampler.filter(lambda cut: cut.id != removed_cut_id)

    source_cuts, target_cuts = [], []
    for src_batch, tgt_batch in sampler:
        source_cuts.extend(src_batch)
        target_cuts.extend(tgt_batch)

    # The filtered cut is not there
    assert removed_cut_id in set(cut_set.ids)
    assert removed_cut_id not in set(c.id for c in source_cuts)

    # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet,
    # minus the filtered item
    assert len(source_cuts) == len(cut_set) - 1
    assert len(target_cuts) == len(cut_set) - 1
    # Invariant 2: the items are not duplicated
    assert len(set(c.id for c in source_cuts)) == len(source_cuts)
    assert len(set(c.id for c in target_cuts)) == len(target_cuts)
Exemplo n.º 2
0
def test_cut_pairs_sampler_len():
    # total duration is 55 seconds
    # each second has 100 frames
    cuts = CutSet.from_cuts(dummy_cut(idx, duration=float(idx)) for idx in range(1, 11))
    sampler = CutPairsSampler(
        source_cuts=cuts,
        target_cuts=cuts,
        shuffle=True,
        max_source_frames=10 * 100,
        max_target_frames=10 * 100,
    )

    for epoch in range(5):
        assert len(sampler) == len([batch for batch in sampler])
        sampler.set_epoch(epoch)
Exemplo n.º 3
0
def test_cut_pairs_sampler():
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)

    sampler = CutPairsSampler(
        source_cuts=cut_set,
        target_cuts=cut_set,
        shuffle=True,
        # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
        # This way we're testing that it works okay when returning multiple batches in
        # a full epoch.
        max_source_frames=1000,
        max_target_frames=500,
    )
    source_cuts, target_cuts = [], []
    for src_batch, tgt_batch in sampler:
        source_cuts.extend(src_batch)
        target_cuts.extend(tgt_batch)

    # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet
    assert len(source_cuts) == len(cut_set)
    assert len(target_cuts) == len(cut_set)
    # Invariant 2: the items are not duplicated
    assert len(set(c.id for c in source_cuts)) == len(source_cuts)
    assert len(set(c.id for c in target_cuts)) == len(target_cuts)
    # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet
    assert [c.id for c in source_cuts] != [c.id for c in cut_set]
    # Invariant 4: the source and target cuts are in the same order
    assert [c.id for c in source_cuts] == [c.id for c in target_cuts]
Exemplo n.º 4
0
def test_cut_pairs_sampler_order_is_deterministic_given_epoch():
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)

    sampler = CutPairsSampler(
        source_cuts=cut_set,
        target_cuts=cut_set,
        shuffle=True,
        # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
        # This way we're testing that it works okay when returning multiple batches in
        # a full epoch.
        max_source_frames=1000,
        max_target_frames=500,
    )
    sampler.set_epoch(42)
    # calling the sampler twice without epoch update gives identical ordering
    assert [item for item in sampler] == [item for item in sampler]
Exemplo n.º 5
0
def test_zip_sampler_cut_pairs_merge_batches_true():
    cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100)
    cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100)
    sampler = ZipSampler(
        # Note: each cut is 1s duration in this test.
        CutPairsSampler(cuts1, cuts1, max_source_duration=10),
        CutPairsSampler(cuts2, cuts2, max_source_duration=2),
    )
    batches = [b for b in sampler]
    assert len(batches) == 10
    for idx, (batch_src, batch_tgt) in enumerate(batches):
        assert len(batch_src) == len(batch_tgt)
        assert len(batch_src) == 12  # twelve 1s items
        assert (
            len([c for c in batch_src if 0 <= int(c.id.split("-")[-1]) <= 100]) == 10
        )  # ten come from cuts1
        assert (
            len([c for c in batch_src if 1000 <= int(c.id.split("-")[-1]) <= 1100]) == 2
        )  # two come from cuts2
Exemplo n.º 6
0
def test_cut_pairs_sampler_order_differs_between_epochs():
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)

    sampler = CutPairsSampler(
        source_cuts=cut_set,
        target_cuts=cut_set,
        shuffle=True,
        # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
        # This way we're testing that it works okay when returning multiple batches in
        # a full epoch.
        max_source_frames=1000,
        max_target_frames=500,
    )

    last_order = [item for item in sampler]
    for epoch in range(1, 6):
        sampler.set_epoch(epoch)
        new_order = [item for item in sampler]
        assert new_order != last_order
        last_order = new_order
Exemplo n.º 7
0
def test_cut_pairs_sampler_2():
    cut_set = CutSet.from_cuts([
        dummy_cut(0, duration=10),
        dummy_cut(1, duration=20),
    ])
    sampler = CutPairsSampler(
        source_cuts=cut_set,
        target_cuts=cut_set,
        max_source_duration=50,
        max_target_duration=50,
    )
    batch = next(iter(sampler))
    assert len(batch) == 2
Exemplo n.º 8
0
def test_cut_pairs_sampler_time_constraints(max_duration, max_frames,
                                            max_samples,
                                            exception_expectation):
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
    if max_frames is None:
        cut_set = cut_set.drop_features()

    with exception_expectation:
        sampler = CutPairsSampler(
            source_cuts=cut_set,
            target_cuts=cut_set,
            shuffle=True,
            # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
            # This way we're testing that it works okay when returning multiple batches in
            # a full epoch.
            max_source_frames=max_frames,
            max_target_frames=max_frames /
            2 if max_frames is not None else None,
            max_source_samples=max_samples,
            max_target_samples=max_samples /
            2 if max_samples is not None else None,
            max_source_duration=max_duration,
            max_target_duration=max_duration /
            2 if max_duration is not None else None,
        )
        sampler_cut_ids = []
        for batch in sampler:
            sampler_cut_ids.extend(batch)

        # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet
        assert len(sampler_cut_ids) == len(cut_set)
        # Invariant 2: the items are not duplicated
        assert len(set(sampler_cut_ids)) == len(sampler_cut_ids)
        # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet
        assert sampler_cut_ids != [c.id for c in cut_set]
Exemplo n.º 9
0
        expected_num_cuts = 50
        expected_discarded_cuts = 0

    num_sampled_cuts = sum(len(b) for b in batches)
    num_discarded_cuts = len(cut_set) - num_sampled_cuts
    assert len(batches) == expected_num_batches
    assert num_sampled_cuts == expected_num_cuts
    assert num_discarded_cuts == expected_discarded_cuts


@pytest.mark.parametrize(
    "sampler",
    [
        SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        CutPairsSampler(
            DummyManifest(CutSet, begin_id=0, end_id=10),
            DummyManifest(CutSet, begin_id=0, end_id=10),
        ),
        BucketingSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
        ZipSampler(
            SimpleCutSampler(DummyManifest(CutSet, begin_id=0, end_id=10)),
            SimpleCutSampler(DummyManifest(CutSet, begin_id=10, end_id=20)),
        ),
    ],
)
def test_sampler_get_report(sampler):
    _ = [b for b in sampler]
    print(sampler.get_report())
    # It runs - voila!


@pytest.mark.parametrize(