def load_lexicon(source, symbol_table): ''' Load lexica entries from source interpreting them using a given symbol table. ''' lex = pynini.Fst() lex.set_input_symbols(symbol_table) lex.set_output_symbols(symbol_table) # longest match, prefer complex over simple symbols tokenizer = re.compile("(<[^>]*>|.)(?::(<[^>]*>|.))?", re.U) for line in source: line = line.strip() if line: tmp = pynini.Fst() tmp.set_input_symbols(symbol_table) tmp.set_output_symbols(symbol_table) start = tmp.add_state() tmp.set_start(start) tmp.set_final(start) for token in tokenizer.findall(line): if token[1]: tmp = pynini.concat( tmp, pynini.transducer(token[0], token[1], input_token_type=symbol_table, output_token_type=symbol_table)) else: tmp = pynini.concat( tmp, pynini.acceptor(token[0], token_type=symbol_table)) lex = pynini.union(lex, tmp) return lex
def to_wfst(self, recording): phonemes = self.predict_raw(recording) EPSILON = 0 fst = pynini.Fst() init = fst.add_state() fst.set_start(init) heads = [(init, EPSILON)] num_of_letters = phonemes.shape[2] time = phonemes.shape[1] letters = [x+1 for x in range(num_of_letters)] for time in range(time): states = [fst.add_state() for _ in letters] log_phonemes = -np.log(phonemes[0]) for entering_state, head in heads: for letter, letter_state in zip(letters, states): if letter == len(letters): letter = 0 # letter_state = fst.add_state() output_sign = head if head != letter else 0 weight = log_phonemes[time, letter] fst.add_arc(entering_state, pynini.Arc( letter, output_sign, weight, letter_state)) heads = list(zip(states, letters)) [fst.set_final(x[0]) for x in heads] if optimize: fst.optimize() return fst
def __call__(self, graphemes: str) -> List[Tuple[str, ...]]: # pragma: no cover """Call the rewrite function""" fst = pynini.Fst() one = pynini.Weight.one(fst.weight_type()) max_state = 0 for i in range(len(graphemes)): start_state = fst.add_state() for j in range(1, self.grapheme_order + 1): if i + j <= len(graphemes): substring = self.seq_sep.join(graphemes[i:i + j]) ilabel = self.input_token_type.find(substring) if ilabel != pynini.NO_LABEL: fst.add_arc(start_state, pynini.Arc(ilabel, ilabel, one, i + j)) if i + j >= max_state: max_state = i + j for _ in range(fst.num_states(), max_state + 1): fst.add_state() fst.set_start(0) fst.set_final(len(graphemes), one) fst.set_input_symbols(self.input_token_type) fst.set_output_symbols(self.input_token_type) hypotheses = self.rewrite(fst) hypotheses = [x.replace(self.seq_sep, " ") for x in hypotheses if x] return hypotheses
def example1(): # ref: http://www.openfst.org/twiki/bin/view/FST/FstQuickTour#CreatingShellFsts # A vector FST is a general mutable FST example = pynini.Fst() Arc = pynini.Arc # A vector FST is a general mutable FST example.add_state() # 1st state will be state 0 (returned by AddState) example.set_start(0) # arg is state ID # Adds two arcs exiting state 0. # Arc constructor args: ilabel, olabel, weight, dest state ID. example.add_arc(0, Arc(1, 1, 0.5, 1)) # 1st arg is src state ID example.add_arc(0, Arc(2, 2, 1.5, 1)) # Adds state 1 and its arc. example.add_state() example.add_arc(1, Arc(3, 3, 2.5, 2)) # Adds state 2 and set its final weight. example.add_state() example.set_final(2, 3.5) # 1st arg is state ID, 2nd arg weight print("example1=") print example # next works, but the output is in a binary format filename = "example1.fst" print "writing", filename example.write(filename) print "reading", filename exfile = pynini.Fst.read(filename) print "printing fst read from file" print exfile
def _label_list_to_string_fsa(labels: List[int]) -> pynini.Fst: fst = pynini.Fst() fst.add_states(len(labels) + 1) fst.set_start(0) fst.set_final(len(labels)) for i, lbl in enumerate(labels): fst.add_arc(i, pynini.Arc(lbl, lbl, pynini.Weight.one("tropical"), i + 1)) return fst.optimize()
def convert_to_automaton(self, keys: List[int]) -> 'Automaton': assert self._compiled input_fst = pynini.Fst() input_fst.set_input_symbols(self.symbol_table) input_fst.set_output_symbols(self.symbol_table) states = [input_fst.add_state() for _ in range(len(keys) + 1)] input_fst.set_start(states[0]) input_fst.set_final(states[-1]) for from_state, to_state, key in zip(states, states[1:], keys): input_fst.add_arc(from_state, pynini.Arc(key, key, None, to_state)) return input_fst
def recombine_windows(window_fsts): ''' Recombine processed window FSTs (containing hypotheses for a given window) to a lattice, which is also represented as an FST. ''' def _label(pos, length): return 'WIN-{}-{}'.format(pos, length) t1 = time.time() space_tr = pynini.acceptor(' ') # determine the input string length and max. window size # (TODO without iterating!!!) num_tokens = max(i for (i, j) in window_fsts) + 1 max_window_size = max(j for (i, j) in window_fsts) root = pynini.Fst() for i in range(num_tokens + 1): s = root.add_state() root.set_start(0) root.set_final(num_tokens, 0) # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for window_fst in window_fsts.values(): symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.output_symbols()) for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = symbol_table.add_symbol(label) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) replacements = [] for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = root.output_symbols().find(label) if pos + length < num_tokens: # append a space if this is not the last token, so that the final # string consists of tokens separated by spaces window_fst.concat(space_tr) replacements.append((label, window_fst)) root.add_arc(pos, pynini.Arc(0, sym, 0, pos + length)) result = pynini.replace(root, replacements) result.optimize() t2 = time.time() logging.debug('Recombining time: {}s'.format(t2 - t1)) return result
def BuildSigmaFstFromSymbolTable(syms: pynini.SymbolTableView) -> pynini.Fst: f = pynini.Fst() start_state = f.add_state() f.set_start(start_state) final_state = f.add_state() f.set_final(final_state) for lbl, _ in syms: f.add_arc( start_state, pynini.Arc(lbl, lbl, pynini.Weight.one("tropical"), final_state)) return f
def construct_any(symbol_table): ''' Return an FST for Sigma*. ''' ANY = pynini.Fst() sym_it = pynini.SymbolTableIterator(symbol_table) start = ANY.add_state() ANY.set_start(start) ANY.set_final(start) while not sym_it.done(): ANY.add_arc(start, pynini.Arc(symbol_table.find(sym_it.symbol()), symbol_table.find(sym_it.symbol()), 1, start)) sym_it.next() return ANY
def _label_union(labels: Set[int], epsilon: bool) -> pynini.Fst: """Creates FSA over a union of the labels.""" if epsilon: labels.add(0) side = pynini.Fst() src = side.add_state() side.set_start(src) dst = side.add_state() for label in labels: side.add_arc(src, pynini.Arc(label, label, None, dst)) side.set_final(dst) assert side.verify(), "FST is ill-formed" return side
def _make_fst(self, string, symbols): fst = py.Fst() s = fst.add_state() fst.set_start(s) for c in string: label = symbols.find(c) next_s = fst.add_state() fst.add_arc(s, py.Arc(label, label, 0, next_s)) s = next_s fst.set_final(s) fst.set_input_symbols(symbols) fst.set_output_symbols(symbols) return fst
def expand(self, token: pynini.FstLike) -> pynini.Fst: """Finds regexps candidates for a token. Args: token: a "zomggg"-like token. Returns: An FST representing a lattice of possible matches. """ try: return rewrite.rewrite_lattice(token, self._regexps) except rewrite.Error: return pynini.Fst()
def expand(self, token: pynini.FstLike) -> pynini.Fst: """Finds deduplication candidates for a token in a lexicon. Args: token: a "cooooool"-like token. Returns: An FST representing a lattice of possible matches. """ try: lattice = rewrite.rewrite_lattice(token, self._dedup) return rewrite.rewrite_lattice(lattice, self._lexicon) except rewrite.Error: return pynini.Fst()
def __suff_stems_filter(self, features): ''' Return a union over filters for each feature given ''' with pynini.default_token_type(self.__syms.alphabet): filtering = pynini.Fst() filtering.set_input_symbols(self.__syms.alphabet) filtering.set_output_symbols(self.__syms.alphabet) suff_stems = pynini.accep("<Suff_Stems>") for feature in features: to_eps = pynini.cross(feature, "") filtering = pynini.union( filtering, to_eps + suff_stems + to_eps ) return filtering.optimize()
def __suff_stems_filter(self, features): ''' Return a union over filters for each feature given ''' filtering = pynini.Fst() filtering.set_input_symbols(self.__syms.alphabet) filtering.set_output_symbols(self.__syms.alphabet) suff_stems = pynini.acceptor("<Suff_Stems>", token_type=self.__syms.alphabet) for feature in features: to_eps = pynini.transducer(feature, "", input_token_type=self.__syms.alphabet) filtering = pynini.union(filtering, pynini.concat(to_eps, suff_stems, to_eps)) return filtering.optimize()
def combine_error_transducers(transducers, max_context, max_errors): def _universal_acceptor(symbol_table): fst = pynini.epsilon_machine() fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) for x, y in symbol_table: if x > 0: fst.add_arc(0, pynini.Arc(x, x, 0, 0)) return fst contexts = [] for n in range(1,max_context+1): for m in range(1,n+1): contexts.append(list(range(m,n+1))) # FIXME refactor the merging of symbol tables into a separate function symtab = pynini.SymbolTable() for t in transducers: symtab = pynini.merge_symbol_table(symtab, t.input_symbols()) symtab = pynini.merge_symbol_table(symtab, t.output_symbols()) for t in transducers: t.relabel_tables(new_isymbols=symtab, new_osymbols=symtab) acceptor = _universal_acceptor(symtab) combined_transducers_dicts = [] for context in contexts: print('Context: ', context) one_error = pynini.Fst() for n in context: one_error.union(transducers[n-1]) for num_errors in range(1, max_errors+1): print('Number of errors:', num_errors) result_tr = acceptor.copy() result_tr.concat(one_error) result_tr.closure(0, num_errors) result_tr.concat(acceptor) result_tr.arcsort() combined_transducers_dicts.append({ 'max_error' : num_errors, 'context' : ''.join(map(str, context)), 'transducer' : result_tr }) return combined_transducers_dicts
def generate_fst_for_factor_digit(factor, include_zero=False): fst = pn.Fst() carets = '' if factor > 0: carets = '^' * factor carets = carets + ' ' for num in range(0, 10): # if num == 0 and include_zero is False: # fst_temp = pn.t(str(num), "") # else: # fst_temp = pn.t(str(num), str(num) + carets) fst_temp = pn.t(str(num), str(num) + carets) fst = pn.union(fst, fst_temp) fst = fst.optimize() return fst
def build(self, dataset): # get vocabulary if needed # get words from an adapter # get transcript words self.chain = pynini.Fst() self.word_symbols = {} for word in transcript_words: state = self.chain.add_symbol() print("Getting word transcriptions") for sentence in tqdm.tqdm(transcript): for word in sentence: for word_trans in word: start_state = ... for phon in word_trans: # add arcs per each phoneme self.chain.add_arc(Arc()) # add epsilons to the dummy state for new word print("Optimizing the model") # minimize() self.chain.determinize()
def make_lattice(tokens: List[str], key_table: KeyTable) -> pynini.Fst: """Creates a lattice from a list of tokens. The lattice is an unweighted FSA. """ lattice = pynini.Fst() # A "string FSA" needs n + 1 states. lattice.add_states(len(tokens) + 1) lattice.set_start(0) lattice.set_final(len(tokens)) for (src, token) in enumerate(tokens): key = get_key(token) # Each element in `indices` is the index of an in-vocabulary word that # represents a possible unscrambling of `token`. indices = key_table[key] for index in indices: # This adds an unweighted arc labeled `index` from the current # state `src` to the next state `src + 1`. lattice.add_arc(src, pynini.Arc(index, index, 0, src + 1)) assert lattice.verify(), "ill-formed lattice" return lattice
def example1a(): # tries to use letters for labels; get error: # TypeError: an integer is required # A vector FST is a general mutable FST example = pynini.Fst() Arc = pynini.Arc # A vector FST is a general mutable FST example.add_state() # 1st state will be state 0 (returned by AddState) example.set_start(0) # arg is state ID # Adds two arcs exiting state 0. # Arc constructor args: ilabel, olabel, weight, dest state ID. example.add_arc(0, Arc('a', 'x', 0.5, 1)) # 1st arg is src state ID example.add_arc(0, Arc('b', 'y', 1.5, 1)) # Adds state 1 and its arc. example.add_state() example.add_arc(1, Arc('c', 'z', 2.5, 2)) # Adds state 2 and set its final weight. example.add_state() example.set_final(2, 3.5) # 1st arg is state ID, 2nd arg weight print("example1a=") print example
def setUpClass(cls): super().setUpClass() # Not clear "aspect" is exactly the right concept. aspect = features.Feature("aspect", "root", "dubitative", "gerundial", "durative") verb = features.Category(aspect) root = features.FeatureVector(verb, "aspect=root") stem = paradigms.make_byte_star_except_boundary() # Naming these with short names for space reasons. vowels = ("a", "i", "o", "u") v = pynini.union(*vowels) c = pynini.union("c", "m", "h", "l", "y", "k", "ʔ", "d", "n", "w", "t") # First template: apply Procrustean transformation to CVCC^?. cvcc = (c + v + pynutil.delete(v).ques + c + pynutil.delete(v).star + c.ques).optimize() # Second template: apply Procrustean transformation to CVCVVC^?. The # CVCVVC^? case involves copying vowels, which is most easily achieved by # iterating over the vowels in the construction. cvcvvc = pynini.Fst() for v in vowels: cvcvvc.union(c + v + pynutil.delete(v).ques + c + pynutil.delete(v).star + pynutil.insert(v + v) + c.ques) cvcvvc.optimize() slots = [(stem, root), (paradigms.suffix("+al", stem), features.FeatureVector(verb, "aspect=dubitative")), (paradigms.suffix("+inay", stem @ cvcc), features.FeatureVector(verb, "aspect=gerundial")), (paradigms.suffix("+ʔaa", stem @ cvcvvc), features.FeatureVector(verb, "aspect=durative"))] cls.paradigm = paradigms.Paradigm( category=verb, slots=slots, lemma_feature_vector=root, stems=["caw", "cuum", "hoyoo", "diiyl", "ʔilk", "hiwiit"])
def _make_fst(self) -> None: self._decoder = pynini.Fst() for (inp, outs) in self._t9_map: self._decoder |= pynini.cross(inp, pynini.union(*outs)) self._decoder.closure().optimize() self._encoder = pynini.invert(self._decoder)
def compile_transducer(mappings, ngr_probs, max_errors=3, max_context=3, weight_threshold=5.0): ngr_weights = -np.log(ngr_probs) identity_trs, error_trs = {}, {} identity_mappings, error_mappings = {}, {} for i in range(max_context): identity_trs[i], error_trs[i] = [], [] identity_mappings[i], error_mappings[i] = [], [] for x, y, weight in mappings: mapping = (escape_for_pynini(x), escape_for_pynini(y), str(weight)) if x == y: identity_mappings[len(x) - 1].append(mapping) else: error_mappings[len(x) - 1].append(mapping) for i in range(max_context): identity_trs[i] = pynini.string_map(identity_mappings[i]) error_trs[i] = pynini.string_map(error_mappings[i]) # TODO refactor as a subfunction # - build the "master transducer" containing ID-n and ERR-n symbols # on transitions for n in 1..max_context and containing ngr_weights[n] in # arcs leading to those state_dict = {} root = pynini.Fst() # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for i in range(max_context): symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].output_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].output_symbols()) sym = symbol_table.add_symbol('id-{}'.format(i + 1)) sym = symbol_table.add_symbol('err-{}'.format(i + 1)) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) for i in range(max_errors + 1): for j in range(max_context + 1): s = root.add_state() state_dict[(i, j)] = s if j > 0: # (i, 0) -> (i, j) with epsilon root.add_arc(state_dict[(i, 0)], pynini.Arc(0, 0, ngr_weights[j - 1], s)) # (i, j) -> (i, 0) with identity sym = root.output_symbols().find('id-{}'.format(j)) root.add_arc(s, pynini.Arc(0, sym, 0, state_dict[(i, 0)])) if i > 0: # arc: (i-1, j) -> (i, 0) with error sym = root.output_symbols().find('err-{}'.format(j)) root.add_arc(state_dict[(i - 1, j)], pynini.Arc(0, sym, 0, state_dict[(i, 0)])) root.set_final(state_dict[(i, 0)], 0) root.set_start(state_dict[(0, 0)]) replacements = [] for i in range(max_context): replacements.append(('id-{}'.format(i + 1), identity_trs[i])) replacements.append(('err-{}'.format(i + 1), error_trs[i])) result = pynini.replace(root, replacements) result.invert() result.optimize() return result
def __init__(self): self._stats = collections.defaultdict(int) self._aligner = py.Fst() s = self._aligner.add_state() self._aligner.set_start(s) self._aligner.set_final(s)
def compute_alignments(self, pairs, max_zeroes=2, max_allowed_mappings=2, print_mappings=False, initial_only=False): """Generalization of Kessler's initials-only match approach. Ref: Kessler, Brett. 2001. "The Significance of Word Lists." University of Chicago Press. Args: pairs: list of pairs max_allowed_mappings: int, maximum number of mappings allowed print_mappings: bool, whether or not to print mappings initial_only: bool, if True, only look at the initial segment Returns: number of matches """ tot = 0 ins_weight = del_weight = 100 if initial_only: new_pairs = [] for (p1, p2) in pairs: new_pairs.append((p1[:1], p2[:1])) pairs = new_pairs # Computes initial statistics for any symbol mapping to any symbol assuming # no reordering. for (p1, p2) in pairs: for i in range(len(p1)): for j in range(i, len(p2)): self._stats[p1[i], p2[j]] += 1 tot += 1 if not initial_only: # If we only consider initials, we don't need 2nd pass for (p1, p2) in pairs: for i in range(len(p2)): for j in range(i, len(p1)): self._stats[p1[j], p2[i]] += 1 tot += 1 symbols = py.SymbolTable() symbols.add_symbol("<epsilon>") # Constructs a matcher FST using the initial statistics. for (c1, c2) in self._stats: label1 = symbols.add_symbol(c1) label2 = symbols.add_symbol(c2) weight = -math.log(self._stats[c1, c2] / tot) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, label2, weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, 0, del_weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(0, label2, ins_weight, self._aligner.start())) self._aligner.optimize() self._aligner.set_input_symbols(symbols) self._aligner.set_output_symbols(symbols) left_to_right = collections.defaultdict( lambda: collections.defaultdict(int)) if not initial_only: right_to_left = collections.defaultdict( lambda: collections.defaultdict(int)) # Realigns the data using the matcher. NB: we could get fancy and use EM for # this... for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() left_to_right[arc.ilabel][arc.olabel] += 1 if not initial_only: right_to_left[arc.olabel][arc.ilabel] += 1 aiter.next() mappings = set() # Finds the best match for a symbol, going in both directions. So if # language 1 /k/ matches to language 2 /s/, /o/ or /m/, and /s/ is most # common, then we propose /k/ -> /s/. Going the other way if language 1 /k/, # /t/ or /s/ matches to language 2 /s/, and /s/ is most common then we also # get /s/ -> /s/. for left in left_to_right: d = left_to_right[left] rights = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for right in rights: mappings.add((left, right)) if not initial_only: for right in right_to_left: d = right_to_left[right] lefts = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for left in lefts: mappings.add((left, right)) # Now build a new pared down aligner... new_aligner = py.Fst() s = new_aligner.add_state() new_aligner.set_start(s) new_aligner.set_final(s) new_aligner.set_input_symbols(symbols) new_aligner.set_output_symbols(symbols) for (ilabel, olabel) in mappings: new_aligner.add_arc(new_aligner.start(), py.Arc(ilabel, olabel, 0, new_aligner.start())) if print_mappings: left = symbols.find(ilabel) right = symbols.find(olabel) left = left.replace("<epsilon>", "Ø") right = right.replace("<epsilon>", "Ø") print("{}\t->\t{}".format(left, right)) self._aligner = new_aligner matched = 0 # ... and realign with it, counting how many alignments succeed, and # computing how many homophones there are. input_homophones = collections.defaultdict(int) output_homophones = collections.defaultdict(int) matching_homophones = collections.defaultdict(int) for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() if alignment.num_states() == 0: continue inp = [] out = [] n_deletions = 0 n_insertions = 0 for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() if arc.ilabel: inp.append(symbols.find(arc.ilabel)) else: inp.append("-") n_insertions += 1 if arc.olabel: out.append(symbols.find(arc.olabel)) else: out.append("-") n_deletions += 1 aiter.next() inp = " ".join(inp).encode("utf8") input_homophones[inp] += 1 out = " ".join(out).encode("utf8") output_homophones[out] += 1 if n_deletions + n_insertions <= max_zeroes: matched += 1 match = "{}\t{}".format(inp.decode("utf8"), out.decode("utf8")) print(match) matching_homophones[match] += 1 # Counts the homophone groups --- the number of unique forms each of which # is assigned to more than one slot, for each language. inp_lang_homophones = 0 for w in input_homophones: if input_homophones[w] > 1: inp_lang_homophones += 1 out_lang_homophones = 0 for w in output_homophones: if output_homophones[w] > 1: out_lang_homophones += 1 print("HOMOPHONE_GROUPS:\t{}\t{}".format(inp_lang_homophones, out_lang_homophones)) for match in matching_homophones: if matching_homophones[match] > 1: print("HOMOPHONE:\t{}\t{}".format(matching_homophones[match], match)) return matched
def expand(self, token: pynini.FstLike) -> pynini.Fst: try: return rewrite.rewrite_lattice(token, self._lexicon) except rewrite.Error: return pynini.Fst()
def compile(slots: Set[Slot]) -> pynini.Fst: """ Returns an OpenFST FST representing the morphotactic rules of an entire lexicon Resolves all dependencies between continuation classes of multiple slots Note: no slot can be named 'start' Args: slots: set of Slot objects (not a list) Returns: (Fst) FST connecting the slots """ # if dependencies are cyclic, we cannot use pynini to concatenate # rules to continuation classes' FSTs since pynini creates a copy of the # continuation class' FST, and we might not have finished mutating it # by the time we are doing the concatenation # thus we must manually add the arcs from the rules to the continuation classes # through 2 passes (1 pass to create the rules, 1 pass to add the arcs) # we use DFS to process the Slots that are reachable from the starting Slots # Slots are the vertices, and transitions between classes are edges slot_map = {slot.name: slot for slot in slots} starting_slots = {slot.name: slot for slot in slots if slot.start} fst = pynini.Fst() if len(starting_slots) == 0: raise Exception('need at least 1 slot to be a starting slot') # copy the FST for each rule or the Slot's FSA into the main fst with pywrapfst # store each Slot's start state in this main FST as DFS state # store each rule's/FSA's final state in the Slot def first_visit(state, vertex): start_states = state if vertex == 'start': s = fst.add_state() fst.set_start(s) start_states[vertex] = s return start_states slot = slot_map[vertex] slot_start_state = fst.add_state() if isinstance(slot, StemGuesser): # copy the regex FSA to fst with pywrapfst fsa = slot.fst old_num_states = fst.num_states() fst.add_states( fsa.num_states() - 1) # do not need to copy over slot_start_state again for state in fsa.states(): new_state = slot_start_state if state == 0 else ( old_num_states + state - 1) # final states of FST may not be accepting, so must manually find the final states if fsa.final(state) != pynini.Weight.zero('tropical'): slot.final_states.append(new_state) for arc in fsa.arcs(state): nextstate = slot_start_state if arc.nextstate == 0 else ( old_num_states + arc.nextstate - 1) fst.add_arc( new_state, pynini.Arc(arc.ilabel, arc.olabel, arc.weight, nextstate)) else: # regular Slot # create an FST for each rule with pynini and copy over to fst with pywrapfst for (upper, lower, _, rule_weight) in slot.rules: # transitions within same slot could have different continuation classes # we will concatenate the rule with the continuation class' FST in the second DFS # place lower on the input side so that FST can take in input from lower alphabet to perform analysis rule = pynutil.add_weight(pynini.cross(lower, upper), rule_weight) # copy rule to fst arc by arc, starting from state start_slot old_num_states = fst.num_states() fst.add_states( rule.num_states() - 1) # do not need to copy over slot_start_state again for state in rule.states(): new_state = slot_start_state if state == 0 else ( old_num_states + state - 1) for arc in rule.arcs(state): nextstate = slot_start_state if arc.nextstate == 0 else ( old_num_states + arc.nextstate - 1) fst.add_arc( new_state, pynini.Arc(arc.ilabel, arc.olabel, arc.weight, nextstate)) rule_final_state = fst.num_states() - 1 slot.final_states.append(rule_final_state) # add current slot's FST to finished set of slots start_states[vertex] = slot_start_state return start_states def revisit(state, vertex): # do nothing because Slot only needs to be processed once return state def neighbors(vertex): if vertex == 'start': return list(starting_slots.keys()) conts = set() # we only care about visiting the continuation class so only retrieve its name # the linking of rules to continuation class' FSTs is done in the finish function slot = slot_map[vertex] # works if the slot is a Slot or StemGuesser for (_, _, continuation_classes, _) in slot.rules: conts |= set([cc for (cc, _) in continuation_classes if cc]) return list(conts) # make a first pass through all of the Slots with DFS # convert each Slot's rules into an FST # DFS guarantees that the Slots processed are reachable from the start # start_states maps Slot name to start state of the Slot so that we can concatenate a rule with its continuation classes start_states = {} start_states = _dfs(start_states, set(), 'start', neighbors, first_visit, revisit) # second pass through all of the Slots # by this time, all Slots reachable from the start have been converted into FSTs # add transition from each rule to continuation class' start state # glue all Slots together, Slot by Slot # note that we cannot concatenate each Slot to its continuation's FST # because its continuation's FST is not guaranteed to have finished processing def second_pass(_, vertex): if vertex == 'start': # add an epsilon transition between each starting state and starting slots # will be removed during optimization # we do not union the starting slots because we do not know when the slots will be finished processing s = start_states[vertex] for start_slot in starting_slots.keys(): # note: we currently do not support setting weights for starting classes arc = pynini.Arc(0, 0, 0.0, start_states[start_slot]) fst.add_arc(s, arc) return slot = slot_map[vertex] if isinstance(slot, StemGuesser): # only care about a StemGuesser's continuation classes # StemGuesser does not assign weights or transitions cont_classes = slot.rules[0][2] for final_state in slot.final_states: # add epsilon transition between FSA's final states and continuation classes for (continuation_class, weight) in cont_classes: if not continuation_class: # mark final_state as accepting by setting weight to semiring One or weight specified by user fst.set_final(final_state, weight) else: arc = pynini.Arc(0, 0, weight, start_states[continuation_class]) fst.add_arc(final_state, arc) else: # regular Slot for ((_, _, cont_classes, _), final_state) in zip(slot.rules, slot.final_states): # add epsilon transition between each rule's final state and continuation classes # note: we currently do not support setting weights for continuation classes for (continuation_class, weight) in cont_classes: if not continuation_class: # mark final_state as accepting by setting weight to semiring One or weight specified by user fst.set_final(final_state, weight) else: arc = pynini.Arc(0, 0, weight, start_states[continuation_class]) fst.add_arc(final_state, arc) return _dfs(None, set(), 'start', neighbors, second_pass, revisit) # verify the FST if not fst.verify(): raise Exception('FST malformed') # epsilon transitions may interfere with determining determinism fst.rmepsilon() if fst.properties(pywrapfst.I_DETERMINISTIC, True) == pywrapfst.I_DETERMINISTIC and\ fst.properties(pywrapfst.O_DETERMINISTIC, True) == pywrapfst.O_DETERMINISTIC: # optimize() determinizes the FST, which we do not want if it's non-deterministic fst.optimize() return fst
def __init__(self) -> None: self.fst = pynini.Fst() self.symbol_table = pynini.SymbolTable() self._compiled = False self.token_to_key = {} self.key_to_token = {}