Exemplo n.º 1
0
def test_single_cut_sampler_time_constraints(
    max_duration, max_frames, max_samples, exception_expectation
):
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
    if max_frames is None:
        cut_set = cut_set.drop_features()

    with exception_expectation:
        sampler = SimpleCutSampler(
            cut_set,
            shuffle=True,
            # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
            # This way we're testing that it works okay when returning multiple batches in
            # a full epoch.
            max_frames=max_frames,
            max_samples=max_samples,
            max_duration=max_duration,
        )
        sampler_cut_ids = []
        for batch in sampler:
            sampler_cut_ids.extend(batch)

        # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet
        assert len(sampler_cut_ids) == len(cut_set)
        # Invariant 2: the items are not duplicated
        assert len(set(c.id for c in sampler_cut_ids)) == len(sampler_cut_ids)
        # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet
        assert [c.id for c in sampler_cut_ids] != [c.id for c in cut_set]
Exemplo n.º 2
0
def test_cut_pairs_sampler_time_constraints(max_duration, max_frames,
                                            max_samples,
                                            exception_expectation):
    # The dummy cuts have a duration of 1 second each
    cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
    if max_frames is None:
        cut_set = cut_set.drop_features()

    with exception_expectation:
        sampler = CutPairsSampler(
            source_cuts=cut_set,
            target_cuts=cut_set,
            shuffle=True,
            # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames
            # This way we're testing that it works okay when returning multiple batches in
            # a full epoch.
            max_source_frames=max_frames,
            max_target_frames=max_frames /
            2 if max_frames is not None else None,
            max_source_samples=max_samples,
            max_target_samples=max_samples /
            2 if max_samples is not None else None,
            max_source_duration=max_duration,
            max_target_duration=max_duration /
            2 if max_duration is not None else None,
        )
        source_cuts, target_cuts = [], []
        for src_batch, tgt_batch in sampler:
            source_cuts.extend(src_batch)
            target_cuts.extend(tgt_batch)

        # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet
        assert len(source_cuts) == len(cut_set)
        assert len(target_cuts) == len(cut_set)
        # Invariant 2: the items are not duplicated
        assert len(set(c.id for c in source_cuts)) == len(source_cuts)
        assert len(set(c.id for c in target_cuts)) == len(target_cuts)
        # Invariant 3: the items are shuffled, i.e. the order is different than that in the CutSet
        assert [c.id for c in source_cuts] != [c.id for c in cut_set]
        # Invariant 4: the source and target cuts are in the same order
        assert [c.id for c in source_cuts] == [c.id for c in target_cuts]