示例#1
0
def to_lattice(src_dir, fst, lattice, cost_type="acoustic"):
    fst_tmp_as_txt_path = path.join(src_dir, "fst_as_text_tmp.txt")
    cmd = "fstprint " + fst + "> " + fst_tmp_as_txt_path
    excute_kaldi_commands(cmd)
    fst_tmp_as_txt_fid = open(fst_tmp_as_txt_path, 'r')
    lattice_path = path.join(src_dir, lattice)
    lattice_fid = open(lattice_path, 'w')
    lattice_fid.write("UTT_ID \n")

    if cost_type not in ["acoustic", "lm"]:
        print("unknown sort_type: using default (ilabel)")

    for line in fst_tmp_as_txt_fid:
        columns = line.split()
        if len(columns) == 5:
            if cost_type == "acoustic":
                lattice_fid.write("{0}\t{1}\t{2}\t{3}\t0,{4}\n".format(
                    columns[0], columns[1], columns[2], columns[3],
                    columns[4]))
            else:
                lattice_fid.write("{0}\t{1}\t{2}\t{3}\t{4},0\n".format(
                    columns[0], columns[1], columns[2], columns[3],
                    columns[4]))
        elif len(columns) == 4:
            lattice_fid.write("{0}\t{1}\t{2}\t{3}\n".format(
                columns[0], columns[1], columns[2], columns[3]))
        else:
            lattice_fid.write("{0}\n".format(columns[0]))

        del columns

    lattice_fid.write("\n")
    fst_tmp_as_txt_fid.close()
    lattice_fid.close()
    remove(fst_tmp_as_txt_path)
示例#2
0
def arcsort(fst, sort_type='ilabel'):
    """ Sort a given fst

    :param fst: fst to sort in binary format
    :param sort_type: sort type: ilabel or olabel
    """

    fst_tmp = path.join(path.dirname(fst), 'fst.tmp')
    cmd = fstarcsort_cmd(fst, fst_tmp, sort_type)
    excute_kaldi_commands(cmd)
    rename(fst_tmp, fst)
示例#3
0
def shortestpath(fst, output_file, nshortest=1, **kwargs):
    """ Find shortet path through fst

    :param fst: input fst in binary format
    :param output_file: output fst in binary format
    :param nshortest: number of shortest paths
    :param kwargs: see fstpostprocess_cmd
    """

    cmd = fstshortestpath_cmd(fst, output_file, nshortest, **kwargs)
    excute_kaldi_commands(cmd)
示例#4
0
def draw_search_graph(search_graph, isym_table, osym_table):
    with tempfile.NamedTemporaryFile() as tmp_graph:
        cmd = 'lattice-to-fst ark:{} ark,t:- | tail -n +2 > {}'.format(
            search_graph, tmp_graph.name)
        excute_kaldi_commands(cmd)
        with tempfile.NamedTemporaryFile() as search_fst:
            build_from_txt(tmp_graph.name,
                           search_fst.name,
                           determinize=False,
                           minimize=False)
            draw(isym_table, osym_table, search_fst.name,
                 search_graph + '.pdf')
示例#5
0
def compose(fst1, fst2, output_file, phi=None, **kwargs):
    """ compse two fsts

    :param fst1: right fst in binary format
    :param fst2: left fst in binary format
    :param output_file: output fst in binary format
    :param phi: phi symbol to be used for phi composition
    :param kwargs: see fstpostprocess_cmd
    """

    cmd = fstcompose_cmd(fst1, fst2, output_file, phi, **kwargs)
    excute_kaldi_commands(cmd)
