def compute_tagging(self, corpus, first_lex_in, first_lex_out,
                        second_lex_in, second_lex_out, automaton1, automaton2,
                        class_cluster, trials1, trials2, result, key):
        results = []
        smooth1 = fst.Fst.read(trials1)
        smooth2 = fst.Fst.read(trials2)
        for idc, l in enumerate(corpus):
            acceptor = self.far_compile_string(l, first_lex_in, "<unk>")
            cv = fst.compose(acceptor, automaton1)
            res = fst.compose(cv, smooth1)
            res = res.rmepsilon()
            res = fst.shortestpath(res)
            res = res.topsort()
            res = self.fst_print(res, first_lex_out)
            compact_res = []
            res = self.sanity_check(l.split(" "), res, class_cluster)
            first_o = " ".join(res)

            acceptor = self.far_compile_string(first_o, second_lex_in, "<unk>")
            cv = fst.compose(acceptor, automaton2)
            res = fst.compose(cv, smooth2)
            res = res.rmepsilon()
            res = fst.shortestpath(res)
            res = res.topsort()
            output = [
                x if "." in x else "O"
                for x in self.fst_print(res, second_lex_out)
            ]

            results.append(output)
        result[key] = results
Esempio n. 2
0
 def decompound(self, word):
     tree = Tree(word)
     self._split(word, tree)
     #print(tree)
     nleafnodes = tree.nleafnodes()
     #print("Number of leaf nodes:", nleafnodes)
     symtablel = sorted(tree.getsyms())
     symtable = dict([(s, i) for i, s in enumerate(symtablel)])
     #print("Symbols:")
     #print(symtable)
     fst = wfst.Fst()
     [fst.add_state() for i in range(nleafnodes + 1)]
     fst.set_final(nleafnodes, wfst.Weight.One(fst.weight_type()))
     fst.set_start(0)
     tree.makelattice(fst, 0, symtable, self.wordcost, firstword=True)
     #output fst for debugging
     # fstsymtable = wfst.SymbolTable(b"default")
     # for i, sym in enumerate(symtablel):
     #     fstsymtable.add_symbol(sym.encode("utf-8"), i)
     # fst.set_input_symbols(fstsymtable)
     # fst.set_output_symbols(fstsymtable)
     # fst.write("/tmp/debug.fst")
     best = wfst.shortestpath(fst, nshortest=1)
     wordseq = label_seq(best, symtablel)
     return wordseq
Esempio n. 3
0
    def topk_choice(self, word_sequence, topk_wds=None):
        '''
        extracts the topk choices of
        lm given a word history (lattice)
        input: lm.fst and sentence string
        output: topk words to complete the lattice
        '''

        # generate sentence fst
        fstout = fst.intersect(word_sequence, self.lm)
        fst_det = fst.determinize(fstout)
        fst_p = fst.push(fst_det, push_weights=True, to_final=True)
        fst_p.rmepsilon()
        fst_rm = fst.determinize(fst_p)
        short = fst.shortestpath(fst_rm, nshortest=10)
        short_det = fst.determinize(short)
        short_det.rmepsilon()
        two_state = fst.compose(short_det, self.refiner)
        output = two_state.project(project_output=True)
        output.rmepsilon()
        output = fst.determinize(output)
        output.minimize()
        if topk_wds is not None:  # Needs to distinguish None and []
            topk_wds.extend(self.get_topk_words(output))
        return output
Esempio n. 4
0
def viterbi(lattice):
    if lattice.arc_type() != 'standard':
        lattice = openfst.arcmap(lattice, map_type='to_std')

    shortest_paths = openfst.shortestpath(lattice).topsort().rmepsilon()

    return shortest_paths
Esempio n. 5
0
def get_shortest_path(fst_in):
    """
	Generates the shortest path through an FST
	:param fst_in: <openfst.Fst>
	:return: <openfst.Fst> or None
	"""
    try:
        return openfst.shortestpath(fst_in)
    except:
        return None
