Exemple #1
0
def replace_and_patch(
    outer_fst: fst.Fst,
    outer_start_state: int,
    outer_final_state: int,
    inner_fst: fst.Fst,
    label_sym: int,
    eps: int = 0,
) -> None:
    """Copies an inner FST into an outer FST, creating states and mapping symbols.
    Creates arcs from outer start/final states to inner start/final states."""

    in_symbols = outer_fst.input_symbols()
    out_symbols = outer_fst.output_symbols()
    inner_zero = fst.Weight.Zero(inner_fst.weight_type())
    outer_one = fst.Weight.One(outer_fst.weight_type())

    state_map = {}
    in_symbol_map = {}
    out_symbol_map = {}

    for i in range(inner_fst.output_symbols().num_symbols()):
        sym_str = inner_fst.output_symbols().find(i).decode()
        out_symbol_map[i] = out_symbols.find(sym_str)

    for i in range(inner_fst.input_symbols().num_symbols()):
        sym_str = inner_fst.input_symbols().find(i).decode()
        in_symbol_map[i] = in_symbols.find(sym_str)

    # Create states in outer FST
    for inner_state in inner_fst.states():
        state_map[inner_state] = outer_fst.add_state()

    # Create arcs in outer FST
    for inner_state in inner_fst.states():
        if inner_state == inner_fst.start():
            outer_fst.add_arc(
                outer_start_state,
                fst.Arc(eps, label_sym, outer_one, state_map[inner_state]),
            )

        for inner_arc in inner_fst.arcs(inner_state):
            outer_fst.add_arc(
                state_map[inner_state],
                fst.Arc(
                    in_symbol_map[inner_arc.ilabel],
                    out_symbol_map[inner_arc.olabel],
                    outer_one,
                    state_map[inner_arc.nextstate],
                ),
            )

            if inner_fst.final(inner_arc.nextstate) != inner_zero:
                outer_fst.add_arc(
                    state_map[inner_arc.nextstate],
                    fst.Arc(eps, eps, outer_one, outer_final_state),
                )
def fst_to_graph(the_fst: fst.Fst) -> nx.MultiDiGraph:
    """Converts a finite state transducer to a directed graph."""
    zero_weight = fst.Weight.Zero(the_fst.weight_type())
    in_symbols = the_fst.input_symbols()
    out_symbols = the_fst.output_symbols()

    g = nx.MultiDiGraph()

    # Add nodes
    for state in the_fst.states():
        # Mark final states
        is_final = the_fst.final(state) != zero_weight
        g.add_node(state, final=is_final, start=False)

        # Add edges
        for arc in the_fst.arcs(state):
            in_label = in_symbols.find(arc.ilabel).decode()
            out_label = out_symbols.find(arc.olabel).decode()

            g.add_edge(state,
                       arc.nextstate,
                       in_label=in_label,
                       out_label=out_label)

    # Mark start state
    g.add_node(the_fst.start(), start=True)

    return g
Exemple #3
0
def _replace_fsts(outer_fst: fst.Fst,
                  replacements: Dict[int, fst.Fst],
                  eps="<eps>") -> fst.Fst:
    input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    state_map: Dict[Union[int, Tuple[int, int]], int] = {}

    # Create new FST
    new_fst = fst.Fst()
    new_input_symbols = fst.SymbolTable()
    new_output_symbols = fst.SymbolTable()

    weight_one = fst.Weight.One(new_fst.weight_type())
    weight_zero = fst.Weight.Zero(new_fst.weight_type())
    weight_final = fst.Weight.Zero(outer_fst.weight_type())

    # Copy symbols
    outer_input_symbols = outer_fst.input_symbols()
    for i in range(outer_input_symbols.num_symbols()):
        key = outer_input_symbols.get_nth_key(i)
        input_symbol_map[key] = new_input_symbols.add_symbol(
            outer_input_symbols.find(key))

    outer_output_symbols = outer_fst.output_symbols()
    for i in range(outer_output_symbols.num_symbols()):
        key = outer_output_symbols.get_nth_key(i)
        output_symbol_map[key] = new_output_symbols.add_symbol(
            outer_output_symbols.find(key))

    in_eps = new_input_symbols.add_symbol(eps)
    out_eps = new_output_symbols.add_symbol(eps)

    # Copy states
    for outer_state in outer_fst.states():
        new_state = new_fst.add_state()
        state_map[outer_state] = new_state

        if outer_fst.final(outer_state) != weight_final:
            new_fst.set_final(new_state)

    # Set start state
    new_fst.set_start(state_map[outer_fst.start()])

    # Copy arcs
    for outer_state in outer_fst.states():
        new_state = state_map[outer_state]
        for outer_arc in outer_fst.arcs(outer_state):
            next_state = state_map[outer_arc.nextstate]
            replace_fst = replacements.get(outer_arc.olabel)

            if replace_fst is not None:
                # Replace in-line
                r = outer_arc.olabel
                replace_final = fst.Weight.Zero(replace_fst.weight_type())
                replace_input_symbols = replace_fst.input_symbols()
                replace_output_symbols = replace_fst.output_symbols()

                # Copy states
                for replace_state in replace_fst.states():
                    state_map[(r, replace_state)] = new_fst.add_state()

                    # Create final arc to next state
                    if replace_fst.final(replace_state) != replace_final:
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(in_eps, out_eps, weight_one, next_state),
                        )

                # Copy arcs
                for replace_state in replace_fst.states():
                    for replace_arc in replace_fst.arcs(replace_state):
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(
                                new_input_symbols.add_symbol(
                                    replace_input_symbols.find(
                                        replace_arc.ilabel)),
                                new_output_symbols.add_symbol(
                                    replace_output_symbols.find(
                                        replace_arc.olabel)),
                                weight_one,
                                state_map[(r, replace_arc.nextstate)],
                            ),
                        )

                # Create arc into start state
                new_fst.add_arc(
                    new_state,
                    fst.Arc(in_eps, out_eps, weight_one,
                            state_map[(r, replace_fst.start())]),
                )
            else:
                # Copy arc as-is
                new_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_symbol_map[outer_arc.ilabel],
                        output_symbol_map[outer_arc.olabel],
                        weight_one,
                        next_state,
                    ),
                )

    # Fix symbol tables
    new_fst.set_input_symbols(new_input_symbols)
    new_fst.set_output_symbols(new_output_symbols)

    return new_fst