def decode_lattice(lattice: pynini.Fst, lm: pynini.Fst, sym: pynini.SymbolTable) -> str: """Decodes the lattice.""" lattice = pynini.compose(lattice, lm) assert lattice.start() != pynini.NO_STATE_ID, "composition failure" # Pynini can join the string for us. return pynini.shortestpath(lattice).rmepsilon().string(sym)
def all_suffixes(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() start_state = fsa.start() for s in fsa.states(): fsa.add_arc( start_state, pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s)) return fsa.optimize()
def compile_wfst_from_right_branching_grammar(grammar): """ :type grammar: LCFRS :rtype: Fst Create a FST from a right-branching hybrid grammar. The Output of the is a rule tree in `polish notation <https://en.wikipedia.org/wiki/Polish_notation>`_ """ myfst = Fst() nonterminals = SymbolTable() for nont in grammar.nonts(): sid = myfst.add_state() nonterminals.add_symbol(nont, sid) if nont == grammar.start(): myfst.set_start(sid) sid = myfst.add_state() nonterminals.add_symbol(FINAL, sid) myfst.set_final(nonterminals.add_symbol(FINAL)) rules = Enumerator(first_index=1) for rule in grammar.rules(): rules.object_index(rule) terminals = SymbolTable() terminals.add_symbol('<epsilon>', 0) for rule in grammar.rules(): if len(rule.rhs()) == 2: for rule2 in grammar.lhs_nont_to_rules(rule.rhs_nont(0)): if len(rule2.rhs()) == 0: arc = Arc(terminals.add_symbol(rule2.lhs().args()[0][0]), terminals.add_symbol(str(rules.object_index(rule)) + '-' + str(rules.object_index(rule2))), make_weight(rule.weight() * rule2.weight()), nonterminals.find(rule.rhs_nont(1))) myfst.add_arc(nonterminals.find(rule.lhs().nont()), arc) elif len(rule.rhs()) == 0: arc = Arc(terminals.add_symbol(rule.lhs().args()[0][0]), terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()), nonterminals.find(FINAL)) myfst.add_arc(nonterminals.find(rule.lhs().nont()), arc) else: assert rule.lhs().nont() == grammar.start() arc = Arc(0, terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()), nonterminals.find(rule.rhs_nont(0))) myfst.add_arc(myfst.start(), arc) myfst.set_input_symbols(terminals) myfst.set_output_symbols(terminals) myfst.optimize(True) return myfst, rules
def check_wellformed_lattice(lattice: pynini.Fst) -> None: """Raises an error if the lattice is empty. Args: lattice: A lattice FST. Raises: Error: Lattice is empty. """ if lattice.start() == pynini.NO_STATE_ID: raise Error("Lattice is empty")
def total_weight(self, fst: pynini.Fst) -> float: return float(pynini.shortestdistance(fst, reverse=True)[fst.start()])