コード例 #1
0
ファイル: latticegen.py プロジェクト: bertsky/cor-asv-fst
def recombine_windows(window_fsts):
    '''
    Recombine processed window FSTs (containing hypotheses for a given
    window) to a lattice, which is also represented as an FST.
    '''
    def _label(pos, length):
        return 'WIN-{}-{}'.format(pos, length)

    t1 = time.time()
    space_tr = pynini.acceptor(' ')

    # determine the input string length and max. window size
    # (TODO without iterating!!!)
    num_tokens = max(i for (i, j) in window_fsts) + 1
    max_window_size = max(j for (i, j) in window_fsts)

    root = pynini.Fst()
    for i in range(num_tokens + 1):
        s = root.add_state()
    root.set_start(0)
    root.set_final(num_tokens, 0)

    # FIXME refactor the merging of symbol tables into a separate function
    symbol_table = pynini.SymbolTable()
    for window_fst in window_fsts.values():
        symbol_table = pynini.merge_symbol_table(symbol_table,
                                                 window_fst.input_symbols())
        symbol_table = pynini.merge_symbol_table(symbol_table,
                                                 window_fst.output_symbols())
    for (pos, length), window_fst in window_fsts.items():
        label = _label(pos, length)
        sym = symbol_table.add_symbol(label)

    root.set_input_symbols(symbol_table)
    root.set_output_symbols(symbol_table)

    replacements = []
    for (pos, length), window_fst in window_fsts.items():
        label = _label(pos, length)
        sym = root.output_symbols().find(label)
        if pos + length < num_tokens:
            # append a space if this is not the last token, so that the final
            # string consists of tokens separated by spaces
            window_fst.concat(space_tr)
        replacements.append((label, window_fst))
        root.add_arc(pos, pynini.Arc(0, sym, 0, pos + length))

    result = pynini.replace(root, replacements)
    result.optimize()

    t2 = time.time()
    logging.debug('Recombining time: {}s'.format(t2 - t1))

    return result
コード例 #2
0
ファイル: error_simp.py プロジェクト: kba/cor-asv-fst
def combine_error_transducers(transducers, max_context, max_errors):

    def _universal_acceptor(symbol_table):
        fst = pynini.epsilon_machine()
        fst.set_input_symbols(symbol_table)
        fst.set_output_symbols(symbol_table)
        for x, y in symbol_table:
            if x > 0:
                fst.add_arc(0, pynini.Arc(x, x, 0, 0))
        return fst

    contexts = []
    for n in range(1,max_context+1):
        for m in range(1,n+1):
            contexts.append(list(range(m,n+1)))
    
    # FIXME refactor the merging of symbol tables into a separate function
    symtab = pynini.SymbolTable()
    for t in transducers:
        symtab = pynini.merge_symbol_table(symtab, t.input_symbols())
        symtab = pynini.merge_symbol_table(symtab, t.output_symbols())
    for t in transducers:
        t.relabel_tables(new_isymbols=symtab, new_osymbols=symtab)
    
    acceptor = _universal_acceptor(symtab)
    combined_transducers_dicts = []
    for context in contexts:
        print('Context: ', context)
        one_error = pynini.Fst()
        for n in context:
            one_error.union(transducers[n-1])
        
        for num_errors in range(1, max_errors+1):
            print('Number of errors:', num_errors)
            result_tr = acceptor.copy()
            result_tr.concat(one_error)
            result_tr.closure(0, num_errors)
            result_tr.concat(acceptor)
            result_tr.arcsort()
            combined_transducers_dicts.append({
                'max_error' : num_errors,
                'context' : ''.join(map(str, context)),
                'transducer' : result_tr })
    return combined_transducers_dicts
