Esempio n. 1
0
def _make_slot_fst(state: int, intent_fst: fst.Fst,
                   slot_to_fst: Dict[str, fst.Fst]):
    out_symbols = intent_fst.output_symbols()
    one_weight = fst.Weight.One(intent_fst.weight_type())

    for arc in intent_fst.arcs(state):
        label = out_symbols.find(arc.olabel).decode()
        if label.startswith("__begin__"):
            slot_name = label[9:]

            # Big assumption here that each instance of a slot (e.g., location)
            # will produce the same FST, and therefore doesn't need to be
            # processed again.
            if slot_name in slot_to_fst:
                continue  # skip duplicate slots

            end_label = f"__end__{slot_name}"

            # Create new FST
            slot_fst = fst.Fst()
            slot_fst.set_input_symbols(intent_fst.input_symbols())
            slot_fst.set_output_symbols(intent_fst.output_symbols())

            start_state = slot_fst.add_state()
            slot_fst.set_start(start_state)
            q = [arc.nextstate]
            state_map = {arc.nextstate: start_state}

            # Copy states/arcs from intent FST until __end__ is found
            while len(q) > 0:
                q_state = q.pop()
                for q_arc in intent_fst.arcs(q_state):
                    slot_arc_label = out_symbols.find(q_arc.olabel).decode()
                    if slot_arc_label != end_label:
                        if not q_arc.nextstate in state_map:
                            state_map[q_arc.nextstate] = slot_fst.add_state()

                        # Create arc
                        slot_fst.add_arc(
                            state_map[q_state],
                            fst.Arc(
                                q_arc.ilabel,
                                q_arc.olabel,
                                one_weight,
                                state_map[q_arc.nextstate],
                            ),
                        )

                        # Continue copy
                        q.append(q_arc.nextstate)
                    else:
                        # Mark previous state as final
                        slot_fst.set_final(state_map[q_state])

            slot_to_fst[slot_name] = minimize_fst(slot_fst)

        # Recurse
        _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
Esempio n. 2
0
def replace_and_patch(
    outer_fst: fst.Fst,
    outer_start_state: int,
    outer_final_state: int,
    inner_fst: fst.Fst,
    label_sym: int,
    eps: int = 0,
) -> None:
    """Copies an inner FST into an outer FST, creating states and mapping symbols.
    Creates arcs from outer start/final states to inner start/final states."""

    in_symbols = outer_fst.input_symbols()
    out_symbols = outer_fst.output_symbols()
    inner_zero = fst.Weight.Zero(inner_fst.weight_type())
    outer_one = fst.Weight.One(outer_fst.weight_type())

    state_map = {}
    in_symbol_map = {}
    out_symbol_map = {}

    for i in range(inner_fst.output_symbols().num_symbols()):
        sym_str = inner_fst.output_symbols().find(i).decode()
        out_symbol_map[i] = out_symbols.find(sym_str)

    for i in range(inner_fst.input_symbols().num_symbols()):
        sym_str = inner_fst.input_symbols().find(i).decode()
        in_symbol_map[i] = in_symbols.find(sym_str)

    # Create states in outer FST
    for inner_state in inner_fst.states():
        state_map[inner_state] = outer_fst.add_state()

    # Create arcs in outer FST
    for inner_state in inner_fst.states():
        if inner_state == inner_fst.start():
            outer_fst.add_arc(
                outer_start_state,
                fst.Arc(eps, label_sym, outer_one, state_map[inner_state]),
            )

        for inner_arc in inner_fst.arcs(inner_state):
            outer_fst.add_arc(
                state_map[inner_state],
                fst.Arc(
                    in_symbol_map[inner_arc.ilabel],
                    out_symbol_map[inner_arc.olabel],
                    outer_one,
                    state_map[inner_arc.nextstate],
                ),
            )

            if inner_fst.final(inner_arc.nextstate) != inner_zero:
                outer_fst.add_arc(
                    state_map[inner_arc.nextstate],
                    fst.Arc(eps, eps, outer_one, outer_final_state),
                )
