コード例 #1
0
def fstprintall(
    in_fst: fst.Fst,
    out_file: Optional[TextIO] = None,
    exclude_meta: bool = True,
    eps: str = "<eps>",
) -> List[List[str]]:
    sentences = []
    output_symbols = in_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    zero_weight = fst.Weight.Zero(in_fst.weight_type())

    state_queue: Deque[Tuple[int, List[str]]] = deque()
    state_queue.append((in_fst.start(), []))

    while len(state_queue) > 0:
        state, sentence = state_queue.popleft()

        if in_fst.final(state) != zero_weight:
            if out_file:
                print(" ".join(sentence), file=out_file)
            else:
                sentences.append(sentence)

        for arc in in_fst.arcs(state):
            arc_sentence = list(sentence)
            if arc.olabel != out_eps:
                out_symbol = output_symbols.find(arc.olabel).decode()
                if exclude_meta and out_symbol.startswith("__"):
                    pass  # skip __label__, etc.
                else:
                    arc_sentence.append(out_symbol)

            state_queue.append((arc.nextstate, arc_sentence))

    return sentences
コード例 #2
0
def fstprintall(
    in_fst: fst.Fst,
    out_file: Optional[TextIO] = None,
    exclude_meta: bool = True,
    state: Optional[int] = None,
    path: Optional[List[fst.Arc]] = None,
    zero_weight: Optional[fst.Weight] = None,
    eps: int = 0,
) -> List[List[str]]:
    sentences = []
    path = path or []
    state = state or in_fst.start()
    zero_weight = zero_weight or fst.Weight.Zero(in_fst.weight_type())

    for arc in in_fst.arcs(state):
        path.append(arc)

        if in_fst.final(arc.nextstate) != zero_weight:
            # Final state
            out_syms = in_fst.output_symbols()
            sentence = []
            for p_arc in path:
                if p_arc.olabel != eps:
                    osym = out_syms.find(p_arc.olabel).decode()
                    if exclude_meta and osym.startswith("__"):
                        continue  # skip __label__, etc.

                    if out_file:
                        print(osym, "", end="", file=out_file)
                    else:
                        sentence.append(osym)

            if out_file:
                print("", file=out_file)
            else:
                sentences.append(sentence)
        else:
            # Non-final state
            sentences.extend(
                fstprintall(
                    in_fst,
                    out_file=out_file,
                    state=arc.nextstate,
                    path=path,
                    zero_weight=zero_weight,
                    eps=eps,
                    exclude_meta=exclude_meta,
                ))

        path.pop()

    return sentences