コード例 #3
0
def compile_transducer(mappings,
                       ngr_probs,
                       max_errors=3,
                       max_context=3,
                       weight_threshold=5.0):
    ngr_weights = -np.log(ngr_probs)
    identity_trs, error_trs = {}, {}
    identity_mappings, error_mappings = {}, {}
    for i in range(max_context):
        identity_trs[i], error_trs[i] = [], []
        identity_mappings[i], error_mappings[i] = [], []
    for x, y, weight in mappings:
        mapping = (escape_for_pynini(x), escape_for_pynini(y), str(weight))
        if x == y:
            identity_mappings[len(x) - 1].append(mapping)
        else:
            error_mappings[len(x) - 1].append(mapping)
    for i in range(max_context):
        identity_trs[i] = pynini.string_map(identity_mappings[i])
        error_trs[i] = pynini.string_map(error_mappings[i])
    # TODO refactor as a subfunction
    # - build the "master transducer" containing ID-n and ERR-n symbols
    #   on transitions for n in 1..max_context and containing ngr_weights[n] in
    #   arcs leading to those
    state_dict = {}
    root = pynini.Fst()

    # FIXME refactor the merging of symbol tables into a separate function
    symbol_table = pynini.SymbolTable()
    for i in range(max_context):
        symbol_table = pynini.merge_symbol_table(
            symbol_table, identity_trs[i].input_symbols())
        symbol_table = pynini.merge_symbol_table(symbol_table,
                                                 error_trs[i].input_symbols())
        symbol_table = pynini.merge_symbol_table(
            symbol_table, identity_trs[i].output_symbols())
        symbol_table = pynini.merge_symbol_table(symbol_table,
                                                 error_trs[i].output_symbols())
        sym = symbol_table.add_symbol('id-{}'.format(i + 1))
        sym = symbol_table.add_symbol('err-{}'.format(i + 1))

    root.set_input_symbols(symbol_table)
    root.set_output_symbols(symbol_table)

    for i in range(max_errors + 1):
        for j in range(max_context + 1):
            s = root.add_state()
            state_dict[(i, j)] = s
            if j > 0:
                # (i, 0) -> (i, j) with epsilon
                root.add_arc(state_dict[(i, 0)],
                             pynini.Arc(0, 0, ngr_weights[j - 1], s))
                # (i, j) -> (i, 0) with identity
                sym = root.output_symbols().find('id-{}'.format(j))
                root.add_arc(s, pynini.Arc(0, sym, 0, state_dict[(i, 0)]))
                if i > 0:
                    # arc: (i-1, j) -> (i, 0) with error
                    sym = root.output_symbols().find('err-{}'.format(j))
                    root.add_arc(state_dict[(i - 1, j)],
                                 pynini.Arc(0, sym, 0, state_dict[(i, 0)]))
        root.set_final(state_dict[(i, 0)], 0)

    root.set_start(state_dict[(0, 0)])
    replacements = []
    for i in range(max_context):
        replacements.append(('id-{}'.format(i + 1), identity_trs[i]))
        replacements.append(('err-{}'.format(i + 1), error_trs[i]))
    result = pynini.replace(root, replacements)
    result.invert()
    result.optimize()
    return result
コード例 #4
0
 def __init__(self) -> None:
     self.fst = pynini.Fst()
     self.symbol_table = pynini.SymbolTable()
     self._compiled = False
     self.token_to_key = {}
     self.key_to_token = {}