示例#6
0
def build_from_txt(transducer_as_txt,
                   output_file,
                   isym_table=None,
                   osym_table=None,
                   determinize=True,
                   minimize=True,
                   addselfloops=False,
                   disambig_in=0,
                   disambig_out=0,
                   rmepsilon=False,
                   sort_type="ilabel",
                   input_as_txt=None):
    """ build transducer from text file or text input

    :param transducer_as_txt: input fst in text format
    :param output_file: output fst in binary format
    :param isym_table: input symbol table file
    :param osym_table: output symbol table file
    :param determinize: determinize fst
    :param minimize: minimize fst
    :param addselfloops: add self loops to fst
    :param disambig_in: list of input symbols
    :param disambig_out: list of corresponding output symbols
    :param rmepsilon: rmepsilons
    :param sort_type: sort type - ilabel or olabel
    :param input_as_txt: optional input in text format
                         (only used if transducer_as_txt is None)
    """

    if addselfloops:
        cmd = fstcompile_cmd(transducer_as_txt,
                             isym_table=isym_table,
                             osym_table=osym_table,
                             determinize=determinize,
                             minimize=minimize,
                             arcsort=False)
        cmd += fstaddselfloops_cmd(out_fst=output_file,
                                   disambig_in=disambig_in,
                                   disambig_out=disambig_out,
                                   rmepsilon=rmepsilon,
                                   sort_type=sort_type)
    else:
        cmd = fstcompile_cmd(transducer_as_txt,
                             output_file,
                             isym_table,
                             osym_table,
                             determinize=determinize,
                             minimize=minimize,
                             rmepsilon=rmepsilon,
                             sort_type=sort_type)

    excute_kaldi_commands(cmd, inputs=input_as_txt)
示例#7
0
def randgen(fst, output_file, select='log_prob', npath=1, **kwargs):
    """ Randomly generate path through given fst

    :param fst: input fst in binary format
    :param output_file: output fst in binary format
    :param select: arc selector for random generation:
                   'log_prob' (treat weights as negative log prob),
                   'uniform'  (draw uniformly)
    :param npath: number of paths to generate
    :param kwargs: see fstpostprocess_cmd
    """

    cmd = fstrandgen_cmd(fst, output_file, select, npath, **kwargs)
    excute_kaldi_commands(cmd)
示例#8
0
def compile_train_graphs(tree_file: Path, model_file: Path,
                         lexicon_fst_file: Path,
                         integer_transcription_file: Path,
                         output_graphs_file: Path):
    """
    Initial step to prepare for forced alignment.

    Args:
        tree_file: E.g. `s5/exp/tri4b/tree`
        model_file: E.g. `s5/exp/tri4b/final.mdl`
        lexicon_fst_file: E.g. `lang_path / 'L.fst'`
        integer_transcription_file: E.g. `train.tra`
        output_graphs_file: E.g. `graphs.fsts`

    Returns:

    """
    command = (f"compile-train-graphs "
               f"{tree_file.resolve().absolute()} "
               f"{model_file.resolve().absolute()} "
               f"{lexicon_fst_file.resolve().absolute()} "
               f"ark:{integer_transcription_file.resolve().absolute()} "
               f"ark:{output_graphs_file.resolve().absolute()}")

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(
        command, name=sys._getframe().f_code.co_name, env=env)

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
示例#9
0
def import_occs(occs_file):
    """ Reads data from an oocs file
    """
    try:
        with open(occs_file) as fid:
            occs = fid.readline().strip()
    except UnicodeDecodeError:
        with NamedTemporaryFile() as tmpfile:
            excute_kaldi_commands(
                [f'copy-vector --binary=false {occs_file} {tmpfile.name}'],
                'convert occs')
            with open(tmpfile.name) as fid:
                occs = fid.readline().strip()
    occs = occs.replace('[', '').replace(']', '').split()
    occs = [occ.split('.')[0] for occ in occs]
    return np.array(occs, dtype=np.int32)
示例#10
0
def check_fst_valid(fst, fast=True):
    if fast:
        return path.exists(fst) and (stat(fst).st_size > 0)
    cmd = 'fstinfo {}'.format(fst)
    stdout, stderr, return_codes = excute_kaldi_commands(
        cmd, ignore_return_code=True)
    if return_codes[0] != 0:
        return False
    else:
        return True
