예제 #1
0
def import_occs(occs_file):
    """ Reads data from an oocs file
    """
    try:
        with open(occs_file) as fid:
            occs = fid.readline().strip()
    except UnicodeDecodeError:
        with NamedTemporaryFile() as tmpfile:
            excute_kaldi_commands(
                [f'copy-vector --binary=false {occs_file} {tmpfile.name}'],
                'convert occs'
            )
            with open(tmpfile.name) as fid:
                occs = fid.readline().strip()
    occs = occs.replace('[', '').replace(']', '').split()
    occs = [occ.split('.')[0] for occ in occs]
    return np.array(occs, dtype=np.int32)
예제 #2
0
def write_transcription_file(
    out_of_vocabulary_mapping_file: Path,
    word_mapping_file: Path,
    word_transcription_file: Path,
    mapped_transcription_file: Path,
):
    """
    Code tested with WSJ database and derived databases.

    Args:
        out_of_vocabulary_mapping_file: Contains an integer to which all OOVs
            are going to be mapped.
            Typically: `db.lang_path / 'oov.int'`
        word_mapping_file:
            It has this form:
                <eps> 0
                !EXCLAMATION-POINT 1
                !SIL 2
        word_transcription_file: If you want to align own data, you need
            to write this file, first.
            Typically: `db.lang_path / 'words.txt'`
            It has this form:
                011c0201 THE SALE OF THE HOTELS ...
        mapped_transcription_file: Output file.
            Typically: `*.tra`
            It has this form:
                011c0201 110920 96431 79225 110920 52031 ...
    Returns:

    """
    sym2int_pl_file = (kaldi_root / "egs" / "wsj" / "s5" / "utils" /
                       "sym2int.pl")

    for file in (sym2int_pl_file, out_of_vocabulary_mapping_file,
                 word_mapping_file, word_transcription_file):
        assert file.is_file(), file
    assert mapped_transcription_file.parent.is_dir(), mapped_transcription_file

    with out_of_vocabulary_mapping_file.open() as f:
        oov = f.read().strip()

    command = (
        f"{sym2int_pl_file.resolve().absolute()} "
        f"--map-oov {oov} "
        f"-f 2- "  # Will map from second item onwards (skipping utt id).
        f"{word_mapping_file.resolve().absolute()} "
        f"{word_transcription_file.resolve().absolute()} "
        f"> {mapped_transcription_file.resolve().absolute()}")

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(command, env=env)

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
예제 #3
0
def forced_alignment(
        log_posteriors_ark_file: Path,
        graphs_file: Path,
        model_file: Path,
        alignment_dir: Path,
        beam: int=200,
        retry_beam: int=400,
        part=1
):
    """

    Args:
        log_posteriors_ark_file: E.g. `log_posteriors.ark`
        graphs_file: E.g. `graphs.fsts`
        model_file: E.g. `s5/exp/tri4b/final.mdl`
        alignment_dir:
        beam: Kaldi recipes (e.g. WSJ) typically use 10.
        retry_beam: Kaldi recipes (e.g. WSJ) typically use 40.
        part: Could be used for parallel processing.

    Returns:

    """
    if not part == 1:
        raise NotImplementedError(
            "I believe that the `log_posteriors_ark_file` and the "
            "`graphs_file` already needs to be chunked to support parallelism."
        )

    command = (
        'align-compiled-mapped '
        f'--beam={beam} '
        f'--retry-beam={retry_beam} '
        f'{model_file} '
        f'ark:{graphs_file} '
        f'ark:{log_posteriors_ark_file} '
        f'ark,t:|gzip -c > {alignment_dir}/ali.{part}.gz'
    )

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(
        command,
        name=sys._getframe().f_code.co_name,
        env=env
    )

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
예제 #4
0
def compile_train_graphs(
        tree_file: Path,
        model_file: Path,
        lexicon_fst_file: Path,
        integer_transcription_file: Path,
        output_graphs_file: Path
):
    """
    Initial step to prepare for forced alignment.

    Args:
        tree_file: E.g. `s5/exp/tri4b/tree`
        model_file: E.g. `s5/exp/tri4b/final.mdl`
        lexicon_fst_file: E.g. `lang_path / 'L.fst'`
        integer_transcription_file: E.g. `train.tra`
        output_graphs_file: E.g. `graphs.fsts`

    Returns:

    """
    command = (
        f"compile-train-graphs "
        f"{tree_file.resolve().absolute()} "
        f"{model_file.resolve().absolute()} "
        f"{lexicon_fst_file.resolve().absolute()} "
        f"ark:{integer_transcription_file.resolve().absolute()} "
        f"ark:{output_graphs_file.resolve().absolute()}"
    )

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(
        command,
        name=sys._getframe().f_code.co_name,
        env=env
    )

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
예제 #5
0
def compute_scores(decode_dir,
                   hclg_dir,
                   ref_text,
                   min_lmwt=8,
                   max_lmwt=18,
                   force_scoring=False,
                   build_tra=True,
                   strict=True,
                   ignore_return_codes=True):
    LOG = logging.getLogger('computer_scores')

    decode_dir = str(decode_dir)
    hclg_dir = str(hclg_dir)
    ref_text = str(ref_text)

    decode_dir = os.path.abspath(decode_dir)
    mkdir_p(os.path.join(decode_dir, 'scoring'))
    ref_file = f'{decode_dir}/scoring/test_filt.txt'
    cmd = (f"cat {ref_text} | sed 's:<NOISE>::g' | sed 's:<SPOKEN_NOISE>::g' "
           f"> {ref_file}")
    helper.excute_kaldi_commands([cmd],
                                 'copying reference transcription',
                                 log_dir=decode_dir + '/logs',
                                 ignore_return_code=ignore_return_codes)
    cmds = list()
    tra_written = dict()
    for lmwt in range(min_lmwt, max_lmwt + 1):
        tra_file = f'{decode_dir}/scoring/{lmwt}.tra'
        rescale = not os.path.exists(tra_file)
        rescale &= not _tra_complete(tra_file, ref_file)
        rescale &= _lattices_exists(ref_file, f'{decode_dir}/lats')
        rescale &= build_tra
        rescale |= force_scoring
        if rescale:
            LOG.info(f'Rescaling lattice for lmwt {lmwt}')
            cmds.append(_build_rescale_lattice_cmd(decode_dir, hclg_dir, lmwt))
            tra_written[lmwt] = True
    if len(cmds):
        helper.excute_kaldi_commands(cmds,
                                     'rescaling lattice',
                                     log_dir=decode_dir + '/logs',
                                     ignore_return_code=ignore_return_codes)
    else:
        LOG.info('All utts already rescaled - skipping')
    cmds = list()
    for lmwt in range(min_lmwt, max_lmwt + 1):
        if lmwt in tra_written:
            LOG.info(f'Computing WER for lmwt {lmwt}')
            cmds.append(
                _build_compute_WER_command(decode_dir,
                                           hclg_dir,
                                           lmwt,
                                           strict=strict))
    if len(cmds):
        helper.excute_kaldi_commands(cmds,
                                     'computing WER',
                                     log_dir=decode_dir + '/logs',
                                     ignore_return_code=ignore_return_codes)
    result = defaultdict(list)
    for lmwt in range(min_lmwt, max_lmwt + 1):
        wer, errors, words, ins, del_, sub = parse_wer_file(
            decode_dir + '/wer_{}'.format(lmwt))
        result['wer'].append(float(wer))
        result['errors'].append(int(errors))
        result['words'].append(int(words))
        result['ins'].append(int(ins))
        result['del'].append(int(del_))
        result['sub'].append(int(sub))
        result['decode_dir'].append(decode_dir)
        result['lmwt'].append(int(lmwt))
    res = pandas.DataFrame(result)
    with open(decode_dir + '/result.pkl', 'wb') as fid:
        pickle.dump(res, fid)
    return result.copy()