Exemple #1
0
def lat_rescore(latfile, lmfst):
    """
    Rescore a lattice using a language model.
    """
    dag = lattice.Dag(latfile)
    end = dag.bestpath(lm)
    return [lattice.baseword(x.sym) for x in dag.backtrace(end)], end.score
Exemple #2
0
def load_denlat(latfile):
    # Annotate the lattice with posterior probabilities
    dag = lattice.Dag(latfile)
    dag.remove_unreachable()

    arc = []
    for w in dag.nodes:
        for l in w.exits:
            if (w.entry, w.sym, l.dest.entry, l.dest.sym) not in arc:
                arc.append((w.entry, w.sym, l.dest.entry, l.dest.sym))

    lat = {}
    for n in arc:
        if n[1] != '<s>':
            left = []
            right = []
            for m in arc:
                if n[3] == '</s>' and (n[2], n[3], n[2]) not in right:
                    right.append((n[2], n[3], n[2]))
                elif m[0] == n[2] and m[1] == n[3]:
                    if (m[0], m[1], m[2]) not in right:
                        right.append((m[0], m[1], m[2]))
                if m[2] == n[0] and m[3] == n[1]:
                    if (m[0], m[1], m[2]) not in left:
                        left.append((m[0], m[1], m[2]))
            if (n[0], n[1], n[2]) not in lat:
                lat[(n[0], n[1], n[2])] = (left, right)
            else:
                for l in left:
                    if l not in lat[(n[0], n[1], n[2])][0]:
                        lat[(n[0], n[1], n[2])][0].append(l)
                for r in right:
                    if r not in lat[(n[0], n[1], n[2])][1]:
                        lat[(n[0], n[1], n[2])][1].append(r)

    return lat
Exemple #3
0
if len(sys.argv) > 4:
    prune = float(sys.argv[4])

ctl = open(ctl)
ref = open(ref)
wordcount = 0
errcount = 0
for c, r in zip(ctl, ref):
    c = c.strip()
    r = r.split()
    del r[-1]
    if len(r) == 0 or r[0] != '<s>': r.insert(0, '<s>')
    if r[-1] != '</s>': r.append('</s>')
    nw = len(r) - 2
    r = [x for x in r if not lattice.is_filler(x)]
    l = lattice.Dag()
    try:
        l.sphinx2dag(os.path.join(latdir, c + ".lat.gz"))
    except IOError:
        try:
            l.sphinx2dag(os.path.join(latdir, c + ".lat"))
        except IOError:
            l.htk2dag(os.path.join(latdir, c + ".slf"))
    if prune:
        l.posterior_prune(-prune)
    err, bt = l.minimum_error(r)
    maxlen = [max([len(y) for y in x]) for x in bt]
    print(" ".join(["%*s" % (m, x[0]) for m, x in zip(maxlen, bt)]))
    print(" ".join(["%*s" % (m, x[1]) for m, x in zip(maxlen, bt)]))
    if nw:
        print("Error: %.2f%%" % (float(err) / nw * 100))
Exemple #4
0
    # prune lattices one by one
    for i in range(start, end):
        c = ctl[i].strip()
        r = ref[i].split()
        del r[-1]
        if r[0] != '<s>':
            r.insert(0, '<s>')
        if r[-1] != '</s>':
            r.append('</s>')
        r = [x for x in r if not lattice.is_filler(x)]

        print("process sent: %s" % c)

        # load lattice
        print("\t load lattice ...")
        dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz"))
        dag.bypass_fillers()
        dag.remove_unreachable()

        # prune lattice
        dag.edges_unigram_score(lm, lw)
        dag.dt_posterior()

        # edge pruning
        print("\t edge pruning ...")
        dag.forward_edge_prune(abeam)
        dag.backward_edge_prune(abeam)
        dag.remove_unreachable()

        # node pruning
        print("\t node pruning ...")
Exemple #5
0
 errcount = 0
 for c, r in zip(ctl, ref):
     # Normalize reference, etc.
     ref, refid = get_utt(r)
     c = c.strip()
     r = ref.split()
     if len(r) == 0 or r[0] != '<s>':
         r.insert(0, '<s>')
     if r[-1] != '</s>':
         r.append('</s>')
     r = [x for x in r if not is_filler(x)]
     # Turn it into an FSM
     rfst = fstutils.sent2fst(r)
     # Get the hypothesis lattice
     try:
         l = lattice.Dag(os.path.join(latdir, c + ".lat"))
     except IOError:
         try:
             l = lattice.Dag(os.path.join(latdir, c + ".lat.gz"))
         except IOError:
             l = lattice.Dag(htk_file=os.path.join(latdir, c + ".slf"))
     if opts.prune is not None:
         l.posterior_prune(-opts.prune)
     # Convert it to an FSM
     lfst = lat2fsg.build_lattice_fsg(l,
                                      rfst.OutputSymbols(),
                                      addsyms=True,
                                      determinize=False,
                                      baseword=lattice.baseword_noclass)
     openfst.ArcSortInput(lfst)
     # Apply Levenshtein model to the input
    parser.add_option("--lmdir", default=".")
    parser.add_option("--lw", type="float", default=7)
    opts, args = parser.parse_args(sys.argv[1:])
    ctlfile, latdir = args[0:2]
    if len(args) > 2:
        lmfst = openfst.StdVectorFst.Read(args[2])
        lmnamectl = None
    elif opts.lmnamectl:
        lmnamectl = open(opts.lmnamectl)
        lmfsts = {}
    else:
        parser.error("either --lmnamectl or LMFST must be given")
    for spam in open(ctlfile):
        if lmnamectl:
            lmname = lmnamectl.readline().strip()
            if lmname not in lmfsts:
                lmfsts[lmname] = openfst.StdVectorFst.Read(
                    os.path.join(opts.lmdir, lmname + ".arpa.fst"))
            lmfst = lmfsts[lmname]
        try:
            dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat.gz"))
        except IOError:
            try:
                dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat"))
            except IOError:
                dag = lattice.Dag(
                    htk_file=os.path.join(latdir,
                                          spam.strip() + ".slf"))
        words, score = lat_rescore(dag, lmfst, opts.lw)
        print(" ".join(words), "(%s %f)" % (spam.strip(), score))