Esempio n. 6
0
def printstrings(a,
                 nshortest=1,
                 project_output=False,
                 syms=None,
                 weight=False):
    """
    Return the nshortest unique input strings in the FST a.  The FST a is projected 
    onto the input or output prior to finding the shortest paths. An optional symbol 
    table syms can be provided.  Results are returned as strings; if the weight 
    flag is specified, the path scores are included
    """
    import pywrapfst as fst
    b = a.copy().project(project_output=project_output)
    if nshortest == 1:
        c = fst.shortestpath(b)
    else:
        c = fst.shortestpath(b, nshortest=nshortest, unique=True)
    nba = fst.push(c, push_weights=True).rmepsilon()
    nb = []
    if nba.start() != -1:
        for arc1 in nba.arcs(nba.start()):
            w = arc1.weight
            nextstate = arc1.nextstate
            nbi = []
            if syms:
                nbi.append(syms.find(arc1.ilabel))
            else:
                nbi.append(str(arc1.ilabel))
            while nba.arcs(nextstate):
                try:
                    nextarc = nba.arcs(nextstate).next()
                except StopIteration:
                    break
                if syms:
                    nbi.append(syms.find(nextarc.ilabel))
                else:
                    nbi.append(str(nextarc.ilabel))
                nextstate = nextarc.nextstate
            if weight:
                nb.append((' '.join(nbi), w.to_string()))
            else:
                nb.append(' '.join(nbi))
    return nb
Esempio n. 7
0
def get_likelihood_for_fsas_over_paths(d_fsa,
                                       w_fsas,
                                       w,
                                       num_paths=10,
                                       return_type="probability"):
    '''get the weight of a single arc by iterating in Python'''
    if num_paths <= 0:
        raise ValueError('num_paths must be a positive integer')

    w_fsa = w_fsas[w]
    dw_composed = pywrapfst.compose(w_fsa, d_fsa)
    dw_composed.arcsort(sort_type="ilabel")

    if num_paths > 1:
        shortest_paths = pywrapfst.epsnormalize(
            pywrapfst.shortestpath(dw_composed, nshortest=num_paths))
        if return_type == "shortest_paths":
            return (shortest_paths)
        if shortest_paths.num_states() > 0:

            # take the reverse distance because with multiple shortest paths, 0 is the start state, 1 is the final state
            shortest_distance = pywrapfst.shortestdistance(shortest_paths,
                                                           reverse=True)

            # iterate over all outgoing arcs from the start state
            path_weights = get_weights_for_paths(shortest_paths)
            if return_type == "path_weights":
                return (path_weights)
            shortest_paths_sum = np.sum(np.exp(-1. * np.array(path_weights)))
            if return_type == "probability":
                return (shortest_paths_sum)
        else:
            # this is the case where there is no way to compose the d_fsa and the w_fsa
            return (10**-20)

    else:
        shortest_path = pywrapfst.shortestpath(dw_composed)
        if shortest_path.num_states() > 0:
            shortest_distance = pywrapfst.shortestdistance(shortest_path)
            return (np.exp(-1 * float(shortest_distance[0])))
        else:
            return (10**-20)
Esempio n. 8
0
def shortest_path(f, n=1):
    # create the FST that will encode the shortest paths,
    # with the same symbol table as f
    s = fst(f)

    # OpenFST shortest path
    s.f = openfst.shortestpath(f.f, nshortest=n)

    # Keep track of the states
    for stateid in s.f.states():
        s.states.append(str(stateid))
        s.stateids[str(stateid)] = stateid

    return s
Esempio n. 9
0
def get_shortest_path(fst):
    shortest_path = openfst.shortestpath(fst, weight=None)

    ## reverse order when printing -- TODO investigate why and use proper operations to extract path
    data = [line.split('\t') for line in shortest_path.text().split('\n')]
    data = [line for line in data if len(line) in [4, 5]]
    data = [(int(line[0]), int(line[2]))
            for line in data]  # (i,o,lab1,lab2,[weight])
    data.sort()
    data.reverse()

    shortest_path = [cat - 1 for (index, cat) in data
                     if cat != 0]  ## remove epsilon 0
    #            ^--- back to python indices
    return shortest_path
