def test_bucketing_sampler_shuffle(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=10) sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, shuffle=True, num_buckets=2, max_frames=200, ) sampler.set_epoch(0) batches_ep0 = [] for batch in sampler: # Convert List[str] to Tuple[str, ...] so that it's hashable batches_ep0.append(tuple(c.id for c in batch)) assert set(cut_set.ids) == set(cid for batch in batches_ep0 for cid in batch) sampler.set_epoch(1) batches_ep1 = [] for batch in sampler: batches_ep1.append(tuple(c.id for c in batch)) assert set(cut_set.ids) == set(cid for batch in batches_ep1 for cid in batch) # BucketingSampler ordering may be different in different epochs (=> use set() to make it irrelevant) # Internal sampler (SimpleCutSampler) ordering should be different in different epochs assert set(batches_ep0) != set(batches_ep1)
def test_bucketing_sampler_order_is_deterministic_given_epoch(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler) sampler.set_epoch(42) # calling the sampler twice without epoch update gives identical ordering assert [item for item in sampler] == [item for item in sampler]
def test_bucketing_sampler_order_differs_between_epochs(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler) last_order = [item for item in sampler] for epoch in range(1, 6): sampler.set_epoch(epoch) new_order = [item for item in sampler] assert new_order != last_order last_order = new_order
def test_bucketing_sampler_len(): # total duration is 550 seconds # each second has 100 frames cuts = CutSet.from_cuts( dummy_cut(idx, duration=float(duration)) for idx, duration in enumerate(list(range(1, 11)) * 10)) sampler = BucketingSampler(cuts, num_buckets=4, shuffle=True, max_frames=64 * 100, max_cuts=6) for epoch in range(5): assert len(sampler) == len([item for item in sampler]) sampler.set_epoch(epoch)