def test_cutset_from_webdataset_sharded_pipe(): cuts = CutSet.from_file("test/fixtures/libri/cuts.json") cut = cuts[0] cuts = [] for i in range(10): cuts.append(fastcopy(cut, id=cut.id + "-" + str(i))) cuts = CutSet.from_cuts(cuts) with TemporaryDirectory() as dir_path: tar_pattern = f"pipe:gzip -c > {dir_path}/shard-%06d.tar.gz" export_to_webdataset(cuts, output_path=tar_pattern, shard_size=2) # disabling shard shuffling for testing purposes here cuts_ds = CutSet.from_webdataset( "pipe:gunzip -c " + dir_path + "/shard-{000000..000004}.tar.gz", shuffle_shards=False, ) assert list(cuts.ids) == list(cuts_ds.ids) for c, cds in zip(cuts, cuts_ds): np.testing.assert_equal(c.load_audio(), cds.load_audio()) np.testing.assert_almost_equal( c.load_features(), cds.load_features(), decimal=2 )
def test_webdataset_sampler_epoch_increment(): cuts = CutSet.from_file("test/fixtures/libri/cuts.json").repeat(10) with TemporaryDirectory() as dir_path: tar_pattern = f"{dir_path}/shard-%06d.tar" export_to_webdataset(cuts, output_path=tar_pattern, shard_size=1) cuts_ds = CutSet.from_webdataset( [str(p) for p in Path(dir_path).glob("*.tar")], shuffle_shards=True ) sampler = DynamicCutSampler(cuts_ds, max_cuts=1) dloader = DataLoader( IterableDatasetWrapper(DummyDataset(), sampler, auto_increment_epoch=True), batch_size=None, num_workers=1, persistent_workers=True, ) epoch_batches = {} for epoch in [0, 1]: batches = [] for batch in dloader: for cut in batch: batches.append(cut) epoch_batches[epoch] = CutSet.from_cuts(batches) # Both epochs have the same cut IDs. assert sorted(epoch_batches[0].ids) == sorted(epoch_batches[1].ids) # Both epochs have different cut order (shards were re-shuffled). assert list(epoch_batches[0].ids) != list(epoch_batches[1].ids)
def test_cutset_from_webdataset(): cuts = CutSet.from_file("test/fixtures/libri/cuts.json") cut = cuts[0] cuts = [] for i in range(10): cuts.append(fastcopy(cut, id=cut.id + "-" + str(i))) cuts = CutSet.from_cuts(cuts) with NamedTemporaryFile(suffix=".tar") as f: export_to_webdataset(cuts, output_path=f.name) f.flush() cuts_ds = CutSet.from_webdataset(f.name) assert list(cuts.ids) == list(cuts_ds.ids) for c, cds in zip(cuts, cuts_ds): np.testing.assert_equal(c.load_audio(), cds.load_audio()) np.testing.assert_almost_equal( c.load_features(), cds.load_features(), decimal=2 )
def run_test( rank: Optional[int], n_shards: int, root: str, world_size: Optional[int], expected_cut_ids: List[str], num_workers: int, ) -> None: # Initialize DDP if needed if world_size is not None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12354" torch.distributed.init_process_group( "gloo", world_size=world_size, rank=rank, ) # adjust the expected cut IDs according to rank expected_cut_ids_orig = expected_cut_ids expected_cut_ids = expected_cut_ids[rank::world_size] else: rank = None # Open CutSet with options that de-duplicate the data across nodes and workers cuts_wds = CutSet.from_webdataset( "%s/shard-{000000..%06d}.tar" % (root, n_shards - 1), split_by_node=True, split_by_worker=True, shuffle_shards=True, ) # Iterate the data tot = 0 cut_ids = [] sampler = SimpleCutSampler(cuts_wds, max_duration=100, rank=0, world_size=1) dloader = DataLoader( IterableDatasetWrapper(dataset=DummyDataset(), sampler=sampler), batch_size=None, num_workers=num_workers, worker_init_fn=make_worker_init_fn( rank=rank, world_size=world_size, ), ) for batch in dloader: tot += len(batch) for c in batch: cut_ids.append(c.id) print(f"[Rank {rank}/{world_size}] Actual cuts: ", sorted(cut_ids)) print(f"[Rank {rank}/{world_size}] Expected cuts: ", sorted(expected_cut_ids)) try: assert tot == len( expected_cut_ids), f"{tot} != {len(expected_cut_ids)}" assert sorted(cut_ids) == sorted( expected_cut_ids ), f"{sorted(cut_ids)}\n!=\n{sorted(expected_cut_ids)}" except AssertionError: # Pytest doesn't work great with subprocesses print(traceback.print_exc()) raise