Esempio n. 10
0
def get_shortest_path(fst_in, quiet=True):

    s = openfst.shortestpath(fst_in, weight=None)  ## weight:
    # A Weight or weight string indicating the desired weight threshold
    # below which paths are pruned; if omitted, no paths are pruned.

    if not quiet:
        print s

    ## reverse order when printing -- TODO investigate why and use proper operations to extract path
    data = [line.split('\t') for line in s.text().split('\n')]
    data = [line for line in data if len(line) in [4, 5]]
    data = [(int(line[0]), int(line[2]))
            for line in data]  # (i,o,lab1,lab2,[weight])
    data.sort()
    data.reverse()

    data = [frame - 1 for (index, frame) in data
            if frame != 0]  ## remove epsilon 0
    #            ^--- back to python indices

    #    print 'shortest path FST:'
    #    print s
    return data
Esempio n. 11
0
def process_line(line):
    global isym
    global osym
    global tm
    global lm
    # Read input
    compiler = fst.Compiler()
    arr = line.strip().split() + ["</s>"]
    unks = []
    for i, x in enumerate(arr):
        if x not in isym:
            unks.append(x)
        xsym = isym[x] if x in isym else isym["<unk>"]
        print >> compiler, "%d %d %s %s" % (i, i + 1, xsym, xsym)
    print >> compiler, "%s" % (len(arr))
    ifst = compiler.compile()

    # Create the search graph and do search
    graph = fst.compose(ifst, tm)
    graph = fst.compose(graph, lm)
    graph = fst.shortestpath(graph)

    # Read off the output
    out = []
    unkspot = 0
    for state in graph.states():
        for arc in graph.arcs(state):
            if arc.olabel != 0:
                tok = osym[arc.olabel]
                # unk substitution (original words in same order)
                if unkspot < len(unks) and tok == "<unk>":
                    out.append(unks[unkspot])
                    unkspot += 1
                else:
                    out.append(tok)
    return " ".join(reversed(out[1:]))
Esempio n. 12
0
                            traversed_arcs = []
                            break
                        elif arc.nextstate not in fst_states:
                            fst_states.insert(0, arc.nextstate)

            print("FINAL NUM OOV CANDIDATES: " + str(len(oovs)))
            #merge overlapping candidates
            comp_oov_i = 0
            while comp_oov_i < len(oovs):
                if_merged = 0
                print(str(comp_oov_i))
                comp_oov_1 = oovs[comp_oov_i]
                comp_oov_1_start = 0
                comp_oov_1_end = 0
                comp_oov_1_shortest = fst.determinize(
                    fst.shortestpath(comp_oov_1)).minimize().rmepsilon()
                for tmp_state in comp_oov_1_shortest.states():
                    for tmp_arc in comp_oov_1_shortest.arcs(tmp_state):
                        str_w = tmp_arc.weight.to_string()
                        comp_oov_1_start = comp_oov_1_start + int(
                            str_w[str_w.find(b' ') +
                                  1:str_w.find(b' ',
                                               str_w.find(b' ') + 1)])
                        comp_oov_1_end = comp_oov_1_end + int(
                            str_w[str_w.find(b' ',
                                             str_w.find(b' ') + 1) + 1:])
                for comp_oov_j in range(comp_oov_i + 1, len(oovs)):
                    #					if_merged = 0
                    comp_oov_2 = oovs[comp_oov_j]
                    comp_oov_2_start = 0
                    comp_oov_2_end = 0
