def fstprintall( in_fst: fst.Fst, out_file: Optional[TextIO] = None, exclude_meta: bool = True, eps: str = "<eps>", ) -> List[List[str]]: sentences = [] output_symbols = in_fst.output_symbols() out_eps = output_symbols.find(eps) zero_weight = fst.Weight.Zero(in_fst.weight_type()) visited_states = set() state_queue = deque() state_queue.append((in_fst.start(), [])) while len(state_queue) > 0: state, sentence = state_queue.popleft() if state in visited_states: continue visited_states.add(state) if in_fst.final(state) != zero_weight: if out_file: print(" ".join(sentence), file=out_file) else: sentences.append(sentence) for arc in in_fst.arcs(state): arc_sentence = list(sentence) if arc.olabel != out_eps: out_symbol = output_symbols.find(arc.olabel).decode() if exclude_meta and out_symbol.startswith("__"): pass # skip __label__, etc. else: arc_sentence.append(out_symbol) state_queue.append((arc.nextstate, arc_sentence)) return sentences
def _all_valid_strings(td: fst.Fst) -> List[Tuple[List[int], float]]: """Return an enumeration of the emission language. Essentially the equivalent of fstprint, but not handling the de-interning of strings. The weight returned is not the weight of the whole sequence, but the weight of the final state in the sequence as a final state. Does not check for duplicate emissions or cycles in the transducer. :param td: transducer, the emission language of which to enumerate :returns: a list of (interned emission symbols, weight) tuples """ if td.start() == -1: return [] stack = [(td.start(), [])] complete_emissions = [] while stack: state, output = stack.pop() final_weight = float(td.final(state)) if np.isfinite(final_weight): complete_emissions.append((output, final_weight)) stack += [(a.nextstate, output + [a.olabel]) for a in td.arcs(state)] return complete_emissions
def _replace_fsts(outer_fst: fst.Fst, replacements: Dict[int, fst.Fst], eps="<eps>") -> fst.Fst: input_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {} output_symbol_map: Dict[Union[int, Tuple[int, int]], int] = {} state_map: Dict[Union[int, Tuple[int, int]], int] = {} # Create new FST new_fst = fst.Fst() new_input_symbols = fst.SymbolTable() new_output_symbols = fst.SymbolTable() weight_one = fst.Weight.One(new_fst.weight_type()) weight_zero = fst.Weight.Zero(new_fst.weight_type()) weight_final = fst.Weight.Zero(outer_fst.weight_type()) # Copy symbols outer_input_symbols = outer_fst.input_symbols() for i in range(outer_input_symbols.num_symbols()): key = outer_input_symbols.get_nth_key(i) input_symbol_map[key] = new_input_symbols.add_symbol( outer_input_symbols.find(key)) outer_output_symbols = outer_fst.output_symbols() for i in range(outer_output_symbols.num_symbols()): key = outer_output_symbols.get_nth_key(i) output_symbol_map[key] = new_output_symbols.add_symbol( outer_output_symbols.find(key)) in_eps = new_input_symbols.add_symbol(eps) out_eps = new_output_symbols.add_symbol(eps) # Copy states for outer_state in outer_fst.states(): new_state = new_fst.add_state() state_map[outer_state] = new_state if outer_fst.final(outer_state) != weight_final: new_fst.set_final(new_state) # Set start state new_fst.set_start(state_map[outer_fst.start()]) # Copy arcs for outer_state in outer_fst.states(): new_state = state_map[outer_state] for outer_arc in outer_fst.arcs(outer_state): next_state = state_map[outer_arc.nextstate] replace_fst = replacements.get(outer_arc.olabel) if replace_fst is not None: # Replace in-line r = outer_arc.olabel replace_final = fst.Weight.Zero(replace_fst.weight_type()) replace_input_symbols = replace_fst.input_symbols() replace_output_symbols = replace_fst.output_symbols() # Copy states for replace_state in replace_fst.states(): state_map[(r, replace_state)] = new_fst.add_state() # Create final arc to next state if replace_fst.final(replace_state) != replace_final: new_fst.add_arc( state_map[(r, replace_state)], fst.Arc(in_eps, out_eps, weight_one, next_state), ) # Copy arcs for replace_state in replace_fst.states(): for replace_arc in replace_fst.arcs(replace_state): new_fst.add_arc( state_map[(r, replace_state)], fst.Arc( new_input_symbols.add_symbol( replace_input_symbols.find( replace_arc.ilabel)), new_output_symbols.add_symbol( replace_output_symbols.find( replace_arc.olabel)), weight_one, state_map[(r, replace_arc.nextstate)], ), ) # Create arc into start state new_fst.add_arc( new_state, fst.Arc(in_eps, out_eps, weight_one, state_map[(r, replace_fst.start())]), ) else: # Copy arc as-is new_fst.add_arc( new_state, fst.Arc( input_symbol_map[outer_arc.ilabel], output_symbol_map[outer_arc.olabel], weight_one, next_state, ), ) # Fix symbol tables new_fst.set_input_symbols(new_input_symbols) new_fst.set_output_symbols(new_output_symbols) return new_fst