Exemple #1
0
def backward(lattice, neglog_to_log=False):
    if lattice.arc_type() != 'log':
        lattice = openfst.arcmap(lattice, map_type='to_log')

    if neglog_to_log:
        inverted = openfst.arcmap(lattice, map_type='invert')
        one = openfst.Weight.one(lattice.weight_type())
        betas = [openfst.divide(one, a) for a in backward(inverted)]
        return betas

    betas = openfst.shortestdistance(lattice, reverse=True)
    return betas
Exemple #2
0
def forward(lattice, neglog_to_log=False):
    if lattice.arc_type() != 'log':
        lattice = openfst.arcmap(lattice, map_type='to_log')

    if neglog_to_log:
        inverted = openfst.arcmap(lattice, map_type='invert')
        one = openfst.Weight.one(lattice.weight_type())
        alphas = [openfst.divide(one, a) for a in forward(inverted)]
        return alphas

    alphas = openfst.shortestdistance(lattice)
    return alphas
Exemple #3
0
    def _decode(self, lattice):
        if self.decode_type == 'marginal':
            fst = libfst.fstArcGradient(lattice)
            if self.arc_type == 'standard':
                fst = openfst.arcmap(fst, map_type='to_std')
        elif self.decode_type == 'joint':
            fst = libfst.viterbi(lattice)
            if self.arc_type == 'log':
                fst = openfst.arcmap(fst, map_type='to_log')
        else:
            err_str = f"Unrecognized value: self.decode_type={self.decode_type}"
            raise AssertionError(err_str)

        if self.reduce_order == 'post':
            fst = openfst.compose(fst, self.reducers[self.output_stage - 1])
        return fst
Exemple #4
0
def viterbi(lattice):
    if lattice.arc_type() != 'standard':
        lattice = openfst.arcmap(lattice, map_type='to_std')

    shortest_paths = openfst.shortestpath(lattice).topsort().rmepsilon()

    return shortest_paths
Exemple #5
0
def fstArcGradient(lattice, alphas=None, betas=None):
    if lattice.arc_type() != 'log':
        lattice = openfst.arcmap(lattice, map_type='to_log')

    if alphas is None:
        alphas = forward(lattice, neglog_to_log=True)

    if betas is None:
        betas = backward(lattice, neglog_to_log=True)

    total_weight = betas[lattice.start()]
    zero = openfst.Weight.zero(lattice.weight_type())

    arc_gradient = lattice.copy()
    for state in arc_gradient.states():
        w_incoming = alphas[state]
        arc_iterator = arc_gradient.mutable_arcs(state)
        while not arc_iterator.done():
            arc = arc_iterator.value()

            w_outgoing = betas[arc.nextstate]
            weight_thru_arc = openfst.times(w_incoming, w_outgoing)
            arc_neglogprob = openfst.divide(total_weight, weight_thru_arc)

            arc.weight = arc_neglogprob
            arc_iterator.set_value(arc)
            arc_iterator.next()

        if lattice.final(state) != zero:
            # w_outgoing = one --> final weight = w_in \otimes one = w_in
            weight_thru_arc = alphas[state]
            arc_neglogprob = openfst.divide(total_weight, weight_thru_arc)
            arc_gradient.set_final(state, arc_neglogprob)

    return arc_gradient
Exemple #6
0
 def normalize_fst(f):
     f2 = f.copy()
     z = fst.shortestdistance(fst.arcmap(f, map_type="to_log"),
                              reverse=True)[0]
     for s in f2.states():
         w = f2.final(s)
         nw = fst.Weight(f2.weight_type(), float(w) - float(z))
         f2.set_final(s, nw)
     return f2