Esempio n. 13
0
def _shortest_path(graph):
    return fst.shortestpath(graph)
    if not l: break
    unks = []
    words = l.split()
    word_ids = []
    for word in words:
      word_id = syms.get(word, unk_id)
      word_ids.append(word_id)
      if word_id == unk_id:
        unks.append(word)
  
  
    sentence = make_sentence_fsa(syms, word_ids)
    sentence.arcsort(sort_type="olabel")
    
    composed = fst.compose(sentence, g)

    alignment = fst.shortestpath(composed)
    alignment.rmepsilon()
    alignment.topsort()
        
    labels = []
    for state in alignment.states():
      for arc in alignment.arcs(state):
        if arc.olabel > 0:
          if arc.olabel == unk_id:
            labels.append(unks.pop(0))
          else:
            labels.append(syms_list[arc.olabel])
    print(" ".join(labels))
    sys.stdout.flush()
            trials.append("smooths/" + folder + "/" + filename)
    corpus = []

    with open(p.dataset_folder + "test_set.txt") as file:
        for l in file.readlines():
            corpus.append(l.strip())

    for t in tqdm(trials):
        results = []
        for l in corpus:
            acceptor = p.far_compile_string(l, lex_in, "<unk>")
            cv = fst.compose(acceptor, automaton)
            smooth = fst.Fst.read(t)
            res = fst.compose(cv, smooth)
            res = res.rmepsilon()
            res = fst.shortestpath(res)
            res = res.topsort()
            results.append(p.fst_print(res, lex_out))
        r = p.evaluation_print(results, "test_features.txt")
        folder1, folder2, filename = t.split("/")
        filename = filename.split(".")[0]
        if not os.path.exists("results/cut_off_lb_" + str(lower_bound) +
                              "_ub_" + str(upper_bound) + "/" + folder2):
            os.makedirs("results/cut_off_lb_" + str(lower_bound) + "_ub_" +
                        str(upper_bound) + "/" + folder2)

        with open(
                "results/cut_off_lb_" + str(lower_bound) + "_ub_" +
                str(upper_bound) + "/" + folder2 + "/" + filename +
                "_result.txt", "w") as file:
            file.write(r)
