コード例 #1
0
ファイル: lat2fsg.py プロジェクト: lwllovewf2010/jarvis
def lat2fsg(lat, fsgfile, lmfst, prune=15):
    if isinstance(lat, str):
        if lat.endswith(".slf"):
            dag = lattice.Dag(htkfile=lat)
        else:
            dag = lattice.Dag(lat)
    else:
        dag = lat
    fst = build_lattice_fsg(dag, lmfst.InputSymbols())
    # Compose it (intersect, really) with the language model to get
    # correct N-gram scores (otherwise it is just a unigram LM).  This
    # is the same thing as "lattice expansion".
    phi = lmfst.InputSymbols().Find("φ")
    if phi != -1:
        opts = openfst.StdPhiComposeOptions()
        opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE)
        opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi)
        cfst = openfst.StdComposeFst(fst, lmfst, opts)
    else:
        cfst = openfst.StdComposeFst(fst, lmfst)
    outfst = openfst.StdVectorFst(cfst)
    openfst.Prune(outfst, prune)
    # Write it back out as an FSG for PocketSphinx.
    build_fsg_fst(outfst, fsgfile)
    return outfst
コード例 #2
0
ファイル: lat2fsg.py プロジェクト: lwllovewf2010/jarvis
def lat2flat(latfile, fsgfile, lmfst):
    """
    Subset a language model using the vocabulary of a lattice.
    """
    dag = lattice.Dag(latfile)
    fst = openfst.StdVectorFst()
    fst.SetStart(fst.AddState())
    fst.SetFinal(0, 0)
    syms = lmfst.InputSymbols()
    seen = set()
    for n in dag.nodes:
        # Skip fillers as they have been "bypassed" by PocketSphinx
        if n.sym.startswith("++") or n.sym == "<sil>":
            continue
        if n.sym in seen:
            continue
        seen.add(n.sym)
        sym = syms.Find(baseword(n.sym))
        if sym == -1:
            continue
        fst.AddArc(0, sym, sym, 0, 0)
    fst.SetOutputSymbols(lmfst.InputSymbols())
    phi = lmfst.InputSymbols().Find("&phi;")
    if phi != -1:
        opts = openfst.StdPhiComposeOptions()
        opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE)
        opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi)
        cfst = openfst.StdComposeFst(fst, lmfst, opts)
    else:
        cfst = openfst.StdComposeFst(fst, lmfst)
    outfst = openfst.StdVectorFst()
    openfst.Determinize(cfst, outfst)
    # Write it back out as an FSG for PocketSphinx.
    build_fsg_fst(outfst, fsgfile)
    return outfst
コード例 #3
0
ファイル: lat_rescore.py プロジェクト: lwllovewf2010/jarvis
def lat_rescore(latfile, lmfst):
    """
    Rescore a lattice using a language model.
    """
    dag = lattice.Dag(latfile)
    end = dag.bestpath(lm)
    words = []
    return [lattice.baseword(x.sym) for x in dag.backtrace(end)], end.score
コード例 #4
0
ファイル: lat2fsg.py プロジェクト: lwllovewf2010/jarvis
def lat2fsg_posterior(lat, fsgfile, prune=5, errfst=None):
    if isinstance(lat, str):
        if lat.endswith(".slf"):
            dag = lattice.Dag(htk_file=lat)
        else:
            dag = lattice.Dag(lat)
    else:
        dag = lat
    dag.posterior_prune(-prune)
    fst = build_lattice_fsg(dag, pscale=1)
    if errfst:
        fst = build_lattice_fsg(dag, errfst.InputSymbols(), pscale=1, addsyms=True)
        errfst.SetOutputSymbols(errfst.InputSymbols())
        fst = apply_errfst(fst, errfst)
    else:
        fst = build_lattice_fsg(dag, pscale=1)
    build_fsg_fst(fst, fsgfile)
    return fst
コード例 #5
0
def load_denlat(latfile):
    # Annotate the lattice with posterior probabilities
    dag = lattice.Dag(latfile)
    dag.remove_unreachable()

    arc = []
    for w in dag.nodes:
        for l in w.exits:
            if (w.entry, w.sym, l.dest.entry, l.dest.sym) not in arc:
                arc.append((w.entry, w.sym, l.dest.entry, l.dest.sym))

    lat = {}
    for n in arc:
        if n[1] != '<s>':
            left = []
            right = []
            for m in arc:
                if n[3] == '</s>' and (n[2], n[3], n[2]) not in right:
                    right.append((n[2], n[3], n[2]))
                elif m[0] == n[2] and m[1] == n[3]:
                    if (m[0], m[1], m[2]) not in right:
                        right.append((m[0], m[1], m[2]))
                if m[2] == n[0] and m[3] == n[1]:
                    if (m[0], m[1], m[2]) not in left:
                        left.append((m[0], m[1], m[2]))
            if (n[0], n[1], n[2]) not in lat:
                lat[(n[0], n[1], n[2])] = (left, right)
            else:
                for l in left:
                    if l not in lat[(n[0], n[1], n[2])][0]:
                        lat[(n[0], n[1], n[2])][0].append(l)
                for r in right:
                    if r not in lat[(n[0], n[1], n[2])][1]:
                        lat[(n[0], n[1], n[2])][1].append(r)

    return lat
