示例#1
0
def test_cutset_from_webdataset_sharded_pipe():
    cuts = CutSet.from_file("test/fixtures/libri/cuts.json")
    cut = cuts[0]
    cuts = []
    for i in range(10):
        cuts.append(fastcopy(cut, id=cut.id + "-" + str(i)))
    cuts = CutSet.from_cuts(cuts)

    with TemporaryDirectory() as dir_path:
        tar_pattern = f"pipe:gzip -c > {dir_path}/shard-%06d.tar.gz"
        export_to_webdataset(cuts, output_path=tar_pattern, shard_size=2)

        # disabling shard shuffling for testing purposes here
        cuts_ds = CutSet.from_webdataset(
            "pipe:gunzip -c " + dir_path + "/shard-{000000..000004}.tar.gz",
            shuffle_shards=False,
        )

        assert list(cuts.ids) == list(cuts_ds.ids)

        for c, cds in zip(cuts, cuts_ds):
            np.testing.assert_equal(c.load_audio(), cds.load_audio())
            np.testing.assert_almost_equal(
                c.load_features(), cds.load_features(), decimal=2
            )
示例#2
0
def test_webdataset_sampler_epoch_increment():
    cuts = CutSet.from_file("test/fixtures/libri/cuts.json").repeat(10)

    with TemporaryDirectory() as dir_path:
        tar_pattern = f"{dir_path}/shard-%06d.tar"
        export_to_webdataset(cuts, output_path=tar_pattern, shard_size=1)

        cuts_ds = CutSet.from_webdataset(
            [str(p) for p in Path(dir_path).glob("*.tar")], shuffle_shards=True
        )
        sampler = DynamicCutSampler(cuts_ds, max_cuts=1)
        dloader = DataLoader(
            IterableDatasetWrapper(DummyDataset(), sampler, auto_increment_epoch=True),
            batch_size=None,
            num_workers=1,
            persistent_workers=True,
        )

        epoch_batches = {}
        for epoch in [0, 1]:
            batches = []
            for batch in dloader:
                for cut in batch:
                    batches.append(cut)
            epoch_batches[epoch] = CutSet.from_cuts(batches)

        # Both epochs have the same cut IDs.
        assert sorted(epoch_batches[0].ids) == sorted(epoch_batches[1].ids)
        # Both epochs have different cut order (shards were re-shuffled).
        assert list(epoch_batches[0].ids) != list(epoch_batches[1].ids)
示例#3
0
def test_cutset_from_webdataset():
    cuts = CutSet.from_file("test/fixtures/libri/cuts.json")
    cut = cuts[0]
    cuts = []
    for i in range(10):
        cuts.append(fastcopy(cut, id=cut.id + "-" + str(i)))
    cuts = CutSet.from_cuts(cuts)

    with NamedTemporaryFile(suffix=".tar") as f:
        export_to_webdataset(cuts, output_path=f.name)
        f.flush()

        cuts_ds = CutSet.from_webdataset(f.name)

        assert list(cuts.ids) == list(cuts_ds.ids)

        for c, cds in zip(cuts, cuts_ds):
            np.testing.assert_equal(c.load_audio(), cds.load_audio())
            np.testing.assert_almost_equal(
                c.load_features(), cds.load_features(), decimal=2
            )
def run_test(
    rank: Optional[int],
    n_shards: int,
    root: str,
    world_size: Optional[int],
    expected_cut_ids: List[str],
    num_workers: int,
) -> None:

    # Initialize DDP if needed
    if world_size is not None:
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12354"
        torch.distributed.init_process_group(
            "gloo",
            world_size=world_size,
            rank=rank,
        )
        # adjust the expected cut IDs according to rank
        expected_cut_ids_orig = expected_cut_ids
        expected_cut_ids = expected_cut_ids[rank::world_size]
    else:
        rank = None

    # Open CutSet with options that de-duplicate the data across nodes and workers
    cuts_wds = CutSet.from_webdataset(
        "%s/shard-{000000..%06d}.tar" % (root, n_shards - 1),
        split_by_node=True,
        split_by_worker=True,
        shuffle_shards=True,
    )

    # Iterate the data
    tot = 0
    cut_ids = []
    sampler = SimpleCutSampler(cuts_wds,
                               max_duration=100,
                               rank=0,
                               world_size=1)
    dloader = DataLoader(
        IterableDatasetWrapper(dataset=DummyDataset(), sampler=sampler),
        batch_size=None,
        num_workers=num_workers,
        worker_init_fn=make_worker_init_fn(
            rank=rank,
            world_size=world_size,
        ),
    )
    for batch in dloader:
        tot += len(batch)
        for c in batch:
            cut_ids.append(c.id)

    print(f"[Rank {rank}/{world_size}] Actual   cuts: ", sorted(cut_ids))
    print(f"[Rank {rank}/{world_size}] Expected cuts: ",
          sorted(expected_cut_ids))
    try:
        assert tot == len(
            expected_cut_ids), f"{tot} != {len(expected_cut_ids)}"
        assert sorted(cut_ids) == sorted(
            expected_cut_ids
        ), f"{sorted(cut_ids)}\n!=\n{sorted(expected_cut_ids)}"
    except AssertionError:
        # Pytest doesn't work great with subprocesses
        print(traceback.print_exc())
        raise