def compile_wfst_from_left_branching_grammar(grammar): """ :type grammar: LCFRS :rtype: Fst, Enumerator Create a FST from a left-branching hybrid grammar. The Output of the is a rule tree in `reverse polish notation <https://en.wikipedia.org/wiki/Reverse_Polish_notation>`_ """ myfst = Fst() nonterminals = SymbolTable() sid = myfst.add_state() nonterminals.add_symbol(INITIAL) myfst.set_start(sid) for nont in grammar.nonts(): sid = myfst.add_state() nonterminals.add_symbol(nont, sid) if nont == grammar.start(): myfst.set_final(sid) rules = Enumerator(first_index=1) for rule in grammar.rules(): rules.object_index(rule) terminals = SymbolTable() terminals.add_symbol('<epsilon>', 0) for rule in grammar.rules(): if rule.rank() == 2: assert len(rule.lhs().arg(0)) == 2 for rule2 in grammar.lhs_nont_to_rules(rule.rhs_nont(1)): if len(rule2.rhs()) == 0: arc = Arc(terminals.add_symbol(rule2.lhs().args()[0][0]), terminals.add_symbol( str(rules.object_index(rule2)) + '-' + str(rules.object_index(rule))), make_weight(rule.weight() * rule2.weight()), nonterminals.find(rule.lhs().nont())) myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc) elif rule.rank() == 0: assert len(rule.lhs().arg(0)) == 1 arc = Arc(terminals.add_symbol(rule.lhs().args()[0][0]), terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()), nonterminals.find(rule.lhs().nont())) myfst.add_arc(nonterminals.find(INITIAL), arc) else: assert rule.rank() == 1 assert rule.lhs().nont() == grammar.start() assert len(rule.lhs().arg(0)) == 1 arc = Arc(0, terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()), nonterminals.find(grammar.start()) ) myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc) myfst.set_input_symbols(terminals) myfst.set_output_symbols(terminals) myfst.optimize(True) return myfst, rules
def fsa_from_list_of_symbols(input, symbol_table): """ :param input: :type input: :param symbol_table: :type symbol_table: SymbolTable :return: An acceptor for the given list of tokens. :rtype: Fst The symbol table gets extended, if new tokens occur in the input. """ fsa = Fst() fsa.set_input_symbols(symbol_table) fsa.set_output_symbols(symbol_table) state = fsa.add_state() fsa.set_start(state) for x in input: next_state = fsa.add_state() try: arc = Arc(symbol_table.find(x), symbol_table.find(x), 0, next_state) except KeyError: arc = Arc(symbol_table.add_symbol(x), symbol_table.add_symbol(x), 0, next_state) fsa.add_arc(state, arc) state = next_state fsa.set_final(state) return fsa
def all_substrings(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() start_state = fsa.start() for s in fsa.states(): fsa.set_final(s) fsa.add_arc( start_state, pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s)) return fsa.optimize()
def all_prefixes(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() for s in fsa.states(): fsa.set_final(s) return fsa.optimize()