Exemple #7
0
def processLattices(lats_sets,folders,statePruneTh=10000,pruneTh=10,silence=False):
    '''Applies standard pre-processing opperations to SMT lattices
    @lats_sets: lattices to be processed
    @folders: output folders for processed lattices
    @statePruneTh: fsts above this threshold are pruned
    @pruneTh: pruning threshold
    @silence: if True, then the function does not print which lattice is being processed'''
    for lats_set,folder in zip(lats_sets,folders):
        print lats_set
        print folder
        for f in sorted(glob.glob(lats_set),key=numericalSort):
            lattice = fst.Fst.read(f)
            if lattice.num_states() > statePruneTh:
                # detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                out = fst.arcmap(fst.push(fst.arcmap(fst.prune(fst.arcmap(detminpush,map_type="to_standard"),weight=pruneTh).minimize(),map_type="to_log"),push_weights=True),map_type="to_standard")
                out.write(folder+os.path.basename(f))
                if not silence:
                    print os.path.basename(f)
            else:
                # detminpush = fst.push(fst.determinize(fst.arcmap(lattice.rmepsilon(),map_type="to_log")).minimize(),push_weights=True)
                detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                out = fst.arcmap(detminpush,map_type="to_standard")
                out.write(folder+os.path.basename(f))
                if not silence:
                    print os.path.basename(f)
 def get_grammar_fst(self):
     print("load grammar fst from:", self.grammar_fst_path)
     with gzip.open(self.grammar_fst_path) as gf:
         g_fst = fst.Fst.read_from_string(gf.read())
     g_fst = fst.arcmap(g_fst, map_type="to_log")
     # remap symbols to network vocab
     n_syms = fst.SymbolTable()
     for i, s in enumerate(self.vocabulary.itos):
         s = {' ': '<spc>', '<pad>': '<eps>'}.get(s, s)
         assert i == n_syms.add_symbol(s)
     # This stops a warninrg and is harmless - these will not occur anyway
     n_syms.add_symbol('<s>')
     n_syms.add_symbol('</s>')
     g_fst.relabel_tables(new_isymbols=n_syms, new_osymbols=n_syms)
     return g_fst
Exemple #9
0
    def __init__(self,
                 sample_batch,
                 num_classes,
                 graph_generator,
                 normalize_by_dim=None,
                 numerator_red='logsumexp',
                 denominator_red='logsumexp',
                 embedder='LutLinear',
                 embedder_kwargs={},
                 **kwargs):
        super(FSTDecoder, self).__init__(**kwargs)

        self.graph_generator = utils.contruct_from_kwargs(
            graph_generator, 'att_speech.fst_utils', {
                'num_classes': num_classes,
                'num_symbols': self.num_symbols
            })

        self.context_order = self.graph_generator.context_order
        self.normalize_by_dim = normalize_by_dim
        self.numerator_red = numerator_red
        self.denominator_red = denominator_red
        self._verify = False

        # None means no normalization
        # 0 means softmax over all symbols
        # >0 means normlize within contexts
        if self.normalize_by_dim not in [None, 0]:
            assert (self.graph_generator.num_classes ==
                    self.graph_generator.num_symbols**self.context_order)

        # An fst used for decoding most probable state sequences
        self.dec_fst = fst.arcmap(self.graph_generator.decoding_fst, 0,
                                  'to_standard', 0).arcsort('ilabel')

        rnn_hidden_size = sample_batch["features"].size()[2]
        embedder = globals()[embedder]
        # Keep the sequential for compatiblity and easier initialization
        # from old checkpoints
        ngram_to_class = self.graph_generator.ngram_to_class
        if Globals.cuda:
            ngram_to_class = ngram_to_class.cuda()
        modules = []
        modules.append(
            embedder(rnn_hidden_size, self.graph_generator.num_symbols,
                     ngram_to_class, **embedder_kwargs))
        fully_connected = nn.Sequential(*modules)
        self.fc = nn.Sequential(SequenceWise(fully_connected), )
