def lat2fsg(lat, fsgfile, lmfst, prune=15): if isinstance(lat, str): if lat.endswith(".slf"): dag = lattice.Dag(htkfile=lat) else: dag = lattice.Dag(lat) else: dag = lat fst = build_lattice_fsg(dag, lmfst.InputSymbols()) # Compose it (intersect, really) with the language model to get # correct N-gram scores (otherwise it is just a unigram LM). This # is the same thing as "lattice expansion". phi = lmfst.InputSymbols().Find("φ") if phi != -1: opts = openfst.StdPhiComposeOptions() opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE) opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi) cfst = openfst.StdComposeFst(fst, lmfst, opts) else: cfst = openfst.StdComposeFst(fst, lmfst) outfst = openfst.StdVectorFst(cfst) openfst.Prune(outfst, prune) # Write it back out as an FSG for PocketSphinx. build_fsg_fst(outfst, fsgfile) return outfst
def lat2flat(latfile, fsgfile, lmfst): """ Subset a language model using the vocabulary of a lattice. """ dag = lattice.Dag(latfile) fst = openfst.StdVectorFst() fst.SetStart(fst.AddState()) fst.SetFinal(0, 0) syms = lmfst.InputSymbols() seen = set() for n in dag.nodes: # Skip fillers as they have been "bypassed" by PocketSphinx if n.sym.startswith("++") or n.sym == "<sil>": continue if n.sym in seen: continue seen.add(n.sym) sym = syms.Find(baseword(n.sym)) if sym == -1: continue fst.AddArc(0, sym, sym, 0, 0) fst.SetOutputSymbols(lmfst.InputSymbols()) phi = lmfst.InputSymbols().Find("φ") if phi != -1: opts = openfst.StdPhiComposeOptions() opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE) opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi) cfst = openfst.StdComposeFst(fst, lmfst, opts) else: cfst = openfst.StdComposeFst(fst, lmfst) outfst = openfst.StdVectorFst() openfst.Determinize(cfst, outfst) # Write it back out as an FSG for PocketSphinx. build_fsg_fst(outfst, fsgfile) return outfst
def lat_rescore(latfile, lmfst): """ Rescore a lattice using a language model. """ dag = lattice.Dag(latfile) end = dag.bestpath(lm) words = [] return [lattice.baseword(x.sym) for x in dag.backtrace(end)], end.score
def lat2fsg_posterior(lat, fsgfile, prune=5, errfst=None): if isinstance(lat, str): if lat.endswith(".slf"): dag = lattice.Dag(htk_file=lat) else: dag = lattice.Dag(lat) else: dag = lat dag.posterior_prune(-prune) fst = build_lattice_fsg(dag, pscale=1) if errfst: fst = build_lattice_fsg(dag, errfst.InputSymbols(), pscale=1, addsyms=True) errfst.SetOutputSymbols(errfst.InputSymbols()) fst = apply_errfst(fst, errfst) else: fst = build_lattice_fsg(dag, pscale=1) build_fsg_fst(fst, fsgfile) return fst
def load_denlat(latfile): # Annotate the lattice with posterior probabilities dag = lattice.Dag(latfile) dag.remove_unreachable() arc = [] for w in dag.nodes: for l in w.exits: if (w.entry, w.sym, l.dest.entry, l.dest.sym) not in arc: arc.append((w.entry, w.sym, l.dest.entry, l.dest.sym)) lat = {} for n in arc: if n[1] != '<s>': left = [] right = [] for m in arc: if n[3] == '</s>' and (n[2], n[3], n[2]) not in right: right.append((n[2], n[3], n[2])) elif m[0] == n[2] and m[1] == n[3]: if (m[0], m[1], m[2]) not in right: right.append((m[0], m[1], m[2])) if m[2] == n[0] and m[3] == n[1]: if (m[0], m[1], m[2]) not in left: left.append((m[0], m[1], m[2])) if (n[0], n[1], n[2]) not in lat: lat[(n[0], n[1], n[2])] = (left, right) else: for l in left: if l not in lat[(n[0], n[1], n[2])][0]: lat[(n[0], n[1], n[2])][0].append(l) for r in right: if r not in lat[(n[0], n[1], n[2])][1]: lat[(n[0], n[1], n[2])][1].append(r) return lat
ref = open(ref) wordcount = 0 errcount = 0 for c, r in izip(ctl, ref): # Normalize reference, etc. ref, refid = get_utt(r) c = c.strip() r = ref.split() if len(r) == 0 or r[0] != '<s>': r.insert(0, '<s>') if r[-1] != '</s>': r.append('</s>') r = filter(lambda x: not is_filler(x), r) # Turn it into an FSM rfst = fstutils.sent2fst(r) # Get the hypothesis lattice try: l = lattice.Dag(os.path.join(latdir, c + ".lat")) except IOError: try: l = lattice.Dag(os.path.join(latdir, c + ".lat.gz")) except IOError: l = lattice.Dag(htk_file=os.path.join(latdir, c + ".slf")) if opts.prune != None: l.posterior_prune(-opts.prune) # Convert it to an FSM lfst = lat2fsg.build_lattice_fsg(l, rfst.OutputSymbols(), addsyms=True, determinize=False, baseword=lattice.baseword_noclass) openfst.ArcSortInput(lfst) # Apply Levenshtein model to the input
import sys import os import lattice from itertools import izip ctl, ref, latdir = sys.argv[1:] ctl = open(ctl) ref = open(ref) wordcount = 0 errcount = 0 for c,r in izip(ctl, ref): c = c.strip() r = r.split() del r[-1] if r[0] != '<s>': r.insert(0, '<s>') if r[-1] != '</s>': r.append('</s>') l = lattice.Dag() l.sphinx2dag(os.path.join(latdir, c + ".lat.gz")) err, bt = l.minimum_error(r) maxlen = [max([len(y) for y in x]) for x in bt] print " ".join(["%*s" % (m, x[0]) for m, x in izip(maxlen, bt)]) print " ".join(["%*s" % (m, x[1]) for m, x in izip(maxlen, bt)]) print "Error: %.2f%%" % (float(err) / len(r) * 100) print wordcount += len(r) errcount += err print "TOTAL Error: %.2f%%" % (float(errcount) / wordcount * 100)
edgecount = 0 density = 0 # prune lattices one by one for i in range(start, end): c = ctl[i].strip() r = ref[i].split() del r[-1] if r[0] != '<s>': r.insert(0, '<s>') if r[-1] != '</s>': r.append('</s>') r = filter(lambda x: not lattice.is_filler(x), r) print "process sent: %s" % c # load lattice print "\t load lattice ..." dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz")) dag.bypass_fillers() dag.remove_unreachable() # prune lattice dag.edges_unigram_score(lm, lw) dag.dt_posterior() # edge pruning print "\t edge pruning ..." dag.forward_edge_prune(abeam) dag.backward_edge_prune(abeam) dag.remove_unreachable() # node pruning print "\t node pruning ..."
parser = OptionParser(usage="%prog CTL LATDIR [LMFST]") parser.add_option("--lmnamectl") parser.add_option("--lmdir", default=".") parser.add_option("--lw", type="float", default=7) opts, args = parser.parse_args(sys.argv[1:]) ctlfile, latdir = args[0:2] if len(args) > 2: lmfst = openfst.StdVectorFst.Read(args[2]) lmnamectl = None elif opts.lmnamectl: lmnamectl = file(opts.lmnamectl) lmfsts = {} else: parser.error("either --lmnamectl or LMFST must be given") for spam in file(ctlfile): if lmnamectl: lmname = lmnamectl.readline().strip() if lmname not in lmfsts: lmfsts[lmname] = openfst.StdVectorFst.Read(os.path.join(opts.lmdir, lmname + ".arpa.fst")) lmfst = lmfsts[lmname] try: dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat.gz")) except IOError: try: dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat")) except IOError: dag = lattice.Dag(htk_file=os.path.join(latdir, spam.strip() + ".slf")) words, score = lat_rescore(dag, lmfst, opts.lw) print " ".join(words), "(%s %f)" % (spam.strip(), score)