コード例 #1
0
def get_valid_datasets(task: str) -> Set[str]:
    """Returns the set of valid datasets that can be loaded."""
    assert task.lower() in ["asr", "tts"], task.lower()
    root = get_root()
    non_datasets = {"TEMPLATE", "COMBINE"}
    return {d.lower() for d in os.listdir(root) if d not in non_datasets
            and os.path.isdir(os.path.join(root, d, f"{task}1"))}
コード例 #2
0
def get_combo_scp(datasets: List[str], task: str, feats_type: str) -> \
        Tuple[Optional[str], Optional[str]]:
    """Checks the dataset combination directory to see if this particular
    combination of datasets has been prepared already. Returns the SCP file
    and data format (hdf5 or mat) if it has, (None, None) otherwise."""
    combo_idx = get_combo_idx(datasets, task)
    dirname = os.path.join(get_root(), "COMBINE", f"{task}1", "dump",
                           feats_type, str(combo_idx))
    if combo_idx < 0 or not os.path.isdir(dirname):
        return None, None

    try:
        with open(os.path.join(dirname, "archive_format")) as f:
            fmt = f.readline().rstrip()
            assert fmt in ["hdf5", "mat"], \
                f"Combined dataset must be dumped to 'hdf5' or 'mat' " \
                f"archives, but got {fmt} instead."

        with open(os.path.join(dirname, "data_format")) as f:
            scp = f"{dirname}/{f.readline().rstrip()}.scp"
            assert os.path.isfile(scp), f"{scp} is not a file."

        return scp, fmt

    except Exception as e:
        logger.warning(
            f"{feats_type} data has not been properly dumped/prepared for "
            f"combined dataset. Failed loading with exception:\n{str(e)}")
        logger.warning("Loading combined dataset from separate SCP's instead")
        return None, None
コード例 #3
0
ファイル: cmvn.py プロジェクト: salesforce/speech-datasets
    def __init__(self, cmvn_type: str, stats: str = None, norm_means=True,
                 norm_vars=False, utt2spk: str = None, reverse=False,
                 std_floor=1.0e-20):
        self.norm_means = norm_means
        self.norm_vars = norm_vars
        self.reverse = reverse
        self.std_floor = std_floor

        assert cmvn_type in ["global", "speaker", "utterance"], cmvn_type
        self.accept_uttid = (cmvn_type != "global")
        self.cmvn_type = cmvn_type
        if cmvn_type != "utterance":
            assert stats is not None, "stats required if cmvn_type != 'utterance'"
            try:
                self.stats_file = stats
                stats_dict = read_cmvn_stats(self.stats_file, cmvn_type)
            except FileNotFoundError:
                self.stats_file = os.path.join(get_root(), stats)
                stats_dict = read_cmvn_stats(self.stats_file, cmvn_type)
        else:
            if stats is not None:
                logger.warning("stats file is not used if cmvn_type is 'utterance'")
            self.stats_file = None
            stats_dict = {}

        if cmvn_type == "speaker":
            assert utt2spk is not None, "utt2spk required if cmvn_type is 'speaker'"
            self.utt2spk = {}
            with io.open(utt2spk, "r", encoding="utf-8") as f:
                for line in f:
                    utt, spk = line.rstrip().split(None, maxsplit=1)
                    self.utt2spk[utt] = spk
        else:
            if utt2spk is not None:
                logger.warning("utt2spk is only used if cmvn_type is 'speaker'")
            self.utt2spk = None

        # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
        # and the first vector contains the sum of feats and the second is
        # the sum of squares. The last value of the first, i.e. stats[0,-1],
        # is the number of samples for this statistics.
        self.bias = {}
        self.scale = {}
        for spk, stats in stats_dict.items():
            # Var[x] = E[x^2] - E[x]^2
            mean = stats.sum / stats.count
            var = stats.sum_squares / stats.count - mean * mean
            std = np.maximum(np.sqrt(var), std_floor)
            self.bias[spk] = -mean
            self.scale[spk] = 1 / std
