def all_suffixes(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() start_state = fsa.start() for s in fsa.states(): fsa.add_arc( start_state, pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s)) return fsa.optimize()
def _flip_lemmatizer_feature_labels(self, lemmatizer: pynini.Fst) -> pynini.Fst: """Helper function to flip lemmatizer's feature labels from input to output. Destructive operation. Args: lemmatizer: FST representing a partially constructed lemmatizer. Returns: Modified lemmatizer. """ feature_labels = set() for s in self.category.feature_labels.states(): aiter = self.category.feature_labels.arcs(s) while not aiter.done(): arc = aiter.value() if arc.ilabel: feature_labels.add(arc.ilabel) aiter.next() lemmatizer.set_input_symbols(lemmatizer.output_symbols()) for s in lemmatizer.states(): maiter = lemmatizer.mutable_arcs(s) while not maiter.done(): arc = maiter.value() if arc.olabel in feature_labels: # This assertion should always be true by construction. assert arc.ilabel == 0, ( f"ilabel = " f"{lemmatizer.input_symbols().find(arc.ilabel)}," f" olabel = " f"{lemmatizer.output_symbols().find(arc.olabel)}") arc = pynini.Arc(arc.olabel, arc.ilabel, arc.weight, arc.nextstate) maiter.set_value(arc) maiter.next() return lemmatizer
def _narcs(f: pynini.Fst) -> int: """Computes the number of arcs in an FST.""" return sum(f.num_arcs(state) for state in f.states())
def all_prefixes(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() for s in fsa.states(): fsa.set_final(s) return fsa.optimize()