Esempio n. 3
0
def fst_to_graph(the_fst: fst.Fst) -> nx.MultiDiGraph:
    """Converts a finite state transducer to a directed graph."""
    zero_weight = fst.Weight.Zero(the_fst.weight_type())
    in_symbols = the_fst.input_symbols()
    out_symbols = the_fst.output_symbols()

    g = nx.MultiDiGraph()

    # Add nodes
    for state in the_fst.states():
        # Mark final states
        is_final = the_fst.final(state) != zero_weight
        g.add_node(state, final=is_final, start=False)

        # Add edges
        for arc in the_fst.arcs(state):
            in_label = in_symbols.find(arc.ilabel).decode()
            out_label = out_symbols.find(arc.olabel).decode()

            g.add_edge(state,
                       arc.nextstate,
                       in_label=in_label,
                       out_label=out_label)

    # Mark start state
    g.add_node(the_fst.start(), start=True)

    return g
Esempio n. 4
0
def fstprintall(
    in_fst: fst.Fst,
    out_file: Optional[TextIO] = None,
    exclude_meta: bool = True,
    eps: str = "<eps>",
) -> List[List[str]]:
    sentences = []
    output_symbols = in_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    zero_weight = fst.Weight.Zero(in_fst.weight_type())

    state_queue: Deque[Tuple[int, List[str]]] = deque()
    state_queue.append((in_fst.start(), []))

    while len(state_queue) > 0:
        state, sentence = state_queue.popleft()

        if in_fst.final(state) != zero_weight:
            if out_file:
                print(" ".join(sentence), file=out_file)
            else:
                sentences.append(sentence)

        for arc in in_fst.arcs(state):
            arc_sentence = list(sentence)
            if arc.olabel != out_eps:
                out_symbol = output_symbols.find(arc.olabel).decode()
                if exclude_meta and out_symbol.startswith("__"):
                    pass  # skip __label__, etc.
                else:
                    arc_sentence.append(out_symbol)

            state_queue.append((arc.nextstate, arc_sentence))

    return sentences
Esempio n. 5
0
def fstprintall(
    in_fst: fst.Fst,
    out_file: Optional[TextIO] = None,
    exclude_meta: bool = True,
    state: Optional[int] = None,
    path: Optional[List[fst.Arc]] = None,
    zero_weight: Optional[fst.Weight] = None,
    eps: int = 0,
) -> List[List[str]]:
    sentences = []
    path = path or []
    state = state or in_fst.start()
    zero_weight = zero_weight or fst.Weight.Zero(in_fst.weight_type())

    for arc in in_fst.arcs(state):
        path.append(arc)

        if in_fst.final(arc.nextstate) != zero_weight:
            # Final state
            out_syms = in_fst.output_symbols()
            sentence = []
            for p_arc in path:
                if p_arc.olabel != eps:
                    osym = out_syms.find(p_arc.olabel).decode()
                    if exclude_meta and osym.startswith("__"):
                        continue  # skip __label__, etc.

                    if out_file:
                        print(osym, "", end="", file=out_file)
                    else:
                        sentence.append(osym)

            if out_file:
                print("", file=out_file)
            else:
                sentences.append(sentence)
        else:
            # Non-final state
            sentences.extend(
                fstprintall(
                    in_fst,
                    out_file=out_file,
                    state=arc.nextstate,
                    path=path,
                    zero_weight=zero_weight,
                    eps=eps,
                    exclude_meta=exclude_meta,
                ))

        path.pop()

    return sentences
