def fsa_from_list_of_symbols(input, symbol_table):
    """
    :param input:
    :type input:
    :param symbol_table:
    :type symbol_table: SymbolTable
    :return: An acceptor for the given list of tokens.
    :rtype: Fst
    The symbol table gets extended, if new tokens occur in the input.
    """
    fsa = Fst()
    fsa.set_input_symbols(symbol_table)
    fsa.set_output_symbols(symbol_table)
    state = fsa.add_state()
    fsa.set_start(state)
    for x in input:
        next_state = fsa.add_state()
        try:
            arc = Arc(symbol_table.find(x), symbol_table.find(x), 0, next_state)
        except KeyError:
            arc = Arc(symbol_table.add_symbol(x), symbol_table.add_symbol(x), 0, next_state)
        fsa.add_arc(state, arc)
        state = next_state
    fsa.set_final(state)
    return fsa
Beispiel #2
0
 def all_suffixes(self, fsa: pynini.Fst) -> pynini.Fst:
     fsa = fsa.copy()
     start_state = fsa.start()
     for s in fsa.states():
         fsa.add_arc(
             start_state,
             pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s))
     return fsa.optimize()
def compile_wfst_from_left_branching_grammar(grammar):
    """
        :type grammar: LCFRS
        :rtype: Fst, Enumerator
        Create a FST from a left-branching hybrid grammar.
        The Output of the is a rule tree in `reverse polish notation <https://en.wikipedia.org/wiki/Reverse_Polish_notation>`_
        """
    myfst = Fst()

    nonterminals = SymbolTable()
    sid = myfst.add_state()
    nonterminals.add_symbol(INITIAL)
    myfst.set_start(sid)

    for nont in grammar.nonts():
        sid = myfst.add_state()
        nonterminals.add_symbol(nont, sid)
        if nont == grammar.start():
            myfst.set_final(sid)

    rules = Enumerator(first_index=1)
    for rule in grammar.rules():
        rules.object_index(rule)

    terminals = SymbolTable()
    terminals.add_symbol('<epsilon>', 0)

    for rule in grammar.rules():
        if rule.rank() == 2:
            assert len(rule.lhs().arg(0)) == 2
            for rule2 in grammar.lhs_nont_to_rules(rule.rhs_nont(1)):
                if len(rule2.rhs()) == 0:
                    arc = Arc(terminals.add_symbol(rule2.lhs().args()[0][0]),
                              terminals.add_symbol(
                                  str(rules.object_index(rule2)) + '-' +
                                  str(rules.object_index(rule))),
                              make_weight(rule.weight() * rule2.weight()),
                              nonterminals.find(rule.lhs().nont()))
                    myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)
        elif rule.rank() == 0:
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(terminals.add_symbol(rule.lhs().args()[0][0]),
                      terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(rule.lhs().nont()))
            myfst.add_arc(nonterminals.find(INITIAL), arc)
        else:
            assert rule.rank() == 1
            assert rule.lhs().nont() == grammar.start()
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(0, terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(grammar.start())
                      )
            myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)

    myfst.set_input_symbols(terminals)
    myfst.set_output_symbols(terminals)

    myfst.optimize(True)

    return myfst, rules