示例#11
0
def write_transcription_file(
    out_of_vocabulary_mapping_file: Path,
    word_mapping_file: Path,
    word_transcription_file: Path,
    mapped_transcription_file: Path,
):
    """
    Code tested with WSJ database and derived databases.

    Args:
        out_of_vocabulary_mapping_file: Contains an integer to which all OOVs
            are going to be mapped.
            Typically: `db.lang_path / 'oov.int'`
        word_mapping_file:
            It has this form:
                <eps> 0
                !EXCLAMATION-POINT 1
                !SIL 2
        word_transcription_file: If you want to align own data, you need
            to write this file, first.
            Typically: `db.lang_path / 'words.txt'`
            It has this form:
                011c0201 THE SALE OF THE HOTELS ...
        mapped_transcription_file: Output file.
            Typically: `*.tra`
            It has this form:
                011c0201 110920 96431 79225 110920 52031 ...
    Returns:

    """
    sym2int_pl_file = (kaldi_root / "egs" / "wsj" / "s5" / "utils" /
                       "sym2int.pl")

    for file in (sym2int_pl_file, out_of_vocabulary_mapping_file,
                 word_mapping_file, word_transcription_file):
        assert file.is_file(), file
    assert mapped_transcription_file.parent.is_dir(), mapped_transcription_file

    with out_of_vocabulary_mapping_file.open() as f:
        oov = f.read().strip()

    command = (
        f"{sym2int_pl_file.resolve().absolute()} "
        f"--map-oov {oov} "
        f"-f 2- "  # Will map from second item onwards (skipping utt id).
        f"{word_mapping_file.resolve().absolute()} "
        f"{word_transcription_file.resolve().absolute()} "
        f"> {mapped_transcription_file.resolve().absolute()}")

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(command, env=env)

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
示例#12
0
def draw(isym_table, osym_table, fst_file, output_file):
    """ draw fst as pdf

    :param isym_table: input symbol table file
    :param osym_table: output symbol table file
    :param fst_file: input fst in binary format
    :param output_file: ouput pdf file
    """
    cmd = 'fstdraw --portrait=true --height=17 --width=22'
    if isym_table is not None:
        cmd += ' --isymbols={}'.format(isym_table)
    if osym_table is not None:
        cmd += ' --osymbols={}'.format(osym_table)
    cmd += ' {}'.format(fst_file)
    excute_kaldi_commands(cmd +
                          ' > {}'.format(output_file.replace('pdf', 'dot')))

    cmd += ' | dot -Tpdf > {}'.format(output_file)

    excute_kaldi_commands(cmd)
示例#13
0
def remove_disambig_symbols(fst1,
                            fst2,
                            special_symbol_ids,
                            sort_type=None,
                            minimize=True):
    with tempfile.NamedTemporaryFile() as disambig_list:
        with open(disambig_list.name, 'w') as fid:
            for s_id in special_symbol_ids:
                if s_id is not None:
                    fid.write('{}\n'.format(s_id))
        if minimize:
            min_cmd = ' | fstminimizeencoded'
        else:
            min_cmd = ''
        if sort_type is not None:
            arcsort_cmd = ' | fstarcsort --sort_type={} '.format(sort_type)
        else:
            arcsort_cmd = ''
        cmd = 'cat {} | fstrmsymbols {} | fstrmepslocal' \
              '{}{} > {}'.format(fst1, disambig_list.name,
                                 min_cmd, arcsort_cmd, fst2)
        excute_kaldi_commands(cmd)