Esempio n. 6
0
def longest_path(the_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst:
    output_symbols = the_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    visited_states: Set[int] = set()
    best_path: List[int] = []
    state_queue: Deque[Tuple[int, List[int]]] = deque()
    state_queue.append((the_fst.start(), []))

    # Determine longest path
    while len(state_queue) > 0:
        state, path = state_queue.popleft()
        if state in visited_states:
            continue

        visited_states.add(state)

        if len(path) > len(best_path):
            best_path = path

        for arc in the_fst.arcs(state):
            next_path = list(path)
            next_path.append(arc.olabel)
            state_queue.append((arc.nextstate, next_path))

    # Create FST with longest path
    path_fst = fst.Fst()

    input_symbols = fst.SymbolTable()
    input_symbols.add_symbol(eps)
    path_fst.set_output_symbols(output_symbols)
    weight_one = fst.Weight.One(path_fst.weight_type())

    state = path_fst.add_state()
    path_fst.set_start(state)

    for olabel in best_path:
        osym = output_symbols.find(olabel).decode()
        next_state = path_fst.add_state()
        path_fst.add_arc(
            state,
            fst.Arc(input_symbols.add_symbol(osym), olabel, weight_one,
                    next_state),
        )
        state = next_state

    path_fst.set_final(state)
    path_fst.set_input_symbols(input_symbols)

    return path_fst
Esempio n. 7
0
def make_slot_fsts(intent_fst: fst.Fst) -> Dict[str, Dict[str, fst.Fst]]:
    out_symbols = intent_fst.output_symbols()
    intent_to_slots: Dict[str, Dict[str, fst.Fst]] = {}

    start_state = intent_fst.start()
    for intent_arc in intent_fst.arcs(start_state):
        # Extract intent name from output label
        intent_label = out_symbols.find(intent_arc.olabel).decode()
        assert intent_label.startswith("__label__"), intent_label
        intent_name = intent_label[9:]

        # Create mapping from slot (tag) name to acceptor FST
        slot_to_fst: Dict[str, fst.Fst] = {}
        intent_to_slots[intent_name] = slot_to_fst

        _make_slot_fst(intent_arc.nextstate, intent_fst, slot_to_fst)

    return intent_to_slots
Esempio n. 8
0
def make_slot_acceptor(intent_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst:
    in_eps = intent_fst.input_symbols().find(eps)
    out_eps = intent_fst.output_symbols().find(eps)
    slot_fst = fst.Fst()

    # Copy symbol tables
    all_symbols = fst.SymbolTable()
    meta_keys = set()

    for table in [intent_fst.input_symbols(), intent_fst.output_symbols()]:
        for i in range(table.num_symbols()):
            key = table.get_nth_key(i)
            sym = table.find(key).decode()
            all_key = all_symbols.add_symbol(sym)
            if sym.startswith("__"):
                meta_keys.add(all_key)

    weight_one = fst.Weight.One(slot_fst.weight_type())
    weight_zero = fst.Weight.Zero(slot_fst.weight_type())

    # States that will be set to final
    final_states: Set[int] = set()

    # States that already have all-word loops
    loop_states: Set[int] = set()

    all_eps = all_symbols.find(eps)

    # Add self transitions to a state for all input words (besides <eps>)
    def add_loop_state(state):
        for sym_idx in range(all_symbols.num_symbols()):
            all_key = all_symbols.get_nth_key(sym_idx)
            if (all_key != all_eps) and (all_key not in meta_keys):
                slot_fst.add_arc(state,
                                 fst.Arc(all_key, all_key, weight_one, state))

    slot_fst.set_start(slot_fst.add_state())

    # Queue of (intent state, acceptor state, copy count)
    state_queue: Deque[Tuple[int, int, int]] = deque()
    state_queue.append((intent_fst.start(), slot_fst.start(), 0))

    # BFS
    while len(state_queue) > 0:
        intent_state, slot_state, do_copy = state_queue.popleft()
        final_states.add(slot_state)
        for intent_arc in intent_fst.arcs(intent_state):
            out_symbol = intent_fst.output_symbols().find(
                intent_arc.olabel).decode()
            all_key = all_symbols.find(out_symbol)

            if out_symbol.startswith("__label__"):
                # Create corresponding __label__ arc
                next_state = slot_fst.add_state()
                slot_fst.add_arc(
                    slot_state,
                    fst.Arc(all_key, all_key, weight_one, next_state))

                # Must create a loop here for intents with no slots
                add_loop_state(next_state)
                loop_states.add(slot_state)
            else:
                # Non-label arc
                if out_symbol.startswith("__begin__"):
                    # States/arcs will be copied until __end__ is reached
                    do_copy += 1

                    # Add loop transitions to soak up non-tag words
                    if not slot_state in loop_states:
                        add_loop_state(slot_state)
                        loop_states.add(slot_state)

                if (do_copy > 0) and ((intent_arc.ilabel != in_eps) or
                                      (intent_arc.olabel != out_eps)):
                    # Copy state/arc
                    in_symbol = (intent_fst.input_symbols().find(
                        intent_arc.ilabel).decode())
                    next_state = slot_fst.add_state()
                    slot_fst.add_arc(
                        slot_state,
                        fst.Arc(all_symbols.find(in_symbol), all_key,
                                weight_one, next_state),
                    )
                    final_states.discard(slot_state)
                else:
                    next_state = slot_state

                if out_symbol.startswith("__end__"):
                    # Stop copying after this state until next __begin__
                    do_copy -= 1

            next_info = (intent_arc.nextstate, next_state, do_copy)
            state_queue.append(next_info)

    # Mark all dangling states as final (excluding start)
    for state in final_states:
        if state != slot_fst.start():
            slot_fst.set_final(state)

    # Fix symbol tables
    slot_fst.set_input_symbols(all_symbols)
    slot_fst.set_output_symbols(all_symbols)

    return slot_fst
Esempio n. 9
0
def _replace_fsts(outer_fst: fst.Fst,
                  replacements: Dict[int, fst.Fst],
                  eps="<eps>") -> fst.Fst:
    input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    state_map: Dict[Union[int, Tuple[int, int]], int] = {}

    # Create new FST
    new_fst = fst.Fst()
    new_input_symbols = fst.SymbolTable()
    new_output_symbols = fst.SymbolTable()

    weight_one = fst.Weight.One(new_fst.weight_type())
    weight_zero = fst.Weight.Zero(new_fst.weight_type())
    weight_final = fst.Weight.Zero(outer_fst.weight_type())

    # Copy symbols
    outer_input_symbols = outer_fst.input_symbols()
    for i in range(outer_input_symbols.num_symbols()):
        key = outer_input_symbols.get_nth_key(i)
        input_symbol_map[key] = new_input_symbols.add_symbol(
            outer_input_symbols.find(key))

    outer_output_symbols = outer_fst.output_symbols()
    for i in range(outer_output_symbols.num_symbols()):
        key = outer_output_symbols.get_nth_key(i)
        output_symbol_map[key] = new_output_symbols.add_symbol(
            outer_output_symbols.find(key))

    in_eps = new_input_symbols.add_symbol(eps)
    out_eps = new_output_symbols.add_symbol(eps)

    # Copy states
    for outer_state in outer_fst.states():
        new_state = new_fst.add_state()
        state_map[outer_state] = new_state

        if outer_fst.final(outer_state) != weight_final:
            new_fst.set_final(new_state)

    # Set start state
    new_fst.set_start(state_map[outer_fst.start()])

    # Copy arcs
    for outer_state in outer_fst.states():
        new_state = state_map[outer_state]
        for outer_arc in outer_fst.arcs(outer_state):
            next_state = state_map[outer_arc.nextstate]
            replace_fst = replacements.get(outer_arc.olabel)

            if replace_fst is not None:
                # Replace in-line
                r = outer_arc.olabel
                replace_final = fst.Weight.Zero(replace_fst.weight_type())
                replace_input_symbols = replace_fst.input_symbols()
                replace_output_symbols = replace_fst.output_symbols()

                # Copy states
                for replace_state in replace_fst.states():
                    state_map[(r, replace_state)] = new_fst.add_state()

                    # Create final arc to next state
                    if replace_fst.final(replace_state) != replace_final:
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(in_eps, out_eps, weight_one, next_state),
                        )

                # Copy arcs
                for replace_state in replace_fst.states():
                    for replace_arc in replace_fst.arcs(replace_state):
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(
                                new_input_symbols.add_symbol(
                                    replace_input_symbols.find(
                                        replace_arc.ilabel)),
                                new_output_symbols.add_symbol(
                                    replace_output_symbols.find(
                                        replace_arc.olabel)),
                                weight_one,
                                state_map[(r, replace_arc.nextstate)],
                            ),
                        )

                # Create arc into start state
                new_fst.add_arc(
                    new_state,
                    fst.Arc(in_eps, out_eps, weight_one,
                            state_map[(r, replace_fst.start())]),
                )
            else:
                # Copy arc as-is
                new_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_symbol_map[outer_arc.ilabel],
                        output_symbol_map[outer_arc.olabel],
                        weight_one,
                        next_state,
                    ),
                )

    # Fix symbol tables
    new_fst.set_input_symbols(new_input_symbols)
    new_fst.set_output_symbols(new_output_symbols)

    return new_fst