def to_wfst(self, recording): phonemes = self.predict_raw(recording) EPSILON = 0 fst = pynini.Fst() init = fst.add_state() fst.set_start(init) heads = [(init, EPSILON)] num_of_letters = phonemes.shape[2] time = phonemes.shape[1] letters = [x+1 for x in range(num_of_letters)] for time in range(time): states = [fst.add_state() for _ in letters] log_phonemes = -np.log(phonemes[0]) for entering_state, head in heads: for letter, letter_state in zip(letters, states): if letter == len(letters): letter = 0 # letter_state = fst.add_state() output_sign = head if head != letter else 0 weight = log_phonemes[time, letter] fst.add_arc(entering_state, pynini.Arc( letter, output_sign, weight, letter_state)) heads = list(zip(states, letters)) [fst.set_final(x[0]) for x in heads] if optimize: fst.optimize() return fst
def __call__(self, graphemes: str) -> List[Tuple[str, ...]]: # pragma: no cover """Call the rewrite function""" fst = pynini.Fst() one = pynini.Weight.one(fst.weight_type()) max_state = 0 for i in range(len(graphemes)): start_state = fst.add_state() for j in range(1, self.grapheme_order + 1): if i + j <= len(graphemes): substring = self.seq_sep.join(graphemes[i:i + j]) ilabel = self.input_token_type.find(substring) if ilabel != pynini.NO_LABEL: fst.add_arc(start_state, pynini.Arc(ilabel, ilabel, one, i + j)) if i + j >= max_state: max_state = i + j for _ in range(fst.num_states(), max_state + 1): fst.add_state() fst.set_start(0) fst.set_final(len(graphemes), one) fst.set_input_symbols(self.input_token_type) fst.set_output_symbols(self.input_token_type) hypotheses = self.rewrite(fst) hypotheses = [x.replace(self.seq_sep, " ") for x in hypotheses if x] return hypotheses
def all_suffixes(self, fsa: pynini.Fst) -> pynini.Fst: fsa = fsa.copy() start_state = fsa.start() for s in fsa.states(): fsa.add_arc( start_state, pynini.Arc(0, 0, pynini.Weight.one(fsa.weight_type()), s)) return fsa.optimize()
def _universal_acceptor(symbol_table): fst = pynini.epsilon_machine() fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) for x, y in symbol_table: if x > 0: fst.add_arc(0, pynini.Arc(x, x, 0, 0)) return fst
def _label_list_to_string_fsa(labels: List[int]) -> pynini.Fst: fst = pynini.Fst() fst.add_states(len(labels) + 1) fst.set_start(0) fst.set_final(len(labels)) for i, lbl in enumerate(labels): fst.add_arc(i, pynini.Arc(lbl, lbl, pynini.Weight.one("tropical"), i + 1)) return fst.optimize()
def convert_to_automaton(self, keys: List[int]) -> 'Automaton': assert self._compiled input_fst = pynini.Fst() input_fst.set_input_symbols(self.symbol_table) input_fst.set_output_symbols(self.symbol_table) states = [input_fst.add_state() for _ in range(len(keys) + 1)] input_fst.set_start(states[0]) input_fst.set_final(states[-1]) for from_state, to_state, key in zip(states, states[1:], keys): input_fst.add_arc(from_state, pynini.Arc(key, key, None, to_state)) return input_fst
def recombine_windows(window_fsts): ''' Recombine processed window FSTs (containing hypotheses for a given window) to a lattice, which is also represented as an FST. ''' def _label(pos, length): return 'WIN-{}-{}'.format(pos, length) t1 = time.time() space_tr = pynini.acceptor(' ') # determine the input string length and max. window size # (TODO without iterating!!!) num_tokens = max(i for (i, j) in window_fsts) + 1 max_window_size = max(j for (i, j) in window_fsts) root = pynini.Fst() for i in range(num_tokens + 1): s = root.add_state() root.set_start(0) root.set_final(num_tokens, 0) # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for window_fst in window_fsts.values(): symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.output_symbols()) for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = symbol_table.add_symbol(label) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) replacements = [] for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = root.output_symbols().find(label) if pos + length < num_tokens: # append a space if this is not the last token, so that the final # string consists of tokens separated by spaces window_fst.concat(space_tr) replacements.append((label, window_fst)) root.add_arc(pos, pynini.Arc(0, sym, 0, pos + length)) result = pynini.replace(root, replacements) result.optimize() t2 = time.time() logging.debug('Recombining time: {}s'.format(t2 - t1)) return result
def BuildSigmaFstFromSymbolTable(syms: pynini.SymbolTableView) -> pynini.Fst: f = pynini.Fst() start_state = f.add_state() f.set_start(start_state) final_state = f.add_state() f.set_final(final_state) for lbl, _ in syms: f.add_arc( start_state, pynini.Arc(lbl, lbl, pynini.Weight.one("tropical"), final_state)) return f
def second_pass(_, vertex): if vertex == 'start': # add an epsilon transition between each starting state and starting slots # will be removed during optimization # we do not union the starting slots because we do not know when the slots will be finished processing s = start_states[vertex] for start_slot in starting_slots.keys(): # note: we currently do not support setting weights for starting classes arc = pynini.Arc(0, 0, 0.0, start_states[start_slot]) fst.add_arc(s, arc) return slot = slot_map[vertex] if isinstance(slot, StemGuesser): # only care about a StemGuesser's continuation classes # StemGuesser does not assign weights or transitions cont_classes = slot.rules[0][2] for final_state in slot.final_states: # add epsilon transition between FSA's final states and continuation classes for (continuation_class, weight) in cont_classes: if not continuation_class: # mark final_state as accepting by setting weight to semiring One or weight specified by user fst.set_final(final_state, weight) else: arc = pynini.Arc(0, 0, weight, start_states[continuation_class]) fst.add_arc(final_state, arc) else: # regular Slot for ((_, _, cont_classes, _), final_state) in zip(slot.rules, slot.final_states): # add epsilon transition between each rule's final state and continuation classes # note: we currently do not support setting weights for continuation classes for (continuation_class, weight) in cont_classes: if not continuation_class: # mark final_state as accepting by setting weight to semiring One or weight specified by user fst.set_final(final_state, weight) else: arc = pynini.Arc(0, 0, weight, start_states[continuation_class]) fst.add_arc(final_state, arc) return
def construct_any(symbol_table): ''' Return an FST for Sigma*. ''' ANY = pynini.Fst() sym_it = pynini.SymbolTableIterator(symbol_table) start = ANY.add_state() ANY.set_start(start) ANY.set_final(start) while not sym_it.done(): ANY.add_arc(start, pynini.Arc(symbol_table.find(sym_it.symbol()), symbol_table.find(sym_it.symbol()), 1, start)) sym_it.next() return ANY
def _make_fst(self, string, symbols): fst = py.Fst() s = fst.add_state() fst.set_start(s) for c in string: label = symbols.find(c) next_s = fst.add_state() fst.add_arc(s, py.Arc(label, label, 0, next_s)) s = next_s fst.set_final(s) fst.set_input_symbols(symbols) fst.set_output_symbols(symbols) return fst
def _label_union(labels: Set[int], epsilon: bool) -> pynini.Fst: """Creates FSA over a union of the labels.""" if epsilon: labels.add(0) side = pynini.Fst() src = side.add_state() side.set_start(src) dst = side.add_state() for label in labels: side.add_arc(src, pynini.Arc(label, label, None, dst)) side.set_final(dst) assert side.verify(), "FST is ill-formed" return side
def make_lattice(tokens: List[str], key_table: KeyTable) -> pynini.Fst: """Creates a lattice from a list of tokens. The lattice is an unweighted FSA. """ lattice = pynini.Fst() # A "string FSA" needs n + 1 states. lattice.add_states(len(tokens) + 1) lattice.set_start(0) lattice.set_final(len(tokens)) for (src, token) in enumerate(tokens): key = get_key(token) # Each element in `indices` is the index of an in-vocabulary word that # represents a possible unscrambling of `token`. indices = key_table[key] for index in indices: # This adds an unweighted arc labeled `index` from the current # state `src` to the next state `src + 1`. lattice.add_arc(src, pynini.Arc(index, index, 0, src + 1)) assert lattice.verify(), "ill-formed lattice" return lattice
def _flip_lemmatizer_feature_labels(self, lemmatizer: pynini.Fst) -> pynini.Fst: """Helper function to flip lemmatizer's feature labels from input to output. Destructive operation. Args: lemmatizer: FST representing a partially constructed lemmatizer. Returns: Modified lemmatizer. """ feature_labels = set() for s in self.category.feature_labels.states(): aiter = self.category.feature_labels.arcs(s) while not aiter.done(): arc = aiter.value() if arc.ilabel: feature_labels.add(arc.ilabel) aiter.next() lemmatizer.set_input_symbols(lemmatizer.output_symbols()) for s in lemmatizer.states(): maiter = lemmatizer.mutable_arcs(s) while not maiter.done(): arc = maiter.value() if arc.olabel in feature_labels: # This assertion should always be true by construction. assert arc.ilabel == 0, ( f"ilabel = " f"{lemmatizer.input_symbols().find(arc.ilabel)}," f" olabel = " f"{lemmatizer.output_symbols().find(arc.olabel)}") arc = pynini.Arc(arc.olabel, arc.ilabel, arc.weight, arc.nextstate) maiter.set_value(arc) maiter.next() return lemmatizer
def compile_transducer(mappings, ngr_probs, max_errors=3, max_context=3, weight_threshold=5.0): ngr_weights = -np.log(ngr_probs) identity_trs, error_trs = {}, {} identity_mappings, error_mappings = {}, {} for i in range(max_context): identity_trs[i], error_trs[i] = [], [] identity_mappings[i], error_mappings[i] = [], [] for x, y, weight in mappings: mapping = (escape_for_pynini(x), escape_for_pynini(y), str(weight)) if x == y: identity_mappings[len(x) - 1].append(mapping) else: error_mappings[len(x) - 1].append(mapping) for i in range(max_context): identity_trs[i] = pynini.string_map(identity_mappings[i]) error_trs[i] = pynini.string_map(error_mappings[i]) # TODO refactor as a subfunction # - build the "master transducer" containing ID-n and ERR-n symbols # on transitions for n in 1..max_context and containing ngr_weights[n] in # arcs leading to those state_dict = {} root = pynini.Fst() # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for i in range(max_context): symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].output_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].output_symbols()) sym = symbol_table.add_symbol('id-{}'.format(i + 1)) sym = symbol_table.add_symbol('err-{}'.format(i + 1)) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) for i in range(max_errors + 1): for j in range(max_context + 1): s = root.add_state() state_dict[(i, j)] = s if j > 0: # (i, 0) -> (i, j) with epsilon root.add_arc(state_dict[(i, 0)], pynini.Arc(0, 0, ngr_weights[j - 1], s)) # (i, j) -> (i, 0) with identity sym = root.output_symbols().find('id-{}'.format(j)) root.add_arc(s, pynini.Arc(0, sym, 0, state_dict[(i, 0)])) if i > 0: # arc: (i-1, j) -> (i, 0) with error sym = root.output_symbols().find('err-{}'.format(j)) root.add_arc(state_dict[(i - 1, j)], pynini.Arc(0, sym, 0, state_dict[(i, 0)])) root.set_final(state_dict[(i, 0)], 0) root.set_start(state_dict[(0, 0)]) replacements = [] for i in range(max_context): replacements.append(('id-{}'.format(i + 1), identity_trs[i])) replacements.append(('err-{}'.format(i + 1), error_trs[i])) result = pynini.replace(root, replacements) result.invert() result.optimize() return result
def compute_alignments(self, pairs, max_zeroes=2, max_allowed_mappings=2, print_mappings=False, initial_only=False): """Generalization of Kessler's initials-only match approach. Ref: Kessler, Brett. 2001. "The Significance of Word Lists." University of Chicago Press. Args: pairs: list of pairs max_allowed_mappings: int, maximum number of mappings allowed print_mappings: bool, whether or not to print mappings initial_only: bool, if True, only look at the initial segment Returns: number of matches """ tot = 0 ins_weight = del_weight = 100 if initial_only: new_pairs = [] for (p1, p2) in pairs: new_pairs.append((p1[:1], p2[:1])) pairs = new_pairs # Computes initial statistics for any symbol mapping to any symbol assuming # no reordering. for (p1, p2) in pairs: for i in range(len(p1)): for j in range(i, len(p2)): self._stats[p1[i], p2[j]] += 1 tot += 1 if not initial_only: # If we only consider initials, we don't need 2nd pass for (p1, p2) in pairs: for i in range(len(p2)): for j in range(i, len(p1)): self._stats[p1[j], p2[i]] += 1 tot += 1 symbols = py.SymbolTable() symbols.add_symbol("<epsilon>") # Constructs a matcher FST using the initial statistics. for (c1, c2) in self._stats: label1 = symbols.add_symbol(c1) label2 = symbols.add_symbol(c2) weight = -math.log(self._stats[c1, c2] / tot) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, label2, weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, 0, del_weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(0, label2, ins_weight, self._aligner.start())) self._aligner.optimize() self._aligner.set_input_symbols(symbols) self._aligner.set_output_symbols(symbols) left_to_right = collections.defaultdict( lambda: collections.defaultdict(int)) if not initial_only: right_to_left = collections.defaultdict( lambda: collections.defaultdict(int)) # Realigns the data using the matcher. NB: we could get fancy and use EM for # this... for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() left_to_right[arc.ilabel][arc.olabel] += 1 if not initial_only: right_to_left[arc.olabel][arc.ilabel] += 1 aiter.next() mappings = set() # Finds the best match for a symbol, going in both directions. So if # language 1 /k/ matches to language 2 /s/, /o/ or /m/, and /s/ is most # common, then we propose /k/ -> /s/. Going the other way if language 1 /k/, # /t/ or /s/ matches to language 2 /s/, and /s/ is most common then we also # get /s/ -> /s/. for left in left_to_right: d = left_to_right[left] rights = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for right in rights: mappings.add((left, right)) if not initial_only: for right in right_to_left: d = right_to_left[right] lefts = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for left in lefts: mappings.add((left, right)) # Now build a new pared down aligner... new_aligner = py.Fst() s = new_aligner.add_state() new_aligner.set_start(s) new_aligner.set_final(s) new_aligner.set_input_symbols(symbols) new_aligner.set_output_symbols(symbols) for (ilabel, olabel) in mappings: new_aligner.add_arc(new_aligner.start(), py.Arc(ilabel, olabel, 0, new_aligner.start())) if print_mappings: left = symbols.find(ilabel) right = symbols.find(olabel) left = left.replace("<epsilon>", "Ø") right = right.replace("<epsilon>", "Ø") print("{}\t->\t{}".format(left, right)) self._aligner = new_aligner matched = 0 # ... and realign with it, counting how many alignments succeed, and # computing how many homophones there are. input_homophones = collections.defaultdict(int) output_homophones = collections.defaultdict(int) matching_homophones = collections.defaultdict(int) for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() if alignment.num_states() == 0: continue inp = [] out = [] n_deletions = 0 n_insertions = 0 for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() if arc.ilabel: inp.append(symbols.find(arc.ilabel)) else: inp.append("-") n_insertions += 1 if arc.olabel: out.append(symbols.find(arc.olabel)) else: out.append("-") n_deletions += 1 aiter.next() inp = " ".join(inp).encode("utf8") input_homophones[inp] += 1 out = " ".join(out).encode("utf8") output_homophones[out] += 1 if n_deletions + n_insertions <= max_zeroes: matched += 1 match = "{}\t{}".format(inp.decode("utf8"), out.decode("utf8")) print(match) matching_homophones[match] += 1 # Counts the homophone groups --- the number of unique forms each of which # is assigned to more than one slot, for each language. inp_lang_homophones = 0 for w in input_homophones: if input_homophones[w] > 1: inp_lang_homophones += 1 out_lang_homophones = 0 for w in output_homophones: if output_homophones[w] > 1: out_lang_homophones += 1 print("HOMOPHONE_GROUPS:\t{}\t{}".format(inp_lang_homophones, out_lang_homophones)) for match in matching_homophones: if matching_homophones[match] > 1: print("HOMOPHONE:\t{}\t{}".format(matching_homophones[match], match)) return matched
inessive_regular_transduce = pynini.transducer( "", inessive_regular) #, output_token_type="utf8") inessive_harmony_transduce = pynini.transducer( "", inessive_harmony) #, output_token_type="utf8") transducer_adessive_harmony = harmony_state + adessive_harmony_transduce transducer_adessive_regular = regular_state + adessive_regular_transduce transducer_inessive_harmony = harmony_state + inessive_harmony_transduce transducer_inessive_regular = regular_state + inessive_regular_transduce transducer_adessive_base = transducer_adessive_regular | transducer_adessive_harmony transducer_inessive_base = transducer_inessive_regular | transducer_inessive_harmony ###Creates arcs between harmony and regular paths to allow setting and resetting for i in vowels_harmony_trigger: arc = pynini.Arc(ord(i), ord(i), 0, 5) transducer_adessive_base.add_arc(0, arc) transducer_inessive_base.add_arc(0, arc) for i in vowels_harmony_holster: arc = pynini.Arc(ord(i), ord(i), 0, 0) transducer_adessive_base.add_arc(5, arc) transducer_inessive_base.add_arc(5, arc) ####Ensures regular path is default transducer_adessive_base.set_start(0) transducer_inessive_base.set_start(0) transducer_adessive_base.optimize() transducer_inessive_base.optimize()
def add_arc(self, from_state: int, to_state: int, key: int) -> None: assert not self._compiled self.fst.add_arc(from_state, pynini.Arc(key, key, None, to_state))
def first_visit(state, vertex): start_states = state if vertex == 'start': s = fst.add_state() fst.set_start(s) start_states[vertex] = s return start_states slot = slot_map[vertex] slot_start_state = fst.add_state() if isinstance(slot, StemGuesser): # copy the regex FSA to fst with pywrapfst fsa = slot.fst old_num_states = fst.num_states() fst.add_states( fsa.num_states() - 1) # do not need to copy over slot_start_state again for state in fsa.states(): new_state = slot_start_state if state == 0 else ( old_num_states + state - 1) # final states of FST may not be accepting, so must manually find the final states if fsa.final(state) != pynini.Weight.zero('tropical'): slot.final_states.append(new_state) for arc in fsa.arcs(state): nextstate = slot_start_state if arc.nextstate == 0 else ( old_num_states + arc.nextstate - 1) fst.add_arc( new_state, pynini.Arc(arc.ilabel, arc.olabel, arc.weight, nextstate)) else: # regular Slot # create an FST for each rule with pynini and copy over to fst with pywrapfst for (upper, lower, _, rule_weight) in slot.rules: # transitions within same slot could have different continuation classes # we will concatenate the rule with the continuation class' FST in the second DFS # place lower on the input side so that FST can take in input from lower alphabet to perform analysis rule = pynutil.add_weight(pynini.cross(lower, upper), rule_weight) # copy rule to fst arc by arc, starting from state start_slot old_num_states = fst.num_states() fst.add_states( rule.num_states() - 1) # do not need to copy over slot_start_state again for state in rule.states(): new_state = slot_start_state if state == 0 else ( old_num_states + state - 1) for arc in rule.arcs(state): nextstate = slot_start_state if arc.nextstate == 0 else ( old_num_states + arc.nextstate - 1) fst.add_arc( new_state, pynini.Arc(arc.ilabel, arc.olabel, arc.weight, nextstate)) rule_final_state = fst.num_states() - 1 slot.final_states.append(rule_final_state) # add current slot's FST to finished set of slots start_states[vertex] = slot_start_state return start_states