Esempio n. 16
0
def __test(fig_dir, num_classes=2, num_samples=10, min_dur=2, max_dur=4):
    eps_str = 'ε'
    bos_str = '<BOS>'
    eos_str = '<EOS>'
    dur_internal_str = 'I'
    dur_final_str = 'F'

    aux_symbols = (eps_str, bos_str, eos_str)
    dur_vocab = (dur_internal_str, dur_final_str)
    sample_vocab = tuple(i for i in range(num_samples))
    dur_vocab = tuple(i for i in range(1, max_dur + 1))
    class_vocab = tuple(i for i in range(num_classes))
    class_dur_vocab = tuple((c, s) for c in class_vocab for s in dur_vocab)

    def to_strings(vocab):
        return tuple(map(str, vocab))

    def to_integerizer(vocab):
        return {item: i for i, item in enumerate(vocab)}

    sample_vocab_str = to_strings(sample_vocab)
    dur_vocab_str = to_strings(dur_vocab)
    class_vocab_str = to_strings(class_vocab)
    class_dur_vocab_str = to_strings(class_dur_vocab)

    # sample_integerizer = to_integerizer(sample_vocab)
    dur_integerizer = to_integerizer(dur_vocab)
    class_integerizer = to_integerizer(class_vocab)
    # class_dur_integerizer = to_integerizer(class_dur_vocab)

    sample_str_integerizer = to_integerizer(sample_vocab_str)
    class_str_integerizer = to_integerizer(class_vocab_str)
    class_dur_str_integerizer = to_integerizer(class_dur_vocab_str)

    def get_parts(class_dur_key):
        c, s = class_dur_vocab[class_dur_str_integerizer[class_dur_key]]
        c_str = class_vocab_str[class_integerizer[c]]
        s_str = dur_vocab_str[dur_integerizer[s]]
        return c_str, s_str

    class_dur_to_str = {get_parts(name): name for name in class_dur_vocab_str}

    sample_symbols = libfst.makeSymbolTable(aux_symbols + sample_vocab,
                                            prepend_epsilon=False)
    # dur_symbols = libfst.makeSymbolTable(aux_symbols + dur_vocab, prepend_epsilon=False)
    class_symbols = libfst.makeSymbolTable(aux_symbols + class_vocab,
                                           prepend_epsilon=False)
    # dur_symbols = libfst.makeSymbolTable(aux_symbols + dur_vocab, prepend_epsilon=False)
    class_dur_symbols = libfst.makeSymbolTable(aux_symbols + class_dur_vocab,
                                               prepend_epsilon=False)

    obs_scores = np.zeros((num_samples, num_classes))

    dur_scores = np.array([[0 if d >= min_dur else np.inf for d in dur_vocab]
                           for c in class_vocab],
                          dtype=float)

    def score_transition(c_prev, s_prev, c_cur, s_cur):
        if c_prev != c_cur:
            if s_prev == dur_final_str and s_cur == dur_internal_str:
                score = 0
            else:
                score = np.inf
        else:
            if s_prev == dur_internal_str and s_cur == dur_final_str:
                score = 0
            elif s_prev == dur_internal_str and s_cur == dur_internal_str:
                score = 0
            else:
                score = np.inf
        return score

    transition_scores = np.array([[
        score_transition(c_prev, s_prev, c_cur, s_cur)
        for (c_cur, s_cur) in class_dur_vocab
    ] for (c_prev, s_prev) in class_dur_vocab],
                                 dtype=float)
    init_scores = np.array([
        0 if c == 0 and s == dur_internal_str else np.inf
        for (c, s) in class_dur_vocab
    ],
                           dtype=float)
    final_scores = np.array([
        0 if c == 1 and s == dur_final_str else np.inf
        for (c, s) in class_dur_vocab
    ],
                            dtype=float)

    def score_arc_state(class_dur_key, class_key):
        c, s = class_dur_vocab[class_dur_str_integerizer[class_dur_key]]
        c_prime = class_str_integerizer[class_key]

        if c == c_prime:
            score = 0
        else:
            score = np.inf

        return score

    class_dur_to_class_scores = np.array([[
        score_arc_state(class_dur_key, class_key)
        for class_key in class_vocab_str
    ] for class_dur_key in class_dur_vocab_str],
                                         dtype=float)

    def log_normalize(arr, axis=1):
        denom = -scipy.special.logsumexp(-arr, axis=axis, keepdims=True)
        return arr - denom

    obs_scores = log_normalize(obs_scores)
    dur_scores = log_normalize(dur_scores)
    transition_scores = log_normalize(transition_scores)
    init_scores = log_normalize(init_scores, axis=None)
    final_scores = log_normalize(final_scores, axis=None)

    obs_fst = add_endpoints(fromArray(obs_scores,
                                      sample_vocab_str,
                                      class_vocab_str,
                                      input_symbols=sample_symbols,
                                      output_symbols=class_symbols,
                                      arc_type='standard'),
                            bos_str=bos_str,
                            eos_str=eos_str).arcsort(sort_type='ilabel')

    dur_fst = add_endpoints(make_duration_fst(
        dur_scores,
        class_vocab_str,
        class_dur_to_str,
        dur_internal_str=dur_internal_str,
        dur_final_str=dur_final_str,
        input_symbols=class_symbols,
        output_symbols=class_dur_symbols,
        allow_self_transitions=False,
        arc_type='standard',
    ),
                            bos_str=bos_str,
                            eos_str=eos_str).arcsort(sort_type='ilabel')

    transition_fst = fromTransitions(
        transition_scores,
        class_dur_vocab_str,
        class_dur_vocab_str,
        init_weights=init_scores,
        final_weights=final_scores,
        input_symbols=class_dur_symbols,
        output_symbols=class_dur_symbols).arcsort(sort_type='ilabel')

    class_dur_to_class_fst = single_state_transducer(
        class_dur_to_class_scores,
        class_dur_vocab_str,
        class_vocab_str,
        input_symbols=class_dur_symbols,
        output_symbols=class_symbols,
        arc_type='standard').arcsort(sort_type='ilabel')

    seq_model = openfst.compose(dur_fst, transition_fst)
    decode_lattice = openfst.compose(obs_fst, seq_model).rmepsilon()

    # Result is in the log semiring (ie weights are negative log probs)
    arc_scores = libfst.fstArcGradient(decode_lattice).arcsort(
        sort_type='ilabel')
    best_arcs = openfst.shortestpath(decode_lattice).arcsort(
        sort_type='ilabel')

    state_scores = openfst.compose(
        arc_scores, openfst.arcmap(class_dur_to_class_fst, map_type='to_log'))
    best_states = openfst.compose(best_arcs, class_dur_to_class_fst)

    state_scores_arr, weight_type = toArray(state_scores,
                                            sample_str_integerizer,
                                            class_str_integerizer)

    draw_fst(os.path.join(fig_dir, 'obs_fst'),
             obs_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'dur_fst'),
             dur_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'transition_fst'),
             transition_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'seq_model'),
             seq_model,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'decode_lattice'),
             decode_lattice,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'state_scores'),
             state_scores,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'best_states'),
             best_states,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    utils.plot_array(obs_scores.T, (state_scores_arr.T, ), ('-logprobs', ),
                     fn=os.path.join(fig_dir, "test_io.png"))
