Пример #1
0
def fsa_from_list_of_symbols(input, symbol_table):
    """
    :param input:
    :type input:
    :param symbol_table:
    :type symbol_table: SymbolTable
    :return: An acceptor for the given list of tokens.
    :rtype: Fst
    The symbol table gets extended, if new tokens occur in the input.
    """
    fsa = Fst()
    fsa.set_input_symbols(symbol_table)
    fsa.set_output_symbols(symbol_table)
    state = fsa.add_state()
    fsa.set_start(state)
    for x in input:
        next_state = fsa.add_state()
        try:
            arc = Arc(symbol_table.find(x), symbol_table.find(x), 0, next_state)
        except KeyError:
            arc = Arc(symbol_table.add_symbol(x), symbol_table.add_symbol(x), 0, next_state)
        fsa.add_arc(state, arc)
        state = next_state
    fsa.set_final(state)
    return fsa
Пример #2
0
def compile_wfst_from_left_branching_grammar(grammar):
    """
        :type grammar: LCFRS
        :rtype: Fst, Enumerator
        Create a FST from a left-branching hybrid grammar.
        The Output of the is a rule tree in `reverse polish notation <https://en.wikipedia.org/wiki/Reverse_Polish_notation>`_
        """
    myfst = Fst()

    nonterminals = SymbolTable()
    sid = myfst.add_state()
    nonterminals.add_symbol(INITIAL)
    myfst.set_start(sid)

    for nont in grammar.nonts():
        sid = myfst.add_state()
        nonterminals.add_symbol(nont, sid)
        if nont == grammar.start():
            myfst.set_final(sid)

    rules = Enumerator(first_index=1)
    for rule in grammar.rules():
        rules.object_index(rule)

    terminals = SymbolTable()
    terminals.add_symbol('<epsilon>', 0)

    for rule in grammar.rules():
        if rule.rank() == 2:
            assert len(rule.lhs().arg(0)) == 2
            for rule2 in grammar.lhs_nont_to_rules(rule.rhs_nont(1)):
                if len(rule2.rhs()) == 0:
                    arc = Arc(terminals.add_symbol(rule2.lhs().args()[0][0]),
                              terminals.add_symbol(
                                  str(rules.object_index(rule2)) + '-' +
                                  str(rules.object_index(rule))),
                              make_weight(rule.weight() * rule2.weight()),
                              nonterminals.find(rule.lhs().nont()))
                    myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)
        elif rule.rank() == 0:
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(terminals.add_symbol(rule.lhs().args()[0][0]),
                      terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(rule.lhs().nont()))
            myfst.add_arc(nonterminals.find(INITIAL), arc)
        else:
            assert rule.rank() == 1
            assert rule.lhs().nont() == grammar.start()
            assert len(rule.lhs().arg(0)) == 1
            arc = Arc(0, terminals.add_symbol(str(rules.object_index(rule))), make_weight(rule.weight()),
                      nonterminals.find(grammar.start())
                      )
            myfst.add_arc(nonterminals.find(rule.rhs_nont(0)), arc)

    myfst.set_input_symbols(terminals)
    myfst.set_output_symbols(terminals)

    myfst.optimize(True)

    return myfst, rules
Пример #3
0
    def _flip_lemmatizer_feature_labels(self,
                                        lemmatizer: pynini.Fst) -> pynini.Fst:
        """Helper function to flip lemmatizer's feature labels from input to output.

    Destructive operation.

    Args:
      lemmatizer: FST representing a partially constructed lemmatizer.

    Returns:
      Modified lemmatizer.
    """
        feature_labels = set()
        for s in self.category.feature_labels.states():
            aiter = self.category.feature_labels.arcs(s)
            while not aiter.done():
                arc = aiter.value()
                if arc.ilabel:
                    feature_labels.add(arc.ilabel)
                aiter.next()
        lemmatizer.set_input_symbols(lemmatizer.output_symbols())
        for s in lemmatizer.states():
            maiter = lemmatizer.mutable_arcs(s)
            while not maiter.done():
                arc = maiter.value()
                if arc.olabel in feature_labels:
                    # This assertion should always be true by construction.
                    assert arc.ilabel == 0, (
                        f"ilabel = "
                        f"{lemmatizer.input_symbols().find(arc.ilabel)},"
                        f" olabel = "
                        f"{lemmatizer.output_symbols().find(arc.olabel)}")
                    arc = pynini.Arc(arc.olabel, arc.ilabel, arc.weight,
                                     arc.nextstate)
                    maiter.set_value(arc)
                maiter.next()
        return lemmatizer