def recombine_windows(window_fsts): ''' Recombine processed window FSTs (containing hypotheses for a given window) to a lattice, which is also represented as an FST. ''' def _label(pos, length): return 'WIN-{}-{}'.format(pos, length) t1 = time.time() space_tr = pynini.acceptor(' ') # determine the input string length and max. window size # (TODO without iterating!!!) num_tokens = max(i for (i, j) in window_fsts) + 1 max_window_size = max(j for (i, j) in window_fsts) root = pynini.Fst() for i in range(num_tokens + 1): s = root.add_state() root.set_start(0) root.set_final(num_tokens, 0) # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for window_fst in window_fsts.values(): symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, window_fst.output_symbols()) for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = symbol_table.add_symbol(label) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) replacements = [] for (pos, length), window_fst in window_fsts.items(): label = _label(pos, length) sym = root.output_symbols().find(label) if pos + length < num_tokens: # append a space if this is not the last token, so that the final # string consists of tokens separated by spaces window_fst.concat(space_tr) replacements.append((label, window_fst)) root.add_arc(pos, pynini.Arc(0, sym, 0, pos + length)) result = pynini.replace(root, replacements) result.optimize() t2 = time.time() logging.debug('Recombining time: {}s'.format(t2 - t1)) return result
def combine_error_transducers(transducers, max_context, max_errors): def _universal_acceptor(symbol_table): fst = pynini.epsilon_machine() fst.set_input_symbols(symbol_table) fst.set_output_symbols(symbol_table) for x, y in symbol_table: if x > 0: fst.add_arc(0, pynini.Arc(x, x, 0, 0)) return fst contexts = [] for n in range(1,max_context+1): for m in range(1,n+1): contexts.append(list(range(m,n+1))) # FIXME refactor the merging of symbol tables into a separate function symtab = pynini.SymbolTable() for t in transducers: symtab = pynini.merge_symbol_table(symtab, t.input_symbols()) symtab = pynini.merge_symbol_table(symtab, t.output_symbols()) for t in transducers: t.relabel_tables(new_isymbols=symtab, new_osymbols=symtab) acceptor = _universal_acceptor(symtab) combined_transducers_dicts = [] for context in contexts: print('Context: ', context) one_error = pynini.Fst() for n in context: one_error.union(transducers[n-1]) for num_errors in range(1, max_errors+1): print('Number of errors:', num_errors) result_tr = acceptor.copy() result_tr.concat(one_error) result_tr.closure(0, num_errors) result_tr.concat(acceptor) result_tr.arcsort() combined_transducers_dicts.append({ 'max_error' : num_errors, 'context' : ''.join(map(str, context)), 'transducer' : result_tr }) return combined_transducers_dicts
def compile_transducer(mappings, ngr_probs, max_errors=3, max_context=3, weight_threshold=5.0): ngr_weights = -np.log(ngr_probs) identity_trs, error_trs = {}, {} identity_mappings, error_mappings = {}, {} for i in range(max_context): identity_trs[i], error_trs[i] = [], [] identity_mappings[i], error_mappings[i] = [], [] for x, y, weight in mappings: mapping = (escape_for_pynini(x), escape_for_pynini(y), str(weight)) if x == y: identity_mappings[len(x) - 1].append(mapping) else: error_mappings[len(x) - 1].append(mapping) for i in range(max_context): identity_trs[i] = pynini.string_map(identity_mappings[i]) error_trs[i] = pynini.string_map(error_mappings[i]) # TODO refactor as a subfunction # - build the "master transducer" containing ID-n and ERR-n symbols # on transitions for n in 1..max_context and containing ngr_weights[n] in # arcs leading to those state_dict = {} root = pynini.Fst() # FIXME refactor the merging of symbol tables into a separate function symbol_table = pynini.SymbolTable() for i in range(max_context): symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].input_symbols()) symbol_table = pynini.merge_symbol_table( symbol_table, identity_trs[i].output_symbols()) symbol_table = pynini.merge_symbol_table(symbol_table, error_trs[i].output_symbols()) sym = symbol_table.add_symbol('id-{}'.format(i + 1)) sym = symbol_table.add_symbol('err-{}'.format(i + 1)) root.set_input_symbols(symbol_table) root.set_output_symbols(symbol_table) for i in range(max_errors + 1): for j in range(max_context + 1): s = root.add_state() state_dict[(i, j)] = s if j > 0: # (i, 0) -> (i, j) with epsilon root.add_arc(state_dict[(i, 0)], pynini.Arc(0, 0, ngr_weights[j - 1], s)) # (i, j) -> (i, 0) with identity sym = root.output_symbols().find('id-{}'.format(j)) root.add_arc(s, pynini.Arc(0, sym, 0, state_dict[(i, 0)])) if i > 0: # arc: (i-1, j) -> (i, 0) with error sym = root.output_symbols().find('err-{}'.format(j)) root.add_arc(state_dict[(i - 1, j)], pynini.Arc(0, sym, 0, state_dict[(i, 0)])) root.set_final(state_dict[(i, 0)], 0) root.set_start(state_dict[(0, 0)]) replacements = [] for i in range(max_context): replacements.append(('id-{}'.format(i + 1), identity_trs[i])) replacements.append(('err-{}'.format(i + 1), error_trs[i])) result = pynini.replace(root, replacements) result.invert() result.optimize() return result
def __init__(self) -> None: self.fst = pynini.Fst() self.symbol_table = pynini.SymbolTable() self._compiled = False self.token_to_key = {} self.key_to_token = {}
def compute_alignments(self, pairs, max_zeroes=2, max_allowed_mappings=2, print_mappings=False, initial_only=False): """Generalization of Kessler's initials-only match approach. Ref: Kessler, Brett. 2001. "The Significance of Word Lists." University of Chicago Press. Args: pairs: list of pairs max_allowed_mappings: int, maximum number of mappings allowed print_mappings: bool, whether or not to print mappings initial_only: bool, if True, only look at the initial segment Returns: number of matches """ tot = 0 ins_weight = del_weight = 100 if initial_only: new_pairs = [] for (p1, p2) in pairs: new_pairs.append((p1[:1], p2[:1])) pairs = new_pairs # Computes initial statistics for any symbol mapping to any symbol assuming # no reordering. for (p1, p2) in pairs: for i in range(len(p1)): for j in range(i, len(p2)): self._stats[p1[i], p2[j]] += 1 tot += 1 if not initial_only: # If we only consider initials, we don't need 2nd pass for (p1, p2) in pairs: for i in range(len(p2)): for j in range(i, len(p1)): self._stats[p1[j], p2[i]] += 1 tot += 1 symbols = py.SymbolTable() symbols.add_symbol("<epsilon>") # Constructs a matcher FST using the initial statistics. for (c1, c2) in self._stats: label1 = symbols.add_symbol(c1) label2 = symbols.add_symbol(c2) weight = -math.log(self._stats[c1, c2] / tot) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, label2, weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(label1, 0, del_weight, self._aligner.start())) self._aligner.add_arc( self._aligner.start(), py.Arc(0, label2, ins_weight, self._aligner.start())) self._aligner.optimize() self._aligner.set_input_symbols(symbols) self._aligner.set_output_symbols(symbols) left_to_right = collections.defaultdict( lambda: collections.defaultdict(int)) if not initial_only: right_to_left = collections.defaultdict( lambda: collections.defaultdict(int)) # Realigns the data using the matcher. NB: we could get fancy and use EM for # this... for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() left_to_right[arc.ilabel][arc.olabel] += 1 if not initial_only: right_to_left[arc.olabel][arc.ilabel] += 1 aiter.next() mappings = set() # Finds the best match for a symbol, going in both directions. So if # language 1 /k/ matches to language 2 /s/, /o/ or /m/, and /s/ is most # common, then we propose /k/ -> /s/. Going the other way if language 1 /k/, # /t/ or /s/ matches to language 2 /s/, and /s/ is most common then we also # get /s/ -> /s/. for left in left_to_right: d = left_to_right[left] rights = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for right in rights: mappings.add((left, right)) if not initial_only: for right in right_to_left: d = right_to_left[right] lefts = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings] for left in lefts: mappings.add((left, right)) # Now build a new pared down aligner... new_aligner = py.Fst() s = new_aligner.add_state() new_aligner.set_start(s) new_aligner.set_final(s) new_aligner.set_input_symbols(symbols) new_aligner.set_output_symbols(symbols) for (ilabel, olabel) in mappings: new_aligner.add_arc(new_aligner.start(), py.Arc(ilabel, olabel, 0, new_aligner.start())) if print_mappings: left = symbols.find(ilabel) right = symbols.find(olabel) left = left.replace("<epsilon>", "Ø") right = right.replace("<epsilon>", "Ø") print("{}\t->\t{}".format(left, right)) self._aligner = new_aligner matched = 0 # ... and realign with it, counting how many alignments succeed, and # computing how many homophones there are. input_homophones = collections.defaultdict(int) output_homophones = collections.defaultdict(int) matching_homophones = collections.defaultdict(int) for (p1, p2) in pairs: f1 = self._make_fst(p1, symbols) f2 = self._make_fst(p2, symbols) alignment = py.shortestpath(f1 * self._aligner * f2).topsort() if alignment.num_states() == 0: continue inp = [] out = [] n_deletions = 0 n_insertions = 0 for s in alignment.states(): aiter = alignment.arcs(s) while not aiter.done(): arc = aiter.value() if arc.ilabel: inp.append(symbols.find(arc.ilabel)) else: inp.append("-") n_insertions += 1 if arc.olabel: out.append(symbols.find(arc.olabel)) else: out.append("-") n_deletions += 1 aiter.next() inp = " ".join(inp).encode("utf8") input_homophones[inp] += 1 out = " ".join(out).encode("utf8") output_homophones[out] += 1 if n_deletions + n_insertions <= max_zeroes: matched += 1 match = "{}\t{}".format(inp.decode("utf8"), out.decode("utf8")) print(match) matching_homophones[match] += 1 # Counts the homophone groups --- the number of unique forms each of which # is assigned to more than one slot, for each language. inp_lang_homophones = 0 for w in input_homophones: if input_homophones[w] > 1: inp_lang_homophones += 1 out_lang_homophones = 0 for w in output_homophones: if output_homophones[w] > 1: out_lang_homophones += 1 print("HOMOPHONE_GROUPS:\t{}\t{}".format(inp_lang_homophones, out_lang_homophones)) for match in matching_homophones: if matching_homophones[match] > 1: print("HOMOPHONE:\t{}\t{}".format(matching_homophones[match], match)) return matched