Exemple #10
0
def buildNgramCounter(wmap,maxngram=1):
    ''' Build an n-gram counting transducer
    @wmap: file containing the vocabulary
    @maxngram: maximium order of the grams to be counted'''
    filenames = []
    counters = []
    for order in range(1,maxngram+1):
        initial_state = 0
        final_state = order+1
        filename = 'counter'+str(order)
        filenames.append(filename)
        with open('counter'+str(order),'w') as outfile:
            with open (wmap,'r') as infile:
                line = infile.readline()
                while line:
                    line = line.strip()
                    outfile.write(str(initial_state)+" "+str(initial_state)+" "+line+" "+str(0)+"\n")
                    for state in range(order):
                        outfile.write(str(state)+" "+str(state+1)+" "+line+" "+line+"\n")
                    outfile.write(str(final_state-1)+" "+str(final_state-1)+" "+line+" "+str(0)+"\n")
                    line = infile.readline()
                outfile.write(str(order)+"\n")
    compiler = fst.Compiler()
    for filename in filenames:
        with open(filename,'r') as f:
            for line in f:
                compiler.write(line)
        tmp=compiler.compile()
        tmp.write(filename+".fst")
    for filename in filenames:
        counters.append(fst.Fst.read(filename+".fst"))
    elem1 = counters[0].union(counters[1])
    elem2 = counters[2].union(counters[3])
    tmp = elem1.union(elem2).rmepsilon().arcsort()
    # tmp = fst.determinize(tmp).minimize()
    ngramCounter = fst.arcmap(tmp,map_type='to_log64',delta=0.0000001)
    ngramCounter.write("ngramCounter.fst")
    return ngramCounter