コード例 #4
0
def validate_datasets(datasets: List[str], task: str, feats_type: str) \
        -> Dict[str, List[str]]:
    """Makes sure that the dataset names given (in form '<dataset>/<sub_dataset>')
    are valid, and that they have been prepared appropriately. Returns a
    Dict[str, Set[str]] dataset2subs, where dataset2subs[dataset] contains all
    sub-speech_datasets of dataset that were given."""
    dataset2subs = defaultdict(set)
    for d in datasets:
        try:
            dataset, sub = d.split("/", maxsplit=1)
        except ValueError:
            raise ValueError(
                f"Datasets must be specified as <dataset>/<sub_dataset> ",
                f"(e.g. wsj/train_si284, librispeech/dev-other, etc.), "
                f"but got {d} instead")
        dataset2subs[dataset.lower()].add(sub)

    # Make sure the task, feature type, and speech_datasets are all valid
    valid = get_valid_datasets(task)
    invalid = set(dataset2subs.keys()).difference(valid)
    assert len(invalid) == 0, \
        f"Invalid datasets: {invalid}. Valid {task} datasets are: {valid}."

    for dataset, subs in dataset2subs.items():
        root = os.path.join(get_root(), dataset, f"{task}1", "dump", feats_type)
        if not os.path.exists(root):
            valid_subs = {}
        else:
            valid_subs = {d for d in os.listdir(root)
                          if os.path.isdir(os.path.join(root, d))}
            valid_subs = valid_subs.difference({"orig", "no_short"})

        invalid_subs = {f"{dataset}/{s}" for s in subs.difference(valid_subs)}
        valid_subs = {f"{dataset}/{s}" for s in valid_subs}
        if len(valid_subs) == 0:
            raise FileNotFoundError(
                f"{feats_type} data has not been dumped/prepared for dataset "
                f"{dataset}.")
        if len(invalid_subs) > 0:
            raise FileNotFoundError(
                f"The following are invalid splits of dataset {dataset}: "
                f"{invalid_subs}\nValid splits are: {valid_subs}")

    # Convert dataset2subs to a SORTED OrderedDict[str, List[str]]
    # This is to ensure reproducibility across different processes
    dataset2subs = OrderedDict((k, sorted(dataset2subs[k]))
                               for k in sorted(dataset2subs.keys()))
    return dataset2subs
コード例 #5
0
def get_combo_idx(datasets: List[str], task: str) -> int:
    """Check if a particular combination of `datasets` (each formatted as
    f"{dataset}/{split}") has been registered yet for task `task`."""
    # Check if data combination registry has been created yet
    combine_dir = os.path.join(get_root(), "COMBINE", f"{task}1")
    registry = os.path.join(combine_dir, "data", "registry.txt")
    if not os.path.isfile(registry):
        return -1

    # Check if this particular combo is in the data combo registry
    datasets = sorted(set(datasets))
    with open(registry) as f:
        registry = [sorted(set(line.rstrip().split())) for line in f]
    combo_idxs = [i + 1 for i, d in enumerate(registry) if d == datasets]

    return -1 if len(combo_idxs) == 0 else combo_idxs[0]
コード例 #6
0
def get_dataset_scps(datasets: List[str], task: str, feats_type: str,
                     check_combine=True) -> \
        Tuple[List[str], type]:
    """Given a list of <dataset>/<split> specifiers, obtain the SCP file(s) that
    index all the datasets (and splits) that we care about."""
    scp_files = []
    fmts = []
    dsets = []
    fmt2reader = {"hdf5": HDF5Reader, "mat": KaldiReader}

    if check_combine:
        scp, fmt = get_combo_scp(datasets, task, feats_type)
        if scp is not None:
            return [scp], fmt2reader[fmt]

    dataset2subs = validate_datasets(datasets, task, feats_type)
    for dataset, subs in dataset2subs.items():
        root = os.path.join(get_root(), dataset, f"{task}1", "dump", feats_type)
        for sub in subs:
            name = f"{dataset}/{sub}"
            dsets.append(name)
            dirname = os.path.join(root, sub)
            try:
                with open(os.path.join(dirname, "archive_format")) as f:
                    fmt = f.readline().rstrip()
                assert fmt in fmt2reader.keys(), \
                    f"Dataset {name} must be dumped to one of " \
                    f"{set(fmt2reader.keys())} archives, but got {fmt} instead."
                fmts.append(fmt)

                with open(os.path.join(dirname, "data_format")) as f:
                    scp = f"{f.readline().rstrip()}.scp"
                scp_files.append(os.path.join(dirname, scp))

            except Exception:
                logger.error(f"{feats_type} data has not been properly dumped/"
                             f"prepared for dataset {name}.")
                raise

    if not all(f == fmts[0] for f in fmts):
        err = "\n".join(f"{dset}: {fmt}" for dset, fmt in zip(dsets, fmts))
        raise RuntimeError(
            "Expected all datasets to be dumped to the same archive "
            "format, but got:\n" + err)

    return scp_files, fmt2reader[fmts[0]]
