def fstprintall( in_fst: fst.Fst, out_file: Optional[TextIO] = None, exclude_meta: bool = True, eps: str = "<eps>", ) -> List[List[str]]: sentences = [] output_symbols = in_fst.output_symbols() out_eps = output_symbols.find(eps) zero_weight = fst.Weight.Zero(in_fst.weight_type()) state_queue: Deque[Tuple[int, List[str]]] = deque() state_queue.append((in_fst.start(), [])) while len(state_queue) > 0: state, sentence = state_queue.popleft() if in_fst.final(state) != zero_weight: if out_file: print(" ".join(sentence), file=out_file) else: sentences.append(sentence) for arc in in_fst.arcs(state): arc_sentence = list(sentence) if arc.olabel != out_eps: out_symbol = output_symbols.find(arc.olabel).decode() if exclude_meta and out_symbol.startswith("__"): pass # skip __label__, etc. else: arc_sentence.append(out_symbol) state_queue.append((arc.nextstate, arc_sentence)) return sentences
def fstprintall( in_fst: fst.Fst, out_file: Optional[TextIO] = None, exclude_meta: bool = True, state: Optional[int] = None, path: Optional[List[fst.Arc]] = None, zero_weight: Optional[fst.Weight] = None, eps: int = 0, ) -> List[List[str]]: sentences = [] path = path or [] state = state or in_fst.start() zero_weight = zero_weight or fst.Weight.Zero(in_fst.weight_type()) for arc in in_fst.arcs(state): path.append(arc) if in_fst.final(arc.nextstate) != zero_weight: # Final state out_syms = in_fst.output_symbols() sentence = [] for p_arc in path: if p_arc.olabel != eps: osym = out_syms.find(p_arc.olabel).decode() if exclude_meta and osym.startswith("__"): continue # skip __label__, etc. if out_file: print(osym, "", end="", file=out_file) else: sentence.append(osym) if out_file: print("", file=out_file) else: sentences.append(sentence) else: # Non-final state sentences.extend( fstprintall( in_fst, out_file=out_file, state=arc.nextstate, path=path, zero_weight=zero_weight, eps=eps, exclude_meta=exclude_meta, )) path.pop() return sentences
def longest_path(the_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst: output_symbols = the_fst.output_symbols() out_eps = output_symbols.find(eps) visited_states: Set[int] = set() best_path: List[int] = [] state_queue: Deque[Tuple[int, List[int]]] = deque() state_queue.append((the_fst.start(), [])) # Determine longest path while len(state_queue) > 0: state, path = state_queue.popleft() if state in visited_states: continue visited_states.add(state) if len(path) > len(best_path): best_path = path for arc in the_fst.arcs(state): next_path = list(path) next_path.append(arc.olabel) state_queue.append((arc.nextstate, next_path)) # Create FST with longest path path_fst = fst.Fst() input_symbols = fst.SymbolTable() input_symbols.add_symbol(eps) path_fst.set_output_symbols(output_symbols) weight_one = fst.Weight.One(path_fst.weight_type()) state = path_fst.add_state() path_fst.set_start(state) for olabel in best_path: osym = output_symbols.find(olabel).decode() next_state = path_fst.add_state() path_fst.add_arc( state, fst.Arc(input_symbols.add_symbol(osym), olabel, weight_one, next_state), ) state = next_state path_fst.set_final(state) path_fst.set_input_symbols(input_symbols) return path_fst
def fst_to_graph(the_fst: fst.Fst) -> nx.MultiDiGraph: """Converts a finite state transducer to a directed graph.""" zero_weight = fst.Weight.Zero(the_fst.weight_type()) in_symbols = the_fst.input_symbols() out_symbols = the_fst.output_symbols() g = nx.MultiDiGraph() # Add nodes for state in the_fst.states(): # Mark final states is_final = the_fst.final(state) != zero_weight g.add_node(state, final=is_final, start=False) # Add edges for arc in the_fst.arcs(state): in_label = in_symbols.find(arc.ilabel).decode() out_label = out_symbols.find(arc.olabel).decode() g.add_edge(state, arc.nextstate, in_label=in_label, out_label=out_label) # Mark start state g.add_node(the_fst.start(), start=True) return g
def make_slot_fsts(intent_fst: fst.Fst) -> Dict[str, Dict[str, fst.Fst]]: out_symbols = intent_fst.output_symbols() intent_to_slots: Dict[str, Dict[str, fst.Fst]] = {} start_state = intent_fst.start() for intent_arc in intent_fst.arcs(start_state): # Extract intent name from output label intent_label = out_symbols.find(intent_arc.olabel).decode() assert intent_label.startswith("__label__"), intent_label intent_name = intent_label[9:] # Create mapping from slot (tag) name to acceptor FST slot_to_fst: Dict[str, fst.Fst] = {} intent_to_slots[intent_name] = slot_to_fst _make_slot_fst(intent_arc.nextstate, intent_fst, slot_to_fst) return intent_to_slots
def _make_slot_fst(state: int, intent_fst: fst.Fst, slot_to_fst: Dict[str, fst.Fst]): out_symbols = intent_fst.output_symbols() one_weight = fst.Weight.One(intent_fst.weight_type()) for arc in intent_fst.arcs(state): label = out_symbols.find(arc.olabel).decode() if label.startswith("__begin__"): slot_name = label[9:] # Big assumption here that each instance of a slot (e.g., location) # will produce the same FST, and therefore doesn't need to be # processed again. if slot_name in slot_to_fst: continue # skip duplicate slots end_label = f"__end__{slot_name}" # Create new FST slot_fst = fst.Fst() slot_fst.set_input_symbols(intent_fst.input_symbols()) slot_fst.set_output_symbols(intent_fst.output_symbols()) start_state = slot_fst.add_state() slot_fst.set_start(start_state) q = [arc.nextstate] state_map = {arc.nextstate: start_state} # Copy states/arcs from intent FST until __end__ is found while len(q) > 0: q_state = q.pop() for q_arc in intent_fst.arcs(q_state): slot_arc_label = out_symbols.find(q_arc.olabel).decode() if slot_arc_label != end_label: if not q_arc.nextstate in state_map: state_map[q_arc.nextstate] = slot_fst.add_state() # Create arc slot_fst.add_arc( state_map[q_state], fst.Arc( q_arc.ilabel, q_arc.olabel, one_weight, state_map[q_arc.nextstate], ), ) # Continue copy q.append(q_arc.nextstate) else: # Mark previous state as final slot_fst.set_final(state_map[q_state]) slot_to_fst[slot_name] = minimize_fst(slot_fst) # Recurse _make_slot_fst(arc.nextstate, intent_fst, slot_to_fst)
def _all_valid_strings(td: fst.Fst) -> List[Tuple[List[int], float]]: """Return an enumeration of the emission language. Essentially the equivalent of fstprint, but not handling the de-interning of strings. The weight returned is not the weight of the whole sequence, but the weight of the final state in the sequence as a final state. Does not check for duplicate emissions or cycles in the transducer. :param td: transducer, the emission language of which to enumerate :returns: a list of (interned emission symbols, weight) tuples """ if td.start() == -1: return [] stack = [(td.start(), [])] complete_emissions = [] while stack: state, output = stack.pop() final_weight = float(td.final(state)) if np.isfinite(final_weight): complete_emissions.append((output, final_weight)) stack += [(a.nextstate, output + [a.olabel]) for a in td.arcs(state)] return complete_emissions
def linear_fst( elements: List[str], automata_op: fst.Fst, keep_isymbols: bool = True, **kwargs: Mapping[Any, Any], ) -> fst.Fst: """Produce a linear automata.""" assert len(elements) > 0, "No elements" compiler = fst.Compiler( isymbols=automata_op.input_symbols().copy(), acceptor=keep_isymbols, keep_isymbols=keep_isymbols, **kwargs, ) num_elements = 0 for i, el in enumerate(elements): print("{} {} {}".format(i, i + 1, el), file=compiler) num_elements += 1 print(str(num_elements), file=compiler) return compiler.compile()
def replace_and_patch( outer_fst: fst.Fst, outer_start_state: int, outer_final_state: int, inner_fst: fst.Fst, label_sym: int, eps: int = 0, ) -> None: """Copies an inner FST into an outer FST, creating states and mapping symbols. Creates arcs from outer start/final states to inner start/final states.""" in_symbols = outer_fst.input_symbols() out_symbols = outer_fst.output_symbols() inner_zero = fst.Weight.Zero(inner_fst.weight_type()) outer_one = fst.Weight.One(outer_fst.weight_type()) state_map = {} in_symbol_map = {} out_symbol_map = {} for i in range(inner_fst.output_symbols().num_symbols()): sym_str = inner_fst.output_symbols().find(i).decode() out_symbol_map[i] = out_symbols.find(sym_str) for i in range(inner_fst.input_symbols().num_symbols()): sym_str = inner_fst.input_symbols().find(i).decode() in_symbol_map[i] = in_symbols.find(sym_str) # Create states in outer FST for inner_state in inner_fst.states(): state_map[inner_state] = outer_fst.add_state() # Create arcs in outer FST for inner_state in inner_fst.states(): if inner_state == inner_fst.start(): outer_fst.add_arc( outer_start_state, fst.Arc(eps, label_sym, outer_one, state_map[inner_state]), ) for inner_arc in inner_fst.arcs(inner_state): outer_fst.add_arc( state_map[inner_state], fst.Arc( in_symbol_map[inner_arc.ilabel], out_symbol_map[inner_arc.olabel], outer_one, state_map[inner_arc.nextstate], ), ) if inner_fst.final(inner_arc.nextstate) != inner_zero: outer_fst.add_arc( state_map[inner_arc.nextstate], fst.Arc(eps, eps, outer_one, outer_final_state), )
def write_dictionary(self, intent_fst: fst.Fst) -> Set[str]: """Writes all required words to a CMU dictionary. Unknown words have their pronunciations guessed and written to a separate dictionary. Fails if any unknown words are found.""" start_time = time.time() words_needed: Set[str] = set() # Gather all words needed in_symbols = intent_fst.input_symbols() for i in range(in_symbols.num_symbols()): word = in_symbols.find(i).decode() if word.startswith("__") or word.startswith("<"): continue # skip metadata # Dictionary uses upper-case letters if self.dictionary_upper: word = word.upper() else: word = word.lower() words_needed.add(word) # Load base and custom dictionaries base_dictionary_path = self.profile.read_path( self.profile.get( f"speech_to_text.{self.system}.base_dictionary", "base_dictionary.txt" ) ) custom_path = self.profile.read_path( self.profile.get( f"speech_to_text.{self.system}.custom_words", "custom_words.txt" ) ) word_dict: Dict[str, List[str]] = {} for word_dict_path in [base_dictionary_path, custom_path]: if os.path.exists(word_dict_path): self._logger.debug(f"Loading dictionary from {word_dict_path}") with open(word_dict_path, "r") as dictionary_file: read_dict(dictionary_file, word_dict) # Add words from wake word if using pocketsphinx if self.profile.get("wake.system") == "pocketsphinx": wake_keyphrase = self.profile.get("wake.pocketsphinx.keyphrase", "") if len(wake_keyphrase) > 0: self._logger.debug(f"Adding words from keyphrase: {wake_keyphrase}") _, wake_tokens = sanitize_sentence( wake_keyphrase, self.dictionary_casing, self.replace_patterns, self.split_pattern, ) for word in wake_tokens: # Dictionary uses upper-case letters if self.dictionary_upper: word = word.upper() else: word = word.lower() words_needed.add(word) # Determine if we need to include the entire base dictionary mix_weight = float( self.profile.get(f"speech_to_text.{self.system}.mix_weight", 0) ) if mix_weight > 0: self._logger.debug( "Including base dictionary because base language model will be mixed" ) # Add in all the words words_needed.update(word_dict.keys()) # Write out dictionary with only the necessary words (speeds up loading) dictionary_path = self.profile.write_path( self.profile.get( f"speech_to_text.{self.system}.dictionary", "dictionary.txt" ) ) words_written = 0 number_duplicates = self.profile.get( "training.dictionary_number_duplicates", True ) with open(dictionary_path, "w") as dictionary_file: for word in sorted(words_needed): if not word in word_dict: continue for i, pronounce in enumerate(word_dict[word]): if (i < 1) or (not number_duplicates): print(word, pronounce, file=dictionary_file) else: print("%s(%s)" % (word, i + 1), pronounce, file=dictionary_file) words_written += 1 dictionary_time = time.time() - start_time self._logger.debug( f"Wrote {words_written} word(s) to {dictionary_path} in {dictionary_time} second(s)" ) # Check for unknown words return words_needed - word_dict.keys()
def make_slot_acceptor(intent_fst: fst.Fst, eps: str = "<eps>") -> fst.Fst: in_eps = intent_fst.input_symbols().find(eps) out_eps = intent_fst.output_symbols().find(eps) slot_fst = fst.Fst() # Copy symbol tables all_symbols = fst.SymbolTable() meta_keys = set() for table in [intent_fst.input_symbols(), intent_fst.output_symbols()]: for i in range(table.num_symbols()): key = table.get_nth_key(i) sym = table.find(key).decode() all_key = all_symbols.add_symbol(sym) if sym.startswith("__"): meta_keys.add(all_key) weight_one = fst.Weight.One(slot_fst.weight_type()) weight_zero = fst.Weight.Zero(slot_fst.weight_type()) # States that will be set to final final_states: Set[int] = set() # States that already have all-word loops loop_states: Set[int] = set() all_eps = all_symbols.find(eps) # Add self transitions to a state for all input words (besides <eps>) def add_loop_state(state): for sym_idx in range(all_symbols.num_symbols()): all_key = all_symbols.get_nth_key(sym_idx) if (all_key != all_eps) and (all_key not in meta_keys): slot_fst.add_arc(state, fst.Arc(all_key, all_key, weight_one, state)) slot_fst.set_start(slot_fst.add_state()) # Queue of (intent state, acceptor state, copy count) state_queue: Deque[Tuple[int, int, int]] = deque() state_queue.append((intent_fst.start(), slot_fst.start(), 0)) # BFS while len(state_queue) > 0: intent_state, slot_state, do_copy = state_queue.popleft() final_states.add(slot_state) for intent_arc in intent_fst.arcs(intent_state): out_symbol = intent_fst.output_symbols().find( intent_arc.olabel).decode() all_key = all_symbols.find(out_symbol) if out_symbol.startswith("__label__"): # Create corresponding __label__ arc next_state = slot_fst.add_state() slot_fst.add_arc( slot_state, fst.Arc(all_key, all_key, weight_one, next_state)) # Must create a loop here for intents with no slots add_loop_state(next_state) loop_states.add(slot_state) else: # Non-label arc if out_symbol.startswith("__begin__"): # States/arcs will be copied until __end__ is reached do_copy += 1 # Add loop transitions to soak up non-tag words if not slot_state in loop_states: add_loop_state(slot_state) loop_states.add(slot_state) if (do_copy > 0) and ((intent_arc.ilabel != in_eps) or (intent_arc.olabel != out_eps)): # Copy state/arc in_symbol = (intent_fst.input_symbols().find( intent_arc.ilabel).decode()) next_state = slot_fst.add_state() slot_fst.add_arc( slot_state, fst.Arc(all_symbols.find(in_symbol), all_key, weight_one, next_state), ) final_states.discard(slot_state) else: next_state = slot_state if out_symbol.startswith("__end__"): # Stop copying after this state until next __begin__ do_copy -= 1 next_info = (intent_arc.nextstate, next_state, do_copy) state_queue.append(next_info) # Mark all dangling states as final (excluding start) for state in final_states: if state != slot_fst.start(): slot_fst.set_final(state) # Fix symbol tables slot_fst.set_input_symbols(all_symbols) slot_fst.set_output_symbols(all_symbols) return slot_fst
def filter_words(words: Iterable[str], the_fst: fst.Fst) -> List[str]: input_symbols = the_fst.input_symbols() return [w for w in words if input_symbols.find(w) >= 0]
def _replace_fsts(outer_fst: fst.Fst, replacements: Dict[int, fst.Fst], eps="<eps>") -> fst.Fst: input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {} output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {} state_map: Dict[Union[int, Tuple[int, int]], int] = {} # Create new FST new_fst = fst.Fst() new_input_symbols = fst.SymbolTable() new_output_symbols = fst.SymbolTable() weight_one = fst.Weight.One(new_fst.weight_type()) weight_zero = fst.Weight.Zero(new_fst.weight_type()) weight_final = fst.Weight.Zero(outer_fst.weight_type()) # Copy symbols outer_input_symbols = outer_fst.input_symbols() for i in range(outer_input_symbols.num_symbols()): key = outer_input_symbols.get_nth_key(i) input_symbol_map[key] = new_input_symbols.add_symbol( outer_input_symbols.find(key)) outer_output_symbols = outer_fst.output_symbols() for i in range(outer_output_symbols.num_symbols()): key = outer_output_symbols.get_nth_key(i) output_symbol_map[key] = new_output_symbols.add_symbol( outer_output_symbols.find(key)) in_eps = new_input_symbols.add_symbol(eps) out_eps = new_output_symbols.add_symbol(eps) # Copy states for outer_state in outer_fst.states(): new_state = new_fst.add_state() state_map[outer_state] = new_state if outer_fst.final(outer_state) != weight_final: new_fst.set_final(new_state) # Set start state new_fst.set_start(state_map[outer_fst.start()]) # Copy arcs for outer_state in outer_fst.states(): new_state = state_map[outer_state] for outer_arc in outer_fst.arcs(outer_state): next_state = state_map[outer_arc.nextstate] replace_fst = replacements.get(outer_arc.olabel) if replace_fst is not None: # Replace in-line r = outer_arc.olabel replace_final = fst.Weight.Zero(replace_fst.weight_type()) replace_input_symbols = replace_fst.input_symbols() replace_output_symbols = replace_fst.output_symbols() # Copy states for replace_state in replace_fst.states(): state_map[(r, replace_state)] = new_fst.add_state() # Create final arc to next state if replace_fst.final(replace_state) != replace_final: new_fst.add_arc( state_map[(r, replace_state)], fst.Arc(in_eps, out_eps, weight_one, next_state), ) # Copy arcs for replace_state in replace_fst.states(): for replace_arc in replace_fst.arcs(replace_state): new_fst.add_arc( state_map[(r, replace_state)], fst.Arc( new_input_symbols.add_symbol( replace_input_symbols.find( replace_arc.ilabel)), new_output_symbols.add_symbol( replace_output_symbols.find( replace_arc.olabel)), weight_one, state_map[(r, replace_arc.nextstate)], ), ) # Create arc into start state new_fst.add_arc( new_state, fst.Arc(in_eps, out_eps, weight_one, state_map[(r, replace_fst.start())]), ) else: # Copy arc as-is new_fst.add_arc( new_state, fst.Arc( input_symbol_map[outer_arc.ilabel], output_symbol_map[outer_arc.olabel], weight_one, next_state, ), ) # Fix symbol tables new_fst.set_input_symbols(new_input_symbols) new_fst.set_output_symbols(new_output_symbols) return new_fst
def minimize_fst(the_fst: fst.Fst) -> fst.Fst: # BUG: Fst.minimize does not pass allow_nondet through, so we have to call out to the command-line minimize_cmd = ["fstminimize", "--allow_nondet"] return fst.Fst.read_from_string( subprocess.check_output(minimize_cmd, input=the_fst.write_to_string()))