示例#14
0
def forced_alignment(log_posteriors_ark_file: Path,
                     graphs_file: Path,
                     model_file: Path,
                     alignment_dir: Path,
                     beam: int = 200,
                     retry_beam: int = 400,
                     part=1):
    """

    Args:
        log_posteriors_ark_file: E.g. `log_posteriors.ark`
        graphs_file: E.g. `graphs.fsts`
        model_file: E.g. `s5/exp/tri4b/final.mdl`
        alignment_dir:
        beam: Kaldi recipes (e.g. WSJ) typically use 10.
        retry_beam: Kaldi recipes (e.g. WSJ) typically use 40.
        part: Could be used for parallel processing.

    Returns:

    """
    if not part == 1:
        raise NotImplementedError(
            "I believe that the `log_posteriors_ark_file` and the "
            "`graphs_file` already needs to be chunked to support parallelism."
        )

    command = ('align-compiled-mapped '
               f'--beam={beam} '
               f'--retry-beam={retry_beam} '
               f'{model_file} '
               f'ark:{graphs_file} '
               f'ark:{log_posteriors_ark_file} '
               f'ark,t:|gzip -c > {alignment_dir}/ali.{part}.gz')

    # Why does this execute in `.../egs/wsj/s5`?
    env = kaldi_helper.get_kaldi_env()
    _, std_err_list, _ = kaldi_helper.excute_kaldi_commands(
        command, name=sys._getframe().f_code.co_name, env=env)

    for line in std_err_list[0].split('\n'):
        LOG.info(line)
示例#15
0
def compute_scores(decode_dir,
                   hclg_dir,
                   ref_text,
                   min_lmwt=8,
                   max_lmwt=18,
                   force_scoring=False,
                   build_tra=True,
                   strict=True,
                   ignore_return_codes=True):
    LOG = logging.getLogger('computer_scores')

    decode_dir = str(decode_dir)
    hclg_dir = str(hclg_dir)
    ref_text = str(ref_text)

    decode_dir = os.path.abspath(decode_dir)
    mkdir_p(os.path.join(decode_dir, 'scoring'))
    ref_file = f'{decode_dir}/scoring/test_filt.txt'
    cmd = (f"cat {ref_text} | sed 's:<NOISE>::g' | sed 's:<SPOKEN_NOISE>::g' "
           f"> {ref_file}")
    helper.excute_kaldi_commands([cmd],
                                 'copying reference transcription',
                                 log_dir=decode_dir + '/logs',
                                 ignore_return_code=ignore_return_codes)
    cmds = list()
    tra_written = dict()
    for lmwt in range(min_lmwt, max_lmwt + 1):
        tra_file = f'{decode_dir}/scoring/{lmwt}.tra'
        rescale = not os.path.exists(tra_file)
        rescale &= not _tra_complete(tra_file, ref_file)
        rescale &= _lattices_exists(ref_file, f'{decode_dir}/lats')
        rescale &= build_tra
        rescale |= force_scoring
        if rescale:
            LOG.info(f'Rescaling lattice for lmwt {lmwt}')
            cmds.append(_build_rescale_lattice_cmd(decode_dir, hclg_dir, lmwt))
            tra_written[lmwt] = True
    if len(cmds):
        helper.excute_kaldi_commands(cmds,
                                     'rescaling lattice',
                                     log_dir=decode_dir + '/logs',
                                     ignore_return_code=ignore_return_codes)
    else:
        LOG.info('All utts already rescaled - skipping')
    cmds = list()
    for lmwt in range(min_lmwt, max_lmwt + 1):
        if lmwt in tra_written:
            LOG.info(f'Computing WER for lmwt {lmwt}')
            cmds.append(
                _build_compute_WER_command(decode_dir,
                                           hclg_dir,
                                           lmwt,
                                           strict=strict))
    if len(cmds):
        helper.excute_kaldi_commands(cmds,
                                     'computing WER',
                                     log_dir=decode_dir + '/logs',
                                     ignore_return_code=ignore_return_codes)
    result = defaultdict(list)
    for lmwt in range(min_lmwt, max_lmwt + 1):
        wer, errors, words, ins, del_, sub = parse_wer_file(
            decode_dir + '/wer_{}'.format(lmwt))
        result['wer'].append(float(wer))
        result['errors'].append(int(errors))
        result['words'].append(int(words))
        result['ins'].append(int(ins))
        result['del'].append(int(del_))
        result['sub'].append(int(sub))
        result['decode_dir'].append(decode_dir)
        result['lmwt'].append(int(lmwt))
    res = pandas.DataFrame(result)
    with open(decode_dir + '/result.pkl', 'wb') as fid:
        pickle.dump(res, fid)
    return result.copy()