Exemple #11
0
def __test(fig_dir, num_classes=2, num_samples=10, min_dur=2, max_dur=4):
    eps_str = 'ε'
    bos_str = '<BOS>'
    eos_str = '<EOS>'
    dur_internal_str = 'I'
    dur_final_str = 'F'

    aux_symbols = (eps_str, bos_str, eos_str)
    dur_vocab = (dur_internal_str, dur_final_str)
    sample_vocab = tuple(i for i in range(num_samples))
    dur_vocab = tuple(i for i in range(1, max_dur + 1))
    class_vocab = tuple(i for i in range(num_classes))
    class_dur_vocab = tuple((c, s) for c in class_vocab for s in dur_vocab)

    def to_strings(vocab):
        return tuple(map(str, vocab))

    def to_integerizer(vocab):
        return {item: i for i, item in enumerate(vocab)}

    sample_vocab_str = to_strings(sample_vocab)
    dur_vocab_str = to_strings(dur_vocab)
    class_vocab_str = to_strings(class_vocab)
    class_dur_vocab_str = to_strings(class_dur_vocab)

    # sample_integerizer = to_integerizer(sample_vocab)
    dur_integerizer = to_integerizer(dur_vocab)
    class_integerizer = to_integerizer(class_vocab)
    # class_dur_integerizer = to_integerizer(class_dur_vocab)

    sample_str_integerizer = to_integerizer(sample_vocab_str)
    class_str_integerizer = to_integerizer(class_vocab_str)
    class_dur_str_integerizer = to_integerizer(class_dur_vocab_str)

    def get_parts(class_dur_key):
        c, s = class_dur_vocab[class_dur_str_integerizer[class_dur_key]]
        c_str = class_vocab_str[class_integerizer[c]]
        s_str = dur_vocab_str[dur_integerizer[s]]
        return c_str, s_str

    class_dur_to_str = {get_parts(name): name for name in class_dur_vocab_str}

    sample_symbols = libfst.makeSymbolTable(aux_symbols + sample_vocab,
                                            prepend_epsilon=False)
    # dur_symbols = libfst.makeSymbolTable(aux_symbols + dur_vocab, prepend_epsilon=False)
    class_symbols = libfst.makeSymbolTable(aux_symbols + class_vocab,
                                           prepend_epsilon=False)
    # dur_symbols = libfst.makeSymbolTable(aux_symbols + dur_vocab, prepend_epsilon=False)
    class_dur_symbols = libfst.makeSymbolTable(aux_symbols + class_dur_vocab,
                                               prepend_epsilon=False)

    obs_scores = np.zeros((num_samples, num_classes))

    dur_scores = np.array([[0 if d >= min_dur else np.inf for d in dur_vocab]
                           for c in class_vocab],
                          dtype=float)

    def score_transition(c_prev, s_prev, c_cur, s_cur):
        if c_prev != c_cur:
            if s_prev == dur_final_str and s_cur == dur_internal_str:
                score = 0
            else:
                score = np.inf
        else:
            if s_prev == dur_internal_str and s_cur == dur_final_str:
                score = 0
            elif s_prev == dur_internal_str and s_cur == dur_internal_str:
                score = 0
            else:
                score = np.inf
        return score

    transition_scores = np.array([[
        score_transition(c_prev, s_prev, c_cur, s_cur)
        for (c_cur, s_cur) in class_dur_vocab
    ] for (c_prev, s_prev) in class_dur_vocab],
                                 dtype=float)
    init_scores = np.array([
        0 if c == 0 and s == dur_internal_str else np.inf
        for (c, s) in class_dur_vocab
    ],
                           dtype=float)
    final_scores = np.array([
        0 if c == 1 and s == dur_final_str else np.inf
        for (c, s) in class_dur_vocab
    ],
                            dtype=float)

    def score_arc_state(class_dur_key, class_key):
        c, s = class_dur_vocab[class_dur_str_integerizer[class_dur_key]]
        c_prime = class_str_integerizer[class_key]

        if c == c_prime:
            score = 0
        else:
            score = np.inf

        return score

    class_dur_to_class_scores = np.array([[
        score_arc_state(class_dur_key, class_key)
        for class_key in class_vocab_str
    ] for class_dur_key in class_dur_vocab_str],
                                         dtype=float)

    def log_normalize(arr, axis=1):
        denom = -scipy.special.logsumexp(-arr, axis=axis, keepdims=True)
        return arr - denom

    obs_scores = log_normalize(obs_scores)
    dur_scores = log_normalize(dur_scores)
    transition_scores = log_normalize(transition_scores)
    init_scores = log_normalize(init_scores, axis=None)
    final_scores = log_normalize(final_scores, axis=None)

    obs_fst = add_endpoints(fromArray(obs_scores,
                                      sample_vocab_str,
                                      class_vocab_str,
                                      input_symbols=sample_symbols,
                                      output_symbols=class_symbols,
                                      arc_type='standard'),
                            bos_str=bos_str,
                            eos_str=eos_str).arcsort(sort_type='ilabel')

    dur_fst = add_endpoints(make_duration_fst(
        dur_scores,
        class_vocab_str,
        class_dur_to_str,
        dur_internal_str=dur_internal_str,
        dur_final_str=dur_final_str,
        input_symbols=class_symbols,
        output_symbols=class_dur_symbols,
        allow_self_transitions=False,
        arc_type='standard',
    ),
                            bos_str=bos_str,
                            eos_str=eos_str).arcsort(sort_type='ilabel')

    transition_fst = fromTransitions(
        transition_scores,
        class_dur_vocab_str,
        class_dur_vocab_str,
        init_weights=init_scores,
        final_weights=final_scores,
        input_symbols=class_dur_symbols,
        output_symbols=class_dur_symbols).arcsort(sort_type='ilabel')

    class_dur_to_class_fst = single_state_transducer(
        class_dur_to_class_scores,
        class_dur_vocab_str,
        class_vocab_str,
        input_symbols=class_dur_symbols,
        output_symbols=class_symbols,
        arc_type='standard').arcsort(sort_type='ilabel')

    seq_model = openfst.compose(dur_fst, transition_fst)
    decode_lattice = openfst.compose(obs_fst, seq_model).rmepsilon()

    # Result is in the log semiring (ie weights are negative log probs)
    arc_scores = libfst.fstArcGradient(decode_lattice).arcsort(
        sort_type='ilabel')
    best_arcs = openfst.shortestpath(decode_lattice).arcsort(
        sort_type='ilabel')

    state_scores = openfst.compose(
        arc_scores, openfst.arcmap(class_dur_to_class_fst, map_type='to_log'))
    best_states = openfst.compose(best_arcs, class_dur_to_class_fst)

    state_scores_arr, weight_type = toArray(state_scores,
                                            sample_str_integerizer,
                                            class_str_integerizer)

    draw_fst(os.path.join(fig_dir, 'obs_fst'),
             obs_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'dur_fst'),
             dur_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'transition_fst'),
             transition_fst,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'seq_model'),
             seq_model,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'decode_lattice'),
             decode_lattice,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'state_scores'),
             state_scores,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    draw_fst(os.path.join(fig_dir, 'best_states'),
             best_states,
             vertical=True,
             width=50,
             height=50,
             portrait=True)

    utils.plot_array(obs_scores.T, (state_scores_arr.T, ), ('-logprobs', ),
                     fn=os.path.join(fig_dir, "test_io.png"))