def test_bucketing_sampler_time_constraints(constraint): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler, **constraint) sampled_cuts = [] for batch in sampler: sampled_cuts.extend(batch) assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
def test_bucketing_sampler_single_cuts_no_proportional_sampling(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, proportional_sampling=False, sampler_type=SimpleCutSampler) sampled_cuts = [] for batch in sampler: sampled_cuts.extend(batch) assert set(cut_set.ids) == set(c.id for c in sampled_cuts)
def test_split_randomize(manifest_type): manifest = DummyManifest(manifest_type, begin_id=0, end_id=100) manifest_subsets = manifest.split(num_splits=2, shuffle=True) assert len(manifest_subsets) == 2 recombined_items = list(manifest_subsets[0]) + list(manifest_subsets[1]) assert len(recombined_items) == len(manifest) # Different ordering (we convert to lists first because the *Set classes might internally # re-order after concatenation, e.g. by using dict or post-init sorting) assert recombined_items != list(manifest)
def test_bucketing_sampler_raises_value_error_on_lazy_cuts_input(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=2) with NamedTemporaryFile(suffix=".jsonl") as f: cut_set.to_jsonl(f.name) lazy_cuts = CutSet.from_jsonl_lazy(f.name) with pytest.raises(ValueError): sampler = BucketingSampler( lazy_cuts, max_duration=10.0, )
def test_lazy_cuts_combine_split_issue(): # Test for lack of exception cuts = DummyManifest(CutSet, begin_id=0, end_id=1000) with TemporaryDirectory() as d, NamedTemporaryFile(suffix=".jsonl.gz") as f: cuts.to_file(f.name) f.flush() cuts_lazy = load_manifest_lazy(f.name) cuts_lazy = combine(cuts_lazy, cuts_lazy.perturb_speed(0.9)) cuts_lazy.split_lazy(d, chunk_size=100)
def test_bucketing_sampler_order_differs_between_epochs(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=1000) sampler = BucketingSampler(cut_set, sampler_type=SimpleCutSampler) last_order = [item for item in sampler] for epoch in range(1, 6): sampler.set_epoch(epoch) new_order = [item for item in sampler] assert new_order != last_order last_order = new_order
def test_zip_sampler_cut_pairs_merge_batches_true(): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) sampler = ZipSampler( # Note: each cut is 1s duration in this test. CutPairsSampler(cuts1, cuts1, max_source_duration=10), CutPairsSampler(cuts2, cuts2, max_source_duration=2), ) batches = [b for b in sampler] assert len(batches) == 10 for idx, (batch_src, batch_tgt) in enumerate(batches): assert len(batch_src) == len(batch_tgt) assert len(batch_src) == 12 # twelve 1s items assert ( len([c for c in batch_src if 0 <= int(c.id.split("-")[-1]) <= 100]) == 10 ) # ten come from cuts1 assert ( len([c for c in batch_src if 1000 <= int(c.id.split("-")[-1]) <= 1100]) == 2 ) # two come from cuts2
def test_filter(manifest_type): expected = DummyManifest(manifest_type, begin_id=0, end_id=5) for idx, item in enumerate(expected): item.duration = idx def predicate(item): return item.duration < 5 data = DummyManifest(manifest_type, begin_id=0, end_id=10) for idx, item in enumerate(data): item.duration = idx eager_result = data.filter(predicate) assert list(eager_result) == list(expected) with as_lazy(data) as lazy_data: lazy_result = lazy_data.filter(predicate) assert list(lazy_result) == list(expected)
def test_dynamic_bucketing_sampler_cut_pairs(): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): if i < 5: c.duration = 1 else: c.duration = 2 sampler = DynamicBucketingSampler(cuts, cuts, max_duration=5, num_buckets=2, seed=0) batches = [b for b in sampler] sampled_cut_pairs = [cut_pair for b in batches for cut_pair in zip(*b)] source_cuts = [sc for sc, tc in sampled_cut_pairs] target_cuts = [tc for sc, tc in sampled_cut_pairs] # Invariant: no duplicated cut IDs assert len(set(c.id for c in source_cuts)) == len(cuts) assert len(set(c.id for c in target_cuts)) == len(cuts) # Same number of sampled and source cuts. assert len(sampled_cut_pairs) == len(cuts) # We sampled 4 batches with this RNG, like the following: assert len(batches) == 4 bidx = 0 sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 2 assert len(tc) == 2 assert sum(c.duration for c in sc) == 4 assert sum(c.duration for c in tc) == 4 bidx = 1 sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 2 assert len(tc) == 2 assert sum(c.duration for c in sc) == 4 assert sum(c.duration for c in tc) == 4 bidx = 2 sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 5 assert len(tc) == 5 assert sum(c.duration for c in sc) == 5 assert sum(c.duration for c in tc) == 5 bidx = 3 sc, tc = batches[bidx][0], batches[bidx][1] assert len(sc) == 1 assert len(tc) == 1 assert sum(c.duration for c in sc) == 2 assert sum(c.duration for c in tc) == 2
def test_estimate_duration_buckets_2b(): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): if i < 5: c.duration = 1 else: c.duration = 2 bins = estimate_duration_buckets(cuts, num_buckets=2) assert bins == [2]
def test_repeat(manifest_type): data = DummyManifest(manifest_type, begin_id=0, end_id=10) expected = data + data eager_result = data.repeat(times=2) assert list(eager_result) == list(expected) with as_lazy(data) as lazy_data: lazy_result = lazy_data.repeat(times=2) assert list(lazy_result) == list(expected)
def test_cut_set_subset_cut_ids_preserves_order_with_lazy_manifest(): cuts = DummyManifest(CutSet, begin_id=0, end_id=1000) cut_ids = ["dummy-cut-0010", "dummy-cut-0171", "dummy-cut-0009"] with NamedTemporaryFile(suffix=".jsonl.gz") as f: cuts.to_file(f.name) cuts = cuts.from_jsonl_lazy(f.name) subcuts = cuts.subset(cut_ids=cut_ids) cut1, cut2, cut3 = subcuts assert cut1.id == "dummy-cut-0010" assert cut2.id == "dummy-cut-0171" assert cut3.id == "dummy-cut-0009"
def test_perturb_speed(): tfnm = PerturbSpeed(factors=[0.9, 1.1], p=0.5, randgen=random.Random(42)) cuts = DummyManifest(CutSet, begin_id=0, end_id=10) cuts_sp = tfnm(cuts) print(set(c.duration for c in cuts_sp)) assert all( # The duration will not be exactly 0.9 and 1.1 because perturb speed # will round to a physically-viable duration based on the sampling_rate # (i.e. round to the nearest sample count). any(isclose(cut.duration, v, abs_tol=0.0125) for v in [0.9, 1.0, 1.1]) for cut in cuts_sp )
def test_zip_sampler_merge_batches_false(): cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100) cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) sampler = ZipSampler( # Note: each cut is 1s duration in this test. SingleCutSampler(cuts1, max_duration=10), SingleCutSampler(cuts2, max_duration=2), merge_batches=False, ) batches = [b for b in sampler] assert len(batches) == 10 for idx, (batch_sampler1, batch_sampler2) in enumerate(batches): assert len(batch_sampler1) == 10 assert (len([ c for c in batch_sampler1 if 0 <= int(c.id.split("-")[-1]) <= 100 ]) == 10) # ten come from cuts1 assert len(batch_sampler2) == 2 assert (len([ c for c in batch_sampler2 if 1000 <= int(c.id.split("-")[-1]) <= 1100 ]) == 2) # two come from cuts2
def test_bucketing_sampler_buckets_have_different_durations(): cut_set_1s = DummyManifest(CutSet, begin_id=0, end_id=10) cut_set_2s = DummyManifest(CutSet, begin_id=10, end_id=20) for c in cut_set_2s: c.duration = 2.0 cut_set = cut_set_1s + cut_set_2s # The bucketing sampler should return 5 batches with two 1s cuts, and 10 batches with one 2s cut. sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, max_frames=200, num_buckets=2 ) batches = [item for item in sampler] assert len(batches) == 15 # All cuts have the same durations (i.e. are from the same bucket in this case) for batch in batches: batch_durs = [cut_set[c.id].duration for c in batch] assert all(d == batch_durs[0] for d in batch_durs) batches = sorted(batches, key=len) assert all(len(b) == 1 for b in batches[:10]) assert all(len(b) == 2 for b in batches[10:])
def test_single_cut_sampler_order_is_deterministic_given_epoch(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = SingleCutSampler( cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_frames=1000) sampler.set_epoch(42) # calling the sampler twice without epoch update gives identical ordering assert [item for item in sampler] == [item for item in sampler]
def test_sequential_jsonl_writer_overwrite(overwrite): cuts = DummyManifest(CutSet, begin_id=0, end_id=100) half = cuts.split(num_splits=2)[0] with NamedTemporaryFile(suffix='.jsonl') as jsonl_f: # Store the first half half.to_file(jsonl_f.name) # Open sequential writer with CutSet.open_writer(jsonl_f.name, overwrite=overwrite) as writer: if overwrite: assert all(not writer.contains(id_) for id_ in half.ids) else: assert all(writer.contains(id_) for id_ in half.ids)
def test_sampler_does_not_drop_cuts_with_multiple_ranks( world_size, sampler_fn): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) tot_cuts = 0 for rank in range(world_size): sampler = sampler_fn(cuts, max_duration=1.0, world_size=world_size, rank=rank) for batch in sampler: tot_cuts += len(batch) assert tot_cuts == len(cuts)
def test_partitions_are_equal(world_size, n_cuts, sampler_cls): # Create a dummy CutSet. cut_set = DummyManifest(CutSet, begin_id=0, end_id=n_cuts) # Randomize the durations of cuts to increase the chance we run into edge cases. for c in cut_set: c.duration += 10 * random.random() # Create a sampler for each "distributed worker." samplers = [ sampler_cls(cut_set, max_duration=25.0, rank=i, world_size=world_size) for i in range(world_size) ] # Check that it worked. n_batches = [len([b for b in s]) for s in samplers] assert all(nb == n_batches[0] for nb in n_batches)
def test_single_cut_sampler_with_lazy_cuts(sampler_cls): # The dummy cuts have a duration of 1 second each eager1 = DummyManifest(CutSet, begin_id=0, end_id=100) eager2 = DummyManifest(CutSet, begin_id=1000, end_id=1100) eager_cuts = eager1 + eager2 with as_lazy(eager1) as lazy1, as_lazy(eager2) as lazy2: lazy_cuts = lazy1 + lazy2 sampler = sampler_cls( lazy_cuts, shuffle=False, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_frames=1000, ) sampled_cuts = [] for batch in sampler: sampled_cuts.extend(batch) # Invariant 1: we receive the same amount of items in a dataloader epoch as there we in the CutSet assert len(sampled_cuts) == len(eager_cuts) # Invariant 2: the items are not duplicated assert len(set(c.id for c in sampled_cuts)) == len(sampled_cuts)
def test_estimate_duration_buckets_4b(): cuts = DummyManifest(CutSet, begin_id=0, end_id=20) for i, c in enumerate(cuts): if i < 5: c.duration = 1 elif i < 10: c.duration = 2 elif i < 15: c.duration = 3 elif i < 20: c.duration = 4 bins = estimate_duration_buckets(cuts, num_buckets=4) assert bins == [2, 3, 4]
def test_repeat_infinite(manifest_type): data = DummyManifest(manifest_type, begin_id=0, end_id=10) # hard to test infinite iterables, iterate it 10x more times than the original size eager_result = data.repeat() for idx, item in enumerate(eager_result): if idx == 105: break assert idx == 105 with as_lazy(data) as lazy_data: lazy_result = lazy_data.repeat() for idx, item in enumerate(lazy_result): if idx == 105: break assert idx == 105
def test_shuffle(manifest_type): data = DummyManifest(manifest_type, begin_id=0, end_id=4) for idx, item in enumerate(data): item.duration = idx expected_durations = [2, 1, 3, 0] rng = random.Random(42) eager_result = data.shuffle(rng=rng) assert [item.duration for item in eager_result] == list(expected_durations) with as_lazy(data) as lazy_data: lazy_result = lazy_data.shuffle(rng=rng) assert [item.duration for item in lazy_result] == list(expected_durations)
def test_combine(manifest_type): expected = DummyManifest(manifest_type, begin_id=0, end_id=200) combined = combine( DummyManifest(manifest_type, begin_id=0, end_id=68), DummyManifest(manifest_type, begin_id=68, end_id=136), DummyManifest(manifest_type, begin_id=136, end_id=200), ) assert combined == expected combined_iterable = combine([ DummyManifest(manifest_type, begin_id=0, end_id=68), DummyManifest(manifest_type, begin_id=68, end_id=136), DummyManifest(manifest_type, begin_id=136, end_id=200), ]) assert combined_iterable == expected
def test_single_cut_sampler_order_differs_between_epochs(): cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = SingleCutSampler( cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_frames=1000) last_order = [item for item in sampler] for epoch in range(1, 6): sampler.set_epoch(epoch) new_order = [item for item in sampler] assert new_order != last_order last_order = new_order
def test_bucketing_sampler_drop_last(drop_last): # CutSet that has 50 cuts: 10 have 1s, 10 have 2s, etc. cut_set = CutSet() for i in range(5): new_cuts = DummyManifest(CutSet, begin_id=i * 10, end_id=(i + 1) * 10) for c in new_cuts: c.duration = i + 1 cut_set = cut_set + new_cuts # Sampler that always select one cut. sampler = BucketingSampler( cut_set, sampler_type=SimpleCutSampler, max_duration=10.5, num_buckets=5, drop_last=drop_last, ) batches = [] for batch in sampler: # Assert there is a consistent cut duration per bucket in this test. for cut in batch: assert cut.duration == batch[0].duration batches.append(batch) # Expectation: if drop_last: # When drop_last = True: # 10 x 1s cuts == 1 batch (10 cuts each, 0 left over) # 10 x 2s cuts == 2 batches (5 cuts each, 0 left over) # 10 x 3s cuts == 3 batches (3 cuts each, 1 left over) # 10 x 4s cuts == 5 batches (2 cuts each, 0 left over) # 10 x 5s cuts == 5 batches (2 cuts each, 0 left over) expected_num_batches = 16 expected_num_cuts = 49 expected_discarded_cuts = 1 else: # When drop_last = False: # There will be one more batch with a single 3s cut. expected_num_batches = 17 expected_num_cuts = 50 expected_discarded_cuts = 0 num_sampled_cuts = sum(len(b) for b in batches) num_discarded_cuts = len(cut_set) - num_sampled_cuts assert len(batches) == expected_num_batches assert num_sampled_cuts == expected_num_cuts assert num_discarded_cuts == expected_discarded_cuts
def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches( ): cuts = DummyManifest(CutSet, begin_id=0, end_id=10) for i, c in enumerate(cuts): if i < 5: c.duration = 1 else: c.duration = 2 # 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets sampler = DynamicBucketingSampler(cuts, max_duration=100, num_buckets=2, seed=0, drop_last=True) batches = [b for b in sampler] assert len(batches) == 0
def test_cut_pairs_sampler_order_is_deterministic_given_epoch(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = CutPairsSampler( source_cuts=cut_set, target_cuts=cut_set, shuffle=True, # Set an effective batch size of 10 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_source_frames=1000, max_target_frames=500, ) sampler.set_epoch(42) # calling the sampler twice without epoch update gives identical ordering assert [item for item in sampler] == [item for item in sampler]
def test_single_cut_sampler_drop_last(): # The dummy cuts have a duration of 1 second each cut_set = DummyManifest(CutSet, begin_id=0, end_id=100) sampler = SimpleCutSampler( cut_set, # Set an effective batch size of 15 cuts, as all have 1s duration == 100 frames # This way we're testing that it works okay when returning multiple batches in # a full epoch. max_frames=1500, drop_last=True, ) batches = [] for batch in sampler: assert len(batch) == 15 batches.append(batch) assert len(batches) == 6
def test_perturb_tempo(preserve_id: bool): tfnm = PerturbTempo( factors=[0.9, 1.1], p=0.5, randgen=random.Random(42), preserve_id=preserve_id ) cuts = DummyManifest(CutSet, begin_id=0, end_id=10) cuts_tp = tfnm(cuts) assert all( # The duration will not be exactly 0.9 and 1.1 because perturb speed # will round to a physically-viable duration based on the sampling_rate # (i.e. round to the nearest sample count). any(isclose(cut.duration, v, abs_tol=0.0125) for v in [0.9, 1.0, 1.1]) for cut in cuts_tp ) if preserve_id: assert all(cut.id == cut_tp.id for cut, cut_tp in zip(cuts, cuts_tp)) else: # Note: not using all() because PerturbTempo has p=0.5 assert any(cut.id != cut_tp.id for cut, cut_tp in zip(cuts, cuts_tp))