コード例 #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, choices=["asr", "tts"])
    parser.add_argument("--write_dir", type=str2bool, default=True)
    parser.add_argument("datasets", nargs="+", type=str)
    args = parser.parse_args()

    # Ensure that all datasets are specified as <dataset>/<split>
    datasets = sorted(set(args.datasets))
    dataset_splits = [d.split("/", maxsplit=1) for d in datasets]
    assert all(len(d) == 2 for d in dataset_splits), \
        f"All datasets must be specified as <dataset>/<split>, but got " \
        f"{datasets} instead"

    # Verify that all datasets have been prepared
    dataset_dirs = [os.path.join(get_root(), ds[0], f"{args.task}1", "data", ds[1])
                    for ds in dataset_splits]
    assert all(os.path.isdir(d) for d in dataset_dirs), \
        f"Please make sure that all dataset splits are valid, and that all " \
        f"datasets you wish to combine have already been prepared by stage 1 " \
        f"of {args.task}.sh"

    # Get the index of this dataset combination (add to the registry if needed)
    idx = get_combo_idx(datasets, args.task)
    data_dir = os.path.join(get_root(), "COMBINE", f"{args.task}1", "data")
    if idx < 0:
        os.makedirs(data_dir, exist_ok=True)
        with open(os.path.join(data_dir, "registry.txt"), "a") as f:
            f.write(" ".join(datasets) + "\n")
        idx = get_combo_idx(datasets, args.task)

    if not args.write_dir:
        return idx

    # Create a directory for this dataset combo & prepare it
    dirname = os.path.join(data_dir, str(idx))
    os.makedirs(dirname, exist_ok=True)
    write_segments = any(os.path.isfile(os.path.join(d, "segments"))
                         for d in dataset_dirs)
    with open(os.path.join(dirname, "wav.scp"), "wb") as wav, \
            open(os.path.join(dirname, "text"), "wb") as text, \
            open(os.path.join(dirname, "utt2spk"), "wb") as utt2spk, \
            open(os.path.join(dirname, "segments"), "w") as segments:
        for d in dataset_dirs:

            # wav.scp, text, and utt2spk can just be concatenated on
            with open(os.path.join(d, "wav.scp"), "rb") as src_wav:
                shutil.copyfileobj(src_wav, wav)
            with open(os.path.join(d, "text"), "rb") as src_text:
                shutil.copyfileobj(src_text, text)
            with open(os.path.join(d, "utt2spk"), "rb") as src_utt2spk:
                shutil.copyfileobj(src_utt2spk, utt2spk)

            if write_segments:
                # If a segments file exists, we can just concatenate it on
                if os.path.isfile(os.path.join(d, "segments")):
                    with open(os.path.join(d, "segments"), "r") as src_segments:
                        shutil.copyfileobj(src_segments, segments)

                # Otherwise, we need to use wav.scp to create a dummy segments
                # line format is <segment_id> <record_id> <start_time> <end_time>
                # <start_time> = 0, <end_time> = -1 means use the whole recording
                else:
                    with open(os.path.join(d, "wav.scp"), "r") as src_wav:
                        for line in src_wav:
                            utt_id, _ = line.rstrip().split(None, maxsplit=1)
                            segments.write(f"{utt_id} {utt_id} 0.0 -1.0\n")

    return idx