コード例 #3
0
def longest_path(the_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst:
    output_symbols = the_fst.output_symbols()
    out_eps = output_symbols.find(eps)
    visited_states: Set[int] = set()
    best_path: List[int] = []
    state_queue: Deque[Tuple[int, List[int]]] = deque()
    state_queue.append((the_fst.start(), []))

    # Determine longest path
    while len(state_queue) > 0:
        state, path = state_queue.popleft()
        if state in visited_states:
            continue

        visited_states.add(state)

        if len(path) > len(best_path):
            best_path = path

        for arc in the_fst.arcs(state):
            next_path = list(path)
            next_path.append(arc.olabel)
            state_queue.append((arc.nextstate, next_path))

    # Create FST with longest path
    path_fst = fst.Fst()

    input_symbols = fst.SymbolTable()
    input_symbols.add_symbol(eps)
    path_fst.set_output_symbols(output_symbols)
    weight_one = fst.Weight.One(path_fst.weight_type())

    state = path_fst.add_state()
    path_fst.set_start(state)

    for olabel in best_path:
        osym = output_symbols.find(olabel).decode()
        next_state = path_fst.add_state()
        path_fst.add_arc(
            state,
            fst.Arc(input_symbols.add_symbol(osym), olabel, weight_one,
                    next_state),
        )
        state = next_state

    path_fst.set_final(state)
    path_fst.set_input_symbols(input_symbols)

    return path_fst
コード例 #4
0
def fst_to_graph(the_fst: fst.Fst) -> nx.MultiDiGraph:
    """Converts a finite state transducer to a directed graph."""
    zero_weight = fst.Weight.Zero(the_fst.weight_type())
    in_symbols = the_fst.input_symbols()
    out_symbols = the_fst.output_symbols()

    g = nx.MultiDiGraph()

    # Add nodes
    for state in the_fst.states():
        # Mark final states
        is_final = the_fst.final(state) != zero_weight
        g.add_node(state, final=is_final, start=False)

        # Add edges
        for arc in the_fst.arcs(state):
            in_label = in_symbols.find(arc.ilabel).decode()
            out_label = out_symbols.find(arc.olabel).decode()

            g.add_edge(state,
                       arc.nextstate,
                       in_label=in_label,
                       out_label=out_label)

    # Mark start state
    g.add_node(the_fst.start(), start=True)

    return g
コード例 #5
0
def make_slot_fsts(intent_fst: fst.Fst) -> Dict[str, Dict[str, fst.Fst]]:
    out_symbols = intent_fst.output_symbols()
    intent_to_slots: Dict[str, Dict[str, fst.Fst]] = {}

    start_state = intent_fst.start()
    for intent_arc in intent_fst.arcs(start_state):
        # Extract intent name from output label
        intent_label = out_symbols.find(intent_arc.olabel).decode()
        assert intent_label.startswith("__label__"), intent_label
        intent_name = intent_label[9:]

        # Create mapping from slot (tag) name to acceptor FST
        slot_to_fst: Dict[str, fst.Fst] = {}
        intent_to_slots[intent_name] = slot_to_fst

        _make_slot_fst(intent_arc.nextstate, intent_fst, slot_to_fst)

    return intent_to_slots
コード例 #6
0
def _make_slot_fst(state: int, intent_fst: fst.Fst,
                   slot_to_fst: Dict[str, fst.Fst]):
    out_symbols = intent_fst.output_symbols()
    one_weight = fst.Weight.One(intent_fst.weight_type())

    for arc in intent_fst.arcs(state):
        label = out_symbols.find(arc.olabel).decode()
        if label.startswith("__begin__"):
            slot_name = label[9:]

            # Big assumption here that each instance of a slot (e.g., location)
            # will produce the same FST, and therefore doesn't need to be
            # processed again.
            if slot_name in slot_to_fst:
                continue  # skip duplicate slots

            end_label = f"__end__{slot_name}"

            # Create new FST
            slot_fst = fst.Fst()
            slot_fst.set_input_symbols(intent_fst.input_symbols())
            slot_fst.set_output_symbols(intent_fst.output_symbols())

            start_state = slot_fst.add_state()
            slot_fst.set_start(start_state)
            q = [arc.nextstate]
            state_map = {arc.nextstate: start_state}

            # Copy states/arcs from intent FST until __end__ is found
            while len(q) > 0:
                q_state = q.pop()
                for q_arc in intent_fst.arcs(q_state):
                    slot_arc_label = out_symbols.find(q_arc.olabel).decode()
                    if slot_arc_label != end_label:
                        if not q_arc.nextstate in state_map:
                            state_map[q_arc.nextstate] = slot_fst.add_state()

                        # Create arc
                        slot_fst.add_arc(
                            state_map[q_state],
                            fst.Arc(
                                q_arc.ilabel,
                                q_arc.olabel,
                                one_weight,
                                state_map[q_arc.nextstate],
                            ),
                        )

                        # Continue copy
                        q.append(q_arc.nextstate)
                    else:
                        # Mark previous state as final
                        slot_fst.set_final(state_map[q_state])

            slot_to_fst[slot_name] = minimize_fst(slot_fst)

        # Recurse
        _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
コード例 #7
0
def _all_valid_strings(td: fst.Fst) -> List[Tuple[List[int], float]]:
    """Return an enumeration of the emission language. Essentially the equivalent of
    fstprint, but not handling the de-interning of strings.

    The weight returned is not the weight of the whole sequence, but the weight
    of the final state in the sequence as a final state.

    Does not check for duplicate emissions or cycles in the transducer.

    :param td: transducer, the emission language of which to enumerate
    :returns: a list of (interned emission symbols, weight) tuples

    """
    if td.start() == -1:
        return []
    stack = [(td.start(), [])]
    complete_emissions = []
    while stack:
        state, output = stack.pop()
        final_weight = float(td.final(state))
        if np.isfinite(final_weight):
            complete_emissions.append((output, final_weight))
        stack += [(a.nextstate, output + [a.olabel]) for a in td.arcs(state)]
    return complete_emissions
コード例 #8
0
def linear_fst(
    elements: List[str],
    automata_op: fst.Fst,
    keep_isymbols: bool = True,
    **kwargs: Mapping[Any, Any],
) -> fst.Fst:
    """Produce a linear automata."""
    assert len(elements) > 0, "No elements"
    compiler = fst.Compiler(
        isymbols=automata_op.input_symbols().copy(),
        acceptor=keep_isymbols,
        keep_isymbols=keep_isymbols,
        **kwargs,
    )

    num_elements = 0
    for i, el in enumerate(elements):
        print("{} {} {}".format(i, i + 1, el), file=compiler)
        num_elements += 1

    print(str(num_elements), file=compiler)

    return compiler.compile()
コード例 #9
0
def replace_and_patch(
    outer_fst: fst.Fst,
    outer_start_state: int,
    outer_final_state: int,
    inner_fst: fst.Fst,
    label_sym: int,
    eps: int = 0,
) -> None:
    """Copies an inner FST into an outer FST, creating states and mapping symbols.
    Creates arcs from outer start/final states to inner start/final states."""

    in_symbols = outer_fst.input_symbols()
    out_symbols = outer_fst.output_symbols()
    inner_zero = fst.Weight.Zero(inner_fst.weight_type())
    outer_one = fst.Weight.One(outer_fst.weight_type())

    state_map = {}
    in_symbol_map = {}
    out_symbol_map = {}

    for i in range(inner_fst.output_symbols().num_symbols()):
        sym_str = inner_fst.output_symbols().find(i).decode()
        out_symbol_map[i] = out_symbols.find(sym_str)

    for i in range(inner_fst.input_symbols().num_symbols()):
        sym_str = inner_fst.input_symbols().find(i).decode()
        in_symbol_map[i] = in_symbols.find(sym_str)

    # Create states in outer FST
    for inner_state in inner_fst.states():
        state_map[inner_state] = outer_fst.add_state()

    # Create arcs in outer FST
    for inner_state in inner_fst.states():
        if inner_state == inner_fst.start():
            outer_fst.add_arc(
                outer_start_state,
                fst.Arc(eps, label_sym, outer_one, state_map[inner_state]),
            )

        for inner_arc in inner_fst.arcs(inner_state):
            outer_fst.add_arc(
                state_map[inner_state],
                fst.Arc(
                    in_symbol_map[inner_arc.ilabel],
                    out_symbol_map[inner_arc.olabel],
                    outer_one,
                    state_map[inner_arc.nextstate],
                ),
            )

            if inner_fst.final(inner_arc.nextstate) != inner_zero:
                outer_fst.add_arc(
                    state_map[inner_arc.nextstate],
                    fst.Arc(eps, eps, outer_one, outer_final_state),
                )
コード例 #10
0
ファイル: stt_train.py プロジェクト: mihaicoli/rhasspy
    def write_dictionary(self, intent_fst: fst.Fst) -> Set[str]:
        """Writes all required words to a CMU dictionary.
        Unknown words have their pronunciations guessed and written to a separate dictionary.
        Fails if any unknown words are found."""

        start_time = time.time()
        words_needed: Set[str] = set()

        # Gather all words needed
        in_symbols = intent_fst.input_symbols()
        for i in range(in_symbols.num_symbols()):
            word = in_symbols.find(i).decode()

            if word.startswith("__") or word.startswith("<"):
                continue  # skip metadata

            # Dictionary uses upper-case letters
            if self.dictionary_upper:
                word = word.upper()
            else:
                word = word.lower()

            words_needed.add(word)

        # Load base and custom dictionaries
        base_dictionary_path = self.profile.read_path(
            self.profile.get(
                f"speech_to_text.{self.system}.base_dictionary", "base_dictionary.txt"
            )
        )

        custom_path = self.profile.read_path(
            self.profile.get(
                f"speech_to_text.{self.system}.custom_words", "custom_words.txt"
            )
        )

        word_dict: Dict[str, List[str]] = {}
        for word_dict_path in [base_dictionary_path, custom_path]:
            if os.path.exists(word_dict_path):
                self._logger.debug(f"Loading dictionary from {word_dict_path}")
                with open(word_dict_path, "r") as dictionary_file:
                    read_dict(dictionary_file, word_dict)

        # Add words from wake word if using pocketsphinx
        if self.profile.get("wake.system") == "pocketsphinx":
            wake_keyphrase = self.profile.get("wake.pocketsphinx.keyphrase", "")
            if len(wake_keyphrase) > 0:
                self._logger.debug(f"Adding words from keyphrase: {wake_keyphrase}")
                _, wake_tokens = sanitize_sentence(
                    wake_keyphrase,
                    self.dictionary_casing,
                    self.replace_patterns,
                    self.split_pattern,
                )

                for word in wake_tokens:
                    # Dictionary uses upper-case letters
                    if self.dictionary_upper:
                        word = word.upper()
                    else:
                        word = word.lower()

                    words_needed.add(word)

        # Determine if we need to include the entire base dictionary
        mix_weight = float(
            self.profile.get(f"speech_to_text.{self.system}.mix_weight", 0)
        )

        if mix_weight > 0:
            self._logger.debug(
                "Including base dictionary because base language model will be mixed"
            )

            # Add in all the words
            words_needed.update(word_dict.keys())

        # Write out dictionary with only the necessary words (speeds up loading)
        dictionary_path = self.profile.write_path(
            self.profile.get(
                f"speech_to_text.{self.system}.dictionary", "dictionary.txt"
            )
        )

        words_written = 0
        number_duplicates = self.profile.get(
            "training.dictionary_number_duplicates", True
        )
        with open(dictionary_path, "w") as dictionary_file:
            for word in sorted(words_needed):
                if not word in word_dict:
                    continue

                for i, pronounce in enumerate(word_dict[word]):
                    if (i < 1) or (not number_duplicates):
                        print(word, pronounce, file=dictionary_file)
                    else:
                        print("%s(%s)" % (word, i + 1), pronounce, file=dictionary_file)

                words_written += 1

        dictionary_time = time.time() - start_time
        self._logger.debug(
            f"Wrote {words_written} word(s) to {dictionary_path} in {dictionary_time} second(s)"
        )

        # Check for unknown words
        return words_needed - word_dict.keys()
コード例 #11
0
def make_slot_acceptor(intent_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst:
    in_eps = intent_fst.input_symbols().find(eps)
    out_eps = intent_fst.output_symbols().find(eps)
    slot_fst = fst.Fst()

    # Copy symbol tables
    all_symbols = fst.SymbolTable()
    meta_keys = set()

    for table in [intent_fst.input_symbols(), intent_fst.output_symbols()]:
        for i in range(table.num_symbols()):
            key = table.get_nth_key(i)
            sym = table.find(key).decode()
            all_key = all_symbols.add_symbol(sym)
            if sym.startswith("__"):
                meta_keys.add(all_key)

    weight_one = fst.Weight.One(slot_fst.weight_type())
    weight_zero = fst.Weight.Zero(slot_fst.weight_type())

    # States that will be set to final
    final_states: Set[int] = set()

    # States that already have all-word loops
    loop_states: Set[int] = set()

    all_eps = all_symbols.find(eps)

    # Add self transitions to a state for all input words (besides <eps>)
    def add_loop_state(state):
        for sym_idx in range(all_symbols.num_symbols()):
            all_key = all_symbols.get_nth_key(sym_idx)
            if (all_key != all_eps) and (all_key not in meta_keys):
                slot_fst.add_arc(state,
                                 fst.Arc(all_key, all_key, weight_one, state))

    slot_fst.set_start(slot_fst.add_state())

    # Queue of (intent state, acceptor state, copy count)
    state_queue: Deque[Tuple[int, int, int]] = deque()
    state_queue.append((intent_fst.start(), slot_fst.start(), 0))

    # BFS
    while len(state_queue) > 0:
        intent_state, slot_state, do_copy = state_queue.popleft()
        final_states.add(slot_state)
        for intent_arc in intent_fst.arcs(intent_state):
            out_symbol = intent_fst.output_symbols().find(
                intent_arc.olabel).decode()
            all_key = all_symbols.find(out_symbol)

            if out_symbol.startswith("__label__"):
                # Create corresponding __label__ arc
                next_state = slot_fst.add_state()
                slot_fst.add_arc(
                    slot_state,
                    fst.Arc(all_key, all_key, weight_one, next_state))

                # Must create a loop here for intents with no slots
                add_loop_state(next_state)
                loop_states.add(slot_state)
            else:
                # Non-label arc
                if out_symbol.startswith("__begin__"):
                    # States/arcs will be copied until __end__ is reached
                    do_copy += 1

                    # Add loop transitions to soak up non-tag words
                    if not slot_state in loop_states:
                        add_loop_state(slot_state)
                        loop_states.add(slot_state)

                if (do_copy > 0) and ((intent_arc.ilabel != in_eps) or
                                      (intent_arc.olabel != out_eps)):
                    # Copy state/arc
                    in_symbol = (intent_fst.input_symbols().find(
                        intent_arc.ilabel).decode())
                    next_state = slot_fst.add_state()
                    slot_fst.add_arc(
                        slot_state,
                        fst.Arc(all_symbols.find(in_symbol), all_key,
                                weight_one, next_state),
                    )
                    final_states.discard(slot_state)
                else:
                    next_state = slot_state

                if out_symbol.startswith("__end__"):
                    # Stop copying after this state until next __begin__
                    do_copy -= 1

            next_info = (intent_arc.nextstate, next_state, do_copy)
            state_queue.append(next_info)

    # Mark all dangling states as final (excluding start)
    for state in final_states:
        if state != slot_fst.start():
            slot_fst.set_final(state)

    # Fix symbol tables
    slot_fst.set_input_symbols(all_symbols)
    slot_fst.set_output_symbols(all_symbols)

    return slot_fst
コード例 #12
0
def filter_words(words: Iterable[str], the_fst: fst.Fst) -> List[str]:
    input_symbols = the_fst.input_symbols()
    return [w for w in words if input_symbols.find(w) >= 0]
コード例 #13
0
def _replace_fsts(outer_fst: fst.Fst,
                  replacements: Dict[int, fst.Fst],
                  eps="<eps>") -> fst.Fst:
    input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {}
    state_map: Dict[Union[int, Tuple[int, int]], int] = {}

    # Create new FST
    new_fst = fst.Fst()
    new_input_symbols = fst.SymbolTable()
    new_output_symbols = fst.SymbolTable()

    weight_one = fst.Weight.One(new_fst.weight_type())
    weight_zero = fst.Weight.Zero(new_fst.weight_type())
    weight_final = fst.Weight.Zero(outer_fst.weight_type())

    # Copy symbols
    outer_input_symbols = outer_fst.input_symbols()
    for i in range(outer_input_symbols.num_symbols()):
        key = outer_input_symbols.get_nth_key(i)
        input_symbol_map[key] = new_input_symbols.add_symbol(
            outer_input_symbols.find(key))

    outer_output_symbols = outer_fst.output_symbols()
    for i in range(outer_output_symbols.num_symbols()):
        key = outer_output_symbols.get_nth_key(i)
        output_symbol_map[key] = new_output_symbols.add_symbol(
            outer_output_symbols.find(key))

    in_eps = new_input_symbols.add_symbol(eps)
    out_eps = new_output_symbols.add_symbol(eps)

    # Copy states
    for outer_state in outer_fst.states():
        new_state = new_fst.add_state()
        state_map[outer_state] = new_state

        if outer_fst.final(outer_state) != weight_final:
            new_fst.set_final(new_state)

    # Set start state
    new_fst.set_start(state_map[outer_fst.start()])

    # Copy arcs
    for outer_state in outer_fst.states():
        new_state = state_map[outer_state]
        for outer_arc in outer_fst.arcs(outer_state):
            next_state = state_map[outer_arc.nextstate]
            replace_fst = replacements.get(outer_arc.olabel)

            if replace_fst is not None:
                # Replace in-line
                r = outer_arc.olabel
                replace_final = fst.Weight.Zero(replace_fst.weight_type())
                replace_input_symbols = replace_fst.input_symbols()
                replace_output_symbols = replace_fst.output_symbols()

                # Copy states
                for replace_state in replace_fst.states():
                    state_map[(r, replace_state)] = new_fst.add_state()

                    # Create final arc to next state
                    if replace_fst.final(replace_state) != replace_final:
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(in_eps, out_eps, weight_one, next_state),
                        )

                # Copy arcs
                for replace_state in replace_fst.states():
                    for replace_arc in replace_fst.arcs(replace_state):
                        new_fst.add_arc(
                            state_map[(r, replace_state)],
                            fst.Arc(
                                new_input_symbols.add_symbol(
                                    replace_input_symbols.find(
                                        replace_arc.ilabel)),
                                new_output_symbols.add_symbol(
                                    replace_output_symbols.find(
                                        replace_arc.olabel)),
                                weight_one,
                                state_map[(r, replace_arc.nextstate)],
                            ),
                        )

                # Create arc into start state
                new_fst.add_arc(
                    new_state,
                    fst.Arc(in_eps, out_eps, weight_one,
                            state_map[(r, replace_fst.start())]),
                )
            else:
                # Copy arc as-is
                new_fst.add_arc(
                    new_state,
                    fst.Arc(
                        input_symbol_map[outer_arc.ilabel],
                        output_symbol_map[outer_arc.olabel],
                        weight_one,
                        next_state,
                    ),
                )

    # Fix symbol tables
    new_fst.set_input_symbols(new_input_symbols)
    new_fst.set_output_symbols(new_output_symbols)

    return new_fst
コード例 #14
0
def minimize_fst(the_fst: fst.Fst) -> fst.Fst:
    # BUG: Fst.minimize does not pass allow_nondet through, so we have to call out to the command-line
    minimize_cmd = ["fstminimize", "--allow_nondet"]
    return fst.Fst.read_from_string(
        subprocess.check_output(minimize_cmd, input=the_fst.write_to_string()))