コード例 #6
0
 ref = open(ref)
 wordcount = 0
 errcount = 0
 for c, r in izip(ctl, ref):
     # Normalize reference, etc.
     ref, refid = get_utt(r)
     c = c.strip()
     r = ref.split()
     if len(r) == 0 or r[0] != '<s>': r.insert(0, '<s>')
     if r[-1] != '</s>': r.append('</s>')
     r = filter(lambda x: not is_filler(x), r)
     # Turn it into an FSM
     rfst = fstutils.sent2fst(r)
     # Get the hypothesis lattice
     try:
         l = lattice.Dag(os.path.join(latdir, c + ".lat"))
     except IOError:
         try:
             l = lattice.Dag(os.path.join(latdir, c + ".lat.gz"))
         except IOError:
             l = lattice.Dag(htk_file=os.path.join(latdir, c + ".slf"))
     if opts.prune != None:
         l.posterior_prune(-opts.prune)
     # Convert it to an FSM
     lfst = lat2fsg.build_lattice_fsg(l,
                                      rfst.OutputSymbols(),
                                      addsyms=True,
                                      determinize=False,
                                      baseword=lattice.baseword_noclass)
     openfst.ArcSortInput(lfst)
     # Apply Levenshtein model to the input
コード例 #7
0
import sys
import os
import lattice
from itertools import izip

ctl, ref, latdir = sys.argv[1:]

ctl = open(ctl)
ref = open(ref)
wordcount = 0
errcount = 0
for c,r in izip(ctl, ref):
    c = c.strip()
    r = r.split()
    del r[-1]
    if r[0] != '<s>': r.insert(0, '<s>')
    if r[-1] != '</s>': r.append('</s>')
    l = lattice.Dag()
    l.sphinx2dag(os.path.join(latdir, c + ".lat.gz"))
    err, bt = l.minimum_error(r)
    maxlen = [max([len(y) for y in x]) for x in bt]
    print " ".join(["%*s" % (m, x[0]) for m, x in izip(maxlen, bt)])
    print " ".join(["%*s" % (m, x[1]) for m, x in izip(maxlen, bt)])
    print "Error: %.2f%%" % (float(err) / len(r) * 100)
    print
    wordcount += len(r)
    errcount += err

print "TOTAL Error: %.2f%%" % (float(errcount) / wordcount * 100)
コード例 #8
0
ファイル: lattice_prune.py プロジェクト: lwllovewf2010/jarvis
    edgecount = 0
    density = 0
    # prune lattices one by one
    for i in range(start, end):
        c = ctl[i].strip()
        r = ref[i].split()
        del r[-1]
        if r[0] != '<s>': r.insert(0, '<s>')
        if r[-1] != '</s>': r.append('</s>')
        r = filter(lambda x: not lattice.is_filler(x), r)

        print "process sent: %s" % c

        # load lattice
        print "\t load lattice ..."
        dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz"))
        dag.bypass_fillers()
        dag.remove_unreachable()

        # prune lattice
        dag.edges_unigram_score(lm, lw)
        dag.dt_posterior()

        # edge pruning
        print "\t edge pruning ..."
        dag.forward_edge_prune(abeam)
        dag.backward_edge_prune(abeam)
        dag.remove_unreachable()

        # node pruning
        print "\t node pruning ..."
コード例 #9
0
    parser = OptionParser(usage="%prog CTL LATDIR [LMFST]")
    parser.add_option("--lmnamectl")
    parser.add_option("--lmdir", default=".")
    parser.add_option("--lw", type="float", default=7)
    opts, args = parser.parse_args(sys.argv[1:])
    ctlfile, latdir = args[0:2]
    if len(args) > 2:
        lmfst = openfst.StdVectorFst.Read(args[2])
        lmnamectl = None
    elif opts.lmnamectl:
        lmnamectl = file(opts.lmnamectl)
        lmfsts = {}
    else:
        parser.error("either --lmnamectl or LMFST must be given")
    for spam in file(ctlfile):
        if lmnamectl:
            lmname = lmnamectl.readline().strip()
            if lmname not in lmfsts:
                lmfsts[lmname] = openfst.StdVectorFst.Read(os.path.join(opts.lmdir,
                                                                        lmname + ".arpa.fst"))
            lmfst = lmfsts[lmname]
        try:
            dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat.gz"))
        except IOError:
            try:
                dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat"))
            except IOError:
                dag = lattice.Dag(htk_file=os.path.join(latdir, spam.strip() + ".slf"))
        words, score = lat_rescore(dag, lmfst, opts.lw)
        print " ".join(words), "(%s %f)" % (spam.strip(), score)