コード例 #5
0
    def compute_alignments(self,
                           pairs,
                           max_zeroes=2,
                           max_allowed_mappings=2,
                           print_mappings=False,
                           initial_only=False):
        """Generalization of Kessler's initials-only match approach.

    Ref: Kessler, Brett. 2001. "The Significance of Word Lists."
    University of Chicago Press.

    Args:
      pairs: list of pairs
      max_allowed_mappings: int, maximum number of mappings allowed
      print_mappings: bool, whether or not to print mappings
      initial_only: bool, if True, only look at the initial segment
    Returns:
      number of matches
    """
        tot = 0
        ins_weight = del_weight = 100
        if initial_only:
            new_pairs = []
            for (p1, p2) in pairs:
                new_pairs.append((p1[:1], p2[:1]))
            pairs = new_pairs
        # Computes initial statistics for any symbol mapping to any symbol assuming
        # no reordering.
        for (p1, p2) in pairs:
            for i in range(len(p1)):
                for j in range(i, len(p2)):
                    self._stats[p1[i], p2[j]] += 1
                    tot += 1
        if not initial_only:  # If we only consider initials, we don't need 2nd pass
            for (p1, p2) in pairs:
                for i in range(len(p2)):
                    for j in range(i, len(p1)):
                        self._stats[p1[j], p2[i]] += 1
                        tot += 1
        symbols = py.SymbolTable()
        symbols.add_symbol("<epsilon>")
        # Constructs a matcher FST using the initial statistics.
        for (c1, c2) in self._stats:
            label1 = symbols.add_symbol(c1)
            label2 = symbols.add_symbol(c2)
            weight = -math.log(self._stats[c1, c2] / tot)
            self._aligner.add_arc(
                self._aligner.start(),
                py.Arc(label1, label2, weight, self._aligner.start()))
            self._aligner.add_arc(
                self._aligner.start(),
                py.Arc(label1, 0, del_weight, self._aligner.start()))
            self._aligner.add_arc(
                self._aligner.start(),
                py.Arc(0, label2, ins_weight, self._aligner.start()))
        self._aligner.optimize()
        self._aligner.set_input_symbols(symbols)
        self._aligner.set_output_symbols(symbols)
        left_to_right = collections.defaultdict(
            lambda: collections.defaultdict(int))
        if not initial_only:
            right_to_left = collections.defaultdict(
                lambda: collections.defaultdict(int))
        # Realigns the data using the matcher. NB: we could get fancy and use EM for
        # this...
        for (p1, p2) in pairs:
            f1 = self._make_fst(p1, symbols)
            f2 = self._make_fst(p2, symbols)
            alignment = py.shortestpath(f1 * self._aligner * f2).topsort()
            for s in alignment.states():
                aiter = alignment.arcs(s)
                while not aiter.done():
                    arc = aiter.value()
                    left_to_right[arc.ilabel][arc.olabel] += 1
                    if not initial_only:
                        right_to_left[arc.olabel][arc.ilabel] += 1
                    aiter.next()
        mappings = set()
        # Finds the best match for a symbol, going in both directions. So if
        # language 1 /k/ matches to language 2 /s/, /o/ or /m/, and /s/ is most
        # common, then we propose /k/ -> /s/. Going the other way if language 1 /k/,
        # /t/ or /s/ matches to language 2 /s/, and /s/ is most common then we also
        # get /s/ -> /s/.
        for left in left_to_right:
            d = left_to_right[left]
            rights = sorted(d, key=d.get, reverse=True)[:max_allowed_mappings]
            for right in rights:
                mappings.add((left, right))
        if not initial_only:
            for right in right_to_left:
                d = right_to_left[right]
                lefts = sorted(d, key=d.get,
                               reverse=True)[:max_allowed_mappings]
                for left in lefts:
                    mappings.add((left, right))
        # Now build a new pared down aligner...
        new_aligner = py.Fst()
        s = new_aligner.add_state()
        new_aligner.set_start(s)
        new_aligner.set_final(s)
        new_aligner.set_input_symbols(symbols)
        new_aligner.set_output_symbols(symbols)
        for (ilabel, olabel) in mappings:
            new_aligner.add_arc(new_aligner.start(),
                                py.Arc(ilabel, olabel, 0, new_aligner.start()))
            if print_mappings:
                left = symbols.find(ilabel)
                right = symbols.find(olabel)
                left = left.replace("<epsilon>", "Ø")
                right = right.replace("<epsilon>", "Ø")
                print("{}\t->\t{}".format(left, right))
        self._aligner = new_aligner
        matched = 0
        # ... and realign with it, counting how many alignments succeed, and
        # computing how many homophones there are.
        input_homophones = collections.defaultdict(int)
        output_homophones = collections.defaultdict(int)
        matching_homophones = collections.defaultdict(int)
        for (p1, p2) in pairs:
            f1 = self._make_fst(p1, symbols)
            f2 = self._make_fst(p2, symbols)
            alignment = py.shortestpath(f1 * self._aligner * f2).topsort()
            if alignment.num_states() == 0:
                continue
            inp = []
            out = []
            n_deletions = 0
            n_insertions = 0
            for s in alignment.states():
                aiter = alignment.arcs(s)
                while not aiter.done():
                    arc = aiter.value()
                    if arc.ilabel:
                        inp.append(symbols.find(arc.ilabel))
                    else:
                        inp.append("-")
                        n_insertions += 1
                    if arc.olabel:
                        out.append(symbols.find(arc.olabel))
                    else:
                        out.append("-")
                        n_deletions += 1
                    aiter.next()
            inp = " ".join(inp).encode("utf8")
            input_homophones[inp] += 1
            out = " ".join(out).encode("utf8")
            output_homophones[out] += 1
            if n_deletions + n_insertions <= max_zeroes:
                matched += 1
                match = "{}\t{}".format(inp.decode("utf8"), out.decode("utf8"))
                print(match)
                matching_homophones[match] += 1
        # Counts the homophone groups --- the number of unique forms each of which
        # is assigned to more than one slot, for each language.
        inp_lang_homophones = 0
        for w in input_homophones:
            if input_homophones[w] > 1:
                inp_lang_homophones += 1
        out_lang_homophones = 0
        for w in output_homophones:
            if output_homophones[w] > 1:
                out_lang_homophones += 1
        print("HOMOPHONE_GROUPS:\t{}\t{}".format(inp_lang_homophones,
                                                 out_lang_homophones))
        for match in matching_homophones:
            if matching_homophones[match] > 1:
                print("HOMOPHONE:\t{}\t{}".format(matching_homophones[match],
                                                  match))
        return matched