예제 #1
0
def fstprintall(
    in_fst: fst.Fst,
    out_file: Optional[TextIO] = None,
    exclude_meta: bool = True,
    eps: str = "<eps>",
) -> List[List[str]]:
    sentences = []
    output_symbols = in_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    zero_weight = fst.Weight.Zero(in_fst.weight_type())
    visited_states = set()

    state_queue = deque()
    state_queue.append((in_fst.start(), []))

    while len(state_queue) > 0:
        state, sentence = state_queue.popleft()
        if state in visited_states:
            continue

        visited_states.add(state)

        if in_fst.final(state) != zero_weight:
            if out_file:
                print(" ".join(sentence), file=out_file)
            else:
                sentences.append(sentence)

        for arc in in_fst.arcs(state):
            arc_sentence = list(sentence)
            if arc.olabel != out_eps:
                out_symbol = output_symbols.find(arc.olabel).decode()
                if exclude_meta and out_symbol.startswith("__"):
                    pass  # skip __label__, etc.
                else:
                    arc_sentence.append(out_symbol)

            state_queue.append((arc.nextstate, arc_sentence))

    return sentences
예제 #2
0
def _all_valid_strings(td: fst.Fst) -> List[Tuple[List[int], float]]:
    """Return an enumeration of the emission language. Essentially the equivalent of
    fstprint, but not handling the de-interning of strings.

    The weight returned is not the weight of the whole sequence, but the weight
    of the final state in the sequence as a final state.

    Does not check for duplicate emissions or cycles in the transducer.

    :param td: transducer, the emission language of which to enumerate
    :returns: a list of (interned emission symbols, weight) tuples

    """
    if td.start() == -1:
        return []
    stack = [(td.start(), [])]
    complete_emissions = []
    while stack:
        state, output = stack.pop()
        final_weight = float(td.final(state))
        if np.isfinite(final_weight):
            complete_emissions.append((output, final_weight))
        stack += [(a.nextstate, output + [a.olabel]) for a in td.arcs(state)]
    return complete_emissions
예제 #3
0
def _replace_fsts(outer_fst: fst.Fst,
                  replacements: Dict[int, fst.Fst],
                  eps="<eps>") -> fst.Fst:
    input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    state_map: Dict[Union[int, Tuple[int, int]], int] = {}

    # Create new FST
    new_fst = fst.Fst()
    new_input_symbols = fst.SymbolTable()
    new_output_symbols = fst.SymbolTable()

    weight_one = fst.Weight.One(new_fst.weight_type())
    weight_zero = fst.Weight.Zero(new_fst.weight_type())
    weight_final = fst.Weight.Zero(outer_fst.weight_type())

    # Copy symbols
    outer_input_symbols = outer_fst.input_symbols()
    for i in range(outer_input_symbols.num_symbols()):
        key = outer_input_symbols.get_nth_key(i)
        input_symbol_map[key] = new_input_symbols.add_symbol(
            outer_input_symbols.find(key))

    outer_output_symbols = outer_fst.output_symbols()
    for i in range(outer_output_symbols.num_symbols()):
        key = outer_output_symbols.get_nth_key(i)
        output_symbol_map[key] = new_output_symbols.add_symbol(
            outer_output_symbols.find(key))

    in_eps = new_input_symbols.add_symbol(eps)
    out_eps = new_output_symbols.add_symbol(eps)

    # Copy states
    for outer_state in outer_fst.states():
        new_state = new_fst.add_state()
        state_map[outer_state] = new_state

        if outer_fst.final(outer_state) != weight_final:
            new_fst.set_final(new_state)

    # Set start state
    new_fst.set_start(state_map[outer_fst.start()])

    # Copy arcs
    for outer_state in outer_fst.states():
        new_state = state_map[outer_state]
        for outer_arc in outer_fst.arcs(outer_state):
            next_state = state_map[outer_arc.nextstate]
            replace_fst = replacements.get(outer_arc.olabel)

            if replace_fst is not None:
                # Replace in-line
                r = outer_arc.olabel
                replace_final = fst.Weight.Zero(replace_fst.weight_type())
                replace_input_symbols = replace_fst.input_symbols()
                replace_output_symbols = replace_fst.output_symbols()

                # Copy states
                for replace_state in replace_fst.states():
                    state_map[(r, replace_state)] = new_fst.add_state()

                    # Create final arc to next state
                    if replace_fst.final(replace_state) != replace_final:
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(in_eps, out_eps, weight_one, next_state),
                        )

                # Copy arcs
                for replace_state in replace_fst.states():
                    for replace_arc in replace_fst.arcs(replace_state):
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(
                                new_input_symbols.add_symbol(
                                    replace_input_symbols.find(
                                        replace_arc.ilabel)),
                                new_output_symbols.add_symbol(
                                    replace_output_symbols.find(
                                        replace_arc.olabel)),
                                weight_one,
                                state_map[(r, replace_arc.nextstate)],
                            ),
                        )

                # Create arc into start state
                new_fst.add_arc(
                    new_state,
                    fst.Arc(in_eps, out_eps, weight_one,
                            state_map[(r, replace_fst.start())]),
                )
            else:
                # Copy arc as-is
                new_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_symbol_map[outer_arc.ilabel],
                        output_symbol_map[outer_arc.olabel],
                        weight_one,
                        next_state,
                    ),
                )

    # Fix symbol tables
    new_fst.set_input_symbols(new_input_symbols)
    new_fst.set_output_symbols(new_output_symbols)

    return new_fst