Exemple #1
0
    def topk_choice(self, word_sequence, topk_wds=None):
        '''
        extracts the topk choices of
        lm given a word history (lattice)
        input: lm.fst and sentence string
        output: topk words to complete the lattice
        '''

        # generate sentence fst
        fstout = fst.intersect(word_sequence, self.lm)
        fst_det = fst.determinize(fstout)
        fst_p = fst.push(fst_det, push_weights=True, to_final=True)
        fst_p.rmepsilon()
        fst_rm = fst.determinize(fst_p)
        short = fst.shortestpath(fst_rm, nshortest=10)
        short_det = fst.determinize(short)
        short_det.rmepsilon()
        two_state = fst.compose(short_det, self.refiner)
        output = two_state.project(project_output=True)
        output.rmepsilon()
        output = fst.determinize(output)
        output.minimize()
        if topk_wds is not None:  # Needs to distinguish None and []
            topk_wds.extend(self.get_topk_words(output))
        return output
Exemple #2
0
 def separate_sausage(self, sausage, helper_fst):
     '''
     Separates history sausage based on the last space. The direction
     (before/after) depends on the helper fst passed in.
     '''
     chars = fst.compose(sausage, helper_fst)
     chars.project(True)
     chars.rmepsilon()
     chars = fst.determinize(chars)
     chars.minimize().topsort()
     return chars
Exemple #3
0
def easyCompose(*fsts, determinize=True, minimize=True):
    composed = fsts[0]
    for fst in fsts[1:]:
        composed = openfst.compose(composed, fst)
        if determinize:
            composed = openfst.determinize(composed)

    if minimize:
        composed.minimize()

    return composed
Exemple #4
0
    def word_sequence_history(self, eeg_saus):
        '''
        generate a probable word
        sequence given the EEG samples
        by intersecting it with word
        language model
        '''

        word_seq = fst.compose(eeg_saus, self.ltr2wrd)
        fst.push(word_seq, push_weights=True, to_final=True)
        word_seq.project(project_output=True)
        word_seq.rmepsilon()
        return word_seq
Exemple #5
0
    def fit(self, train_samples, train_labels, observation_scores=None, num_epochs=1):
        if observation_scores is None:
            observation_scores = self.score(train_samples)

        obs_fsts = tuple(
            fromArray(
                scores.reshape(scores.shape[0], -1),
                output_labels=self._transition_to_arc
            )
            for scores in observation_scores
        )
        obs_batch_fst = easyUnion(*obs_fsts, disambiguate=True)

        train_labels = tuple(
            [self._transition_to_arc[t] for t in toTransitionSeq(label)]
            for label in train_labels
        )
        gt_fsts = tuple(
            fromSequence(labels, symbol_table=obs_batch_fst.output_symbols())
            for labels in train_labels
        )
        gt_batch_fst = easyUnion(*gt_fsts, disambiguate=True)

        losses = []
        params = []
        for i in range(num_epochs):
            seq_fst = seqFstToBatch(self._makeSeqFst(), gt_batch_fst)
            denom_fst = openfst.compose(obs_batch_fst, seq_fst)
            num_fst = openfst.compose(denom_fst, gt_batch_fst)
            batch_loss, batch_arcgrad = fstProb(num_fst, denom_fst)

            param_gradient = self._backward(batch_arcgrad)
            self._params = self._update_params(self._params, param_gradient)

            params.append(self._params.copy())
            losses.append(float(batch_loss))

        return np.array(losses), params
Exemple #6
0
 def wrd2ltr(self, fstout):
     # if there are normalization methods in fst land..
     norm_fst = self.normalize(fstout)
     letter = fst.compose(norm_fst, self.spell)
     letter.push(to_final=False)
     letter.project(project_output=True)
     letter.rmepsilon()
     letter = fst.determinize(letter)
     for state in letter.states():
         if state==0:
             continue
         letter.set_final(state)
     
     return letter
Exemple #7
0
def alternatives(sequence):
    # sequence is a list of words
    # produces the n_best alternative to sequence made of sub-units that are in words

    # Build FST
    compiler_sequence = fst.Compiler(isymbols=printable_ST,
                                     osymbols=printable_ST,
                                     keep_isymbols=True,
                                     keep_osymbols=True)
    c = 0
    for word in sequence:
        for char in word:
            print >> compiler_sequence, str(c) + ' ' + str(
                c + 1) + ' ' + char + ' ' + char
            c = c + 1
        print >> compiler_sequence, str(c) + ' ' + str(c + 1) + ' </w> </w>'
        c = c + 1
    print >> compiler_sequence, str(c)
    fst_sequence = compiler_sequence.compile()
    fst_sequence = fst_sequence.set_input_symbols(printable_ST)
    fst_sequence = fst_sequence.set_output_symbols(printable_ST)

    composition = fst.compose(fst_vocab,
                              fst.compose(grapheme_confusion,
                                          fst_sequence)).rmepsilon().arcsort()
    # composition.prune(weight = 3)
    alters = printstrings(composition,
                          nshortest=n_best,
                          syms=printable_ST,
                          weight=True)
    scores = []
    if alters:
        print alters
        scores = [float(alt[1]) for alt in alters]
        alters = [alt[0].split(' </w>')[:-1] for alt in alters]
        alters = [[''.join(word.split(' ')) for word in alt] for alt in alters]
    return alters, scores
    l = sys.stdin.readline()
    if not l: break
    unks = []
    words = l.split()
    word_ids = []
    for word in words:
      word_id = syms.get(word, unk_id)
      word_ids.append(word_id)
      if word_id == unk_id:
        unks.append(word)
  
  
    sentence = make_sentence_fsa(syms, word_ids)
    sentence.arcsort(sort_type="olabel")
    
    composed = fst.compose(sentence, g)

    alignment = fst.shortestpath(composed)
    alignment.rmepsilon()
    alignment.topsort()
        
    labels = []
    for state in alignment.states():
      for arc in alignment.arcs(state):
        if arc.olabel > 0:
          if arc.olabel == unk_id:
            labels.append(unks.pop(0))
          else:
            labels.append(syms_list[arc.olabel])
    print(" ".join(labels))
    sys.stdout.flush()