Esempio n. 17
0
osym = {}
with open(sys.argv[5], "r") as osymfile:
    for line in osymfile:
        x, y = line.strip().split()
        osym[int(y)] = x

for line in sys.stdin:
    # Read input
    compiler = fst.Compiler()
    arr = line.strip().split() + ["</s>"]
    for i, x in enumerate(arr):
        xsym = isym[x] if x in isym else isym["<unk>"]
        print >> compiler, "%d %d %s %s" % (i, i + 1, xsym, xsym)
    print >> compiler, "%s" % (len(arr))
    ifst = compiler.compile()

    # Create the search graph and do search
    graph = fst.compose(ifst, tm)
    graph = fst.compose(graph, lm)
    graph = fst.compose(graph, wp)
    graph = fst.shortestpath(graph)

    # Read off the output
    out = []
    for state in graph.states():
        for arc in graph.arcs(state):
            if arc.olabel != 0:
                out.append(osym[arc.olabel])
    print(" ".join(reversed(out[1:])))
Esempio n. 18
0
    def decode(self,
               encoded,
               encoded_lens,
               texts=None,
               text_lens=None,
               return_texts_and_generated_loss=False,
               return_logits_text_diff=False,
               spkids=None,
               **other_data_in_batch):
        logits = self.logits(encoded, encoded_lens)

        denominator_matrices = self.graph_generator.get_decoding_matrices(
            logits.device)

        # The gradient wrt Viterbi gives the symbols on the shortest path.
        with torch.enable_grad():
            logits = logits.detach().requires_grad_()
            loss = -fst_utils.path_reduction(
                logits,
                encoded_lens,
                denominator_matrices,
                red_kind='viterbi',
                neg_inf=self.graph_generator.nc_weight).sum()
            loss.backward()

        selidx = logits.grad.min(-1)[1].cpu()
        decoded_texts = []
        for i in range(logits.size(1)):
            # add 1 because our FST wants symbols from range 1..Num_Symbols
            idx = selidx[:encoded_lens[i], i] + 1
            decoded_fst = fst.shortestpath(
                fst.compose(
                    fst_utils.build_chain_fst(idx, arc_type='standard'),
                    self.dec_fst))
            decoded_text = []
            n = decoded_fst.start()
            while decoded_fst.num_arcs(n) != 0:
                a, = decoded_fst.arcs(n)
                n = a.nextstate
                if a.olabel > 0:
                    decoded_text.append(a.olabel)
            decoded_texts.append(decoded_text)

        ret = {
            'decoded': decoded_texts,
            # 'decoded_frames': selidx,
            'logits': logits
        }

        if texts is not None and text_lens is not None:
            fst_text_losses = self.get_fst_loss(logits, encoded_lens, texts,
                                                text_lens, other_data_in_batch)
            fst_text_loss = fst_text_losses.sum()
            ret['loss'] = dict(fst_loss=fst_text_loss, loss=fst_text_loss)

        if return_texts_and_generated_loss:
            decoded_lens = torch.IntTensor([len(x) for x in decoded_texts])
            fst_generated_losses = self.get_fst_loss(logits, encoded_lens,
                                                     decoded_texts,
                                                     decoded_lens, None)
            ret['text_loss'] = fst_text_losses.tolist()
            ret['generated_loss'] = fst_generated_losses.tolist()

        if return_logits_text_diff:
            ret['logits_text_diff'] = (encoded_lens - text_lens).tolist()
        return ret
Esempio n. 19
0
    def fst_alter_sent(self, words, numalts=5):
        # create new empty FST
        altfst = fst.Fst()
        altfst.add_state()

        for idx, word in enumerate(words):
            # add the word to the lattice or <unk> if out-of-vocabulary
            if word in self.lmfst.input_symbols():
                word_id = self.lmfst.input_symbols().find(word)
                arc = fst.Arc(word_id, word_id, 0,
                              self.get_state_id(idx + 1, altfst))
                altfst.add_arc(self.get_state_id(idx, altfst), arc)
            else:
                word_id = self.lmfst.input_symbols().find("<unk>")
                arc = fst.Arc(word_id, word_id, 0,
                              self.get_state_id(idx + 1, altfst))
                altfst.add_arc(self.get_state_id(idx, altfst), arc)

            # add word alternatives to the lattice
            nearlist = []
            for i in range(1):
                r = random.random()
                altword = '<unk>'
                p = 0
                for w, wp in self.unigrams:
                    p = p + wp
                    if p > r:
                        altword = w
                        break
                nearlist.append(altword)
            #nearlist = None

            # check if there are any neighbors at all
            if nearlist == None:
                continue

            # add each neighbor to the lattice
            for widx, w in enumerate(nearlist):
                if w in self.lmfst.input_symbols() and w != word:
                    w_id = self.lmfst.input_symbols().find(w)
                    arc = fst.Arc(w_id, w_id, 0,
                                  self.get_state_id(idx + 1, altfst))
                    altfst.add_arc(self.get_state_id(idx, altfst), arc)

        # mark the final state in the FST
        altfst.set_final(len(words))
        altfst.set_start(0)

        # sort lattice prior to rescoring
        altfst.arcsort()

        # rescore the lattice using the language model
        scoredfst = fst.compose(self.lmfst, altfst)

        # get best paths in the rescored lattice
        bestpaths = fst.shortestpath(scoredfst, nshortest=numalts)
        bestpaths.rmepsilon()

        altstrings = {}

        # get the strings and weights from the best paths
        for i, path in enumerate(self.paths(bestpaths)):
            path_string = ' '.join(
                (bestpaths.input_symbols().find(arc.ilabel)).decode('utf-8')
                for arc in path)
            path_weight = functools.reduce(operator.add,
                                           (float(arc.weight) for arc in path))
            if not path_string in altstrings:
                altstrings[path_string] = path_weight

        # sort strings by weight
        scoredstrings = []
        for sent in altstrings:
            score = altstrings[sent]
            scoredstrings.append((score, sent))
        scoredstrings.sort()

        if len(scoredstrings) > numalts:
            scoredstrings = scoredstring[:numalts]

        return scoredstrings
Esempio n. 20
0
        l = sys.stdin.readline()
        if not l: break
        unks = []
        words = l.split()
        word_ids = []
        for word in words:
            word_id = syms.get(word, unk_id)
            word_ids.append(word_id)
            if word_id == unk_id:
                unks.append(word)

        sentence = make_sentence_fsa(syms, word_ids)
        sentence.arcsort(sort_type="olabel")

        composed = fst.compose(sentence, g)

        alignment = fst.shortestpath(composed)
        alignment.rmepsilon()
        alignment.topsort()

        labels = []
        for state in alignment.states():
            for arc in alignment.arcs(state):
                if arc.olabel > 0:
                    if arc.olabel == unk_id:
                        labels.append(unks.pop(0))
                    else:
                        labels.append(syms_list[arc.olabel])
        print(" ".join(labels))
        sys.stdout.flush()