Example #1
0
File: oclm.py Project: shiranD/oclm
    def get_priors(self):
        '''
        get priors from oclm
        '''
        history_chars = self.lm.separate_sausage(
            self.lm.history_fst, self.lm.ch_before_last_space_fst)
        trailing_chars = self.lm.separate_sausage(
            self.lm.history_fst, self.lm.ch_after_last_space_fst)

        word_lattice = fst.compose(history_chars, self.lm.ltr2wrd)
        word_lattice.project(project_output=True).rmepsilon()
        word_lattice = fst.determinize(word_lattice)
        word_lattice.minimize().topsort()
        if word_lattice.num_states() == 0:
            word_lattice = self.lm.create_empty_fst(self.lm.wd_syms,
                                                    self.lm.wd_syms)
        trailing_chars_sigma = self.lm.add_char_selfloop(
            trailing_chars, self.lm.ch_syms)
        trailing_chars_sigma.arcsort(sort_type="olabel")
        current_words = fst.compose(trailing_chars_sigma, self.lm.tr_ltr2wrd)
        current_words.project(project_output=True).rmepsilon()
        current_words = fst.determinize(current_words)
        current_words.minimize().topsort()
        word_seq = word_lattice.copy().concat(current_words).rmepsilon()
        topk_wds = []
        topk = self.lm.topk_choice(word_seq)
        united_LM = self.lm.combine_ch_lm(topk, topk_wds)
        return self.lm.next_char_dist(trailing_chars, united_LM), topk
Example #2
0
def processLattices(lats_sets,folders,statePruneTh=10000,pruneTh=10,silence=False):
    '''Applies standard pre-processing opperations to SMT lattices
    @lats_sets: lattices to be processed
    @folders: output folders for processed lattices
    @statePruneTh: fsts above this threshold are pruned
    @pruneTh: pruning threshold
    @silence: if True, then the function does not print which lattice is being processed'''
    for lats_set,folder in zip(lats_sets,folders):
        print lats_set
        print folder
        for f in sorted(glob.glob(lats_set),key=numericalSort):
            lattice = fst.Fst.read(f)
            if lattice.num_states() > statePruneTh:
                # detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                out = fst.arcmap(fst.push(fst.arcmap(fst.prune(fst.arcmap(detminpush,map_type="to_standard"),weight=pruneTh).minimize(),map_type="to_log"),push_weights=True),map_type="to_standard")
                out.write(folder+os.path.basename(f))
                if not silence:
                    print os.path.basename(f)
            else:
                # detminpush = fst.push(fst.determinize(fst.arcmap(lattice.rmepsilon(),map_type="to_log")).minimize(),push_weights=True)
                detminpush = fst.push(fst.arcmap(fst.determinize(lattice.rmepsilon()).minimize(),map_type="to_log"),push_weights=True)
                out = fst.arcmap(detminpush,map_type="to_standard")
                out.write(folder+os.path.basename(f))
                if not silence:
                    print os.path.basename(f)
Example #3
0
    def topk_choice(self, word_sequence, topk_wds=None):
        '''
        extracts the topk choices of
        lm given a word history (lattice)
        input: lm.fst and sentence string
        output: topk words to complete the lattice
        '''

        # generate sentence fst
        fstout = fst.intersect(word_sequence, self.lm)
        fst_det = fst.determinize(fstout)
        fst_p = fst.push(fst_det, push_weights=True, to_final=True)
        fst_p.rmepsilon()
        fst_rm = fst.determinize(fst_p)
        short = fst.shortestpath(fst_rm, nshortest=10)
        short_det = fst.determinize(short)
        short_det.rmepsilon()
        two_state = fst.compose(short_det, self.refiner)
        output = two_state.project(project_output=True)
        output.rmepsilon()
        output = fst.determinize(output)
        output.minimize()
        if topk_wds is not None:  # Needs to distinguish None and []
            topk_wds.extend(self.get_topk_words(output))
        return output
Example #4
0
 def expand_rtn(self, func):
     """This method expands the RTN as far as necessary. This means
     that the RTN is expanded s.t. we can build the posterior for 
     ``cur_history``. In practice, this means that we follow all 
     epsilon edges and replaces all NT edges until all paths with 
     the prefix ``cur_history`` in the RTN have at least one more 
     terminal token. Then, we apply ``func`` to all reachable nodes.
     """
     updated = True
     while updated:
         updated = False
         label_fst_map = {}
         self.visited_nodes = {}
         self.cur_fst.arcsort(sort_type="olabel")
         self.add_to_label_fst_map_recursive(label_fst_map, {},
                                             self.cur_fst.start, 0.0,
                                             self.cur_history, func)
         if label_fst_map:
             logging.debug("Replace %d NT arcs for history %s" %
                           (len(label_fst_map), self.cur_history))
             # First in the list is the root FST and label
             replaced_fst = fst.replace(
                 [(len(label_fst_map) + 2000000000, self.cur_fst)] +
                 [(nt_label, f)
                  for (nt_label, f) in label_fst_map.iteritems()],
                 epsilon_on_replace=True)
             self.cur_fst = replaced_fst
             updated = True
     if self.rmeps or self.minimize_rtns:
         self.cur_fst.rmepsilon()
     if self.minimize_rtns:
         tmp = fst.determinize(self.cur_fst.determinize)
         self.cur_fst = tmp
         self.cur_fst.minimize()
Example #5
0
def test_compose_token_and_lexicon_fst_with_homophones(workdir,
                                                       words_with_homophones):
    vocab = get_vocabulary_table(workdir, words_with_homophones)
    lexicon = get_lexicon(words_with_homophones)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon_fst = lexicon.create_fst(phoneme_table, vocab, min_freq=0)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    fst = pywrapfst.compose(token_fst.arcsort('olabel'), lexicon_fst)
    with pytest.raises(pywrapfst.FstOpError):
        pywrapfst.determinize(fst)
Example #6
0
def write_fst(f, path):
    """Writes FST f to the file system after epsilon removal, determinization,
    and minimization.
    """
    f.rmepsilon()
    f = fst.determinize(f)
    f.minimize()
    f.write(path)
Example #7
0
 def get_decoding_fst(self):
     hc_fst = self.get_hc_fst()
     if self.grammar_fst_path:
         ret = fst.compose(hc_fst, self.get_grammar_fst())
         ret = fst.rmepsilon(fst.determinize(ret))
         return ret
     else:
         return hc_fst
Example #8
0
def run(words):

    construct_verbalizer('here is a check for $40 and $5.95 and 10 puppy')
    words = ['10', '$1.95', '$50', 'hello', '$40', "new", "york", "110th"]

    with open('CNN_HD_2018-12-12_14-29-00.001.srt', 'r') as f:
        words = f.read().strip().split()

    big_verb_fst = None
    unique_vocab_set = set()

    verbalizers = []

    a = time.time()
    for w in words:
        verbalizer, unique_vocab = construct_verbalizer(w)
        for v in unique_vocab:
            unique_vocab_set.add(v)
        verbalizers.append(verbalizer)
    print(time.time() - a, 'seconds for verbalizers')
    unique_vocab_set = list(unique_vocab_set)

    a = time.time()
    lexicon_fst, epsilon_fst = make_lexicon_fst(words, unique_vocab_set)
    print(time.time() - a, 'seconds for lexicon')

    unique_vocab_set += words

    big_verb_fst = None
    a = time.time()
    for w, v in zip(words, verbalizers):
        composed = fst.compose(v.project().arcsort(sort_type='olabel'),
                               lexicon_fst).project(project_output=True)
        composed = fst.compose(epsilon_fst.arcsort(sort_type='olabel'),
                               composed)
        trivial_word_fst = get_trivial_fst(unique_vocab_set.index(w) + 1)
        concat = fst.determinize(
            trivial_word_fst.concat(composed).rmepsilon().invert()).minimize()

        if big_verb_fst is None:
            big_verb_fst = concat
        else:
            big_verb_fst = big_verb_fst.union(concat)
    big_verb_fst = big_verb_fst.rmepsilon()
    print(time.time() - a, 'seconds for compositions')

    big_verb_fst.write('a.fst')
    return

    #    verbalizer = verbalizer.project()
    #verbalizer.write('/home/philip/graves_loss/hive-speech/alignment/src/test.fst')

    verbalizer = fst.compose(
        verbalizer, space_deduper)  #.project(project_output=True).rmepsilon()
    verbalizer.write('check.fst')
    verbalizer.write(
        '/home/philip/graves_loss/hive-speech/alignment/src/test.fst')
Example #9
0
def supervisor(MK, P=None, As=None, Aa=None):
    """Synthesizes an attack-resilient supervisor for the plant P, the desired language MK, the sensor attacker As and the actuator attacker Aa.
    
    Parameters
    ----------
    MK : pywrapfst.Fst 
        The FST for the desired language
    P : pywrapfst.Fst, optional
        The FST for the plant
    As : pywrapfst.Fst, optional
        The FST for the sensor attacker
    Aa : pywrapfst.Fst, optional
        The FST for the actuator attacker
    
    Returns
    -------
    S : pywrapfst.Fst 
        The attack-resilient supervisor
    controllable : bool
        True if the desired language is contollable
        
    Examples
    --------
    Examples should be written in doctest format, and should illustrate how
    to use the function.

    >>> import arsc
    >>> MK,P,As,Aa = arsc.example()
    >>> MK
    <vector Fst at 0x2a0a94178f0>
    >>> S, controllable = arsc.supervisor(MK,P,As,Aa)
    >>> S
    <vector Fst at 0x2a0a9417c00>
    >>> controllable
    True
    """

    S = MK.copy().arcsort()
    if P:
        S = fst.compose(P.copy().invert().arcsort(), S).arcsort()
    if As:
        S = fst.compose(As.copy().invert().arcsort(), S).arcsort()
    if Aa:
        S = fst.compose(S, Aa.copy().invert().arcsort()).arcsort()

        LO = fst.compose(
            fst.compose(MK.copy().arcsort(),
                        Aa.copy().invert().arcsort()),
            Aa.copy().arcsort()).arcsort().project(project_output=True)
        LO = fst.determinize(fst.epsnormalize(LO)).minimize().arcsort()
        K = fst.epsnormalize(MK.copy().arcsort().project(
            project_output=True)).minimize().arcsort()
        controllable = fst.equivalent(LO, K)
    else:
        controllable = True

    return S, controllable
Example #10
0
def fst_finalize(c, last_node, eos_node, path):
  fst_arc(c, last_node, eos_node, args.eos_id)
  c.write("%d\n" % eos_node)
  f = c.compile()
  f.rmepsilon()
  f = fst.determinize(f)
  f.minimize()
  f.topsort()
  f = fst.push(f, push_weights=True)
  f.write(path)
Example #11
0
 def create_fst(self, token_fst, lexicon_fst, grammar_fst):
     """Create FST by composing token, lexicon and grammar FST
     Args:
         token_fst (pywrapfst._MutableFst): A token FST
         lexicon_fst (pywrapfst._MutableFst): A lexicon FST
         grammar_fst (pywrapfst._MutableFst): A grammar FST
     """
     LG = pywrapfst.determinize(
         pywrapfst.compose(lexicon_fst, grammar_fst.arcsort()).rmepsilon()
     ).minimize()
     self._fst = pywrapfst.compose(token_fst, LG.arcsort())
Example #12
0
 def restrict_chars_length(self, trailing_chars):
     '''
     Restrict the length of trailing characters to avoid complex computation.
     The exact number of length depending on the implementation of fst.
     '''
     restricted_length = fst.compose(trailing_chars, self.length_fst)
     restricted_length.project(project_output=True)
     restricted_length.rmepsilon()
     restricted_length = fst.determinize(restricted_length)
     restricted_length.minimize()
     return restricted_length
Example #13
0
 def separate_sausage(self, sausage, helper_fst):
     '''
     Separates history sausage based on the last space. The direction
     (before/after) depends on the helper fst passed in.
     '''
     chars = fst.compose(sausage, helper_fst)
     chars.project(True)
     chars.rmepsilon()
     chars = fst.determinize(chars)
     chars.minimize().topsort()
     return chars
Example #14
0
def easyCompose(*fsts, determinize=True, minimize=True):
    composed = fsts[0]
    for fst in fsts[1:]:
        composed = openfst.compose(composed, fst)
        if determinize:
            composed = openfst.determinize(composed)

    if minimize:
        composed.minimize()

    return composed
Example #15
0
 def determinize(self):
     """
     Transforms a Non Deterministic DFA into a Deterministic
     Args:
         None
     Returns:
         DFA: The resulting DFA
     """
     # This function is not necessary
     self.automaton = fst.determinize(self.automaton)
     return self
Example #16
0
def make_lexicon_fst(input_words, words):
    compiler = fst.Compiler()
    lexicon_fst = fst.Fst()
    start = lexicon_fst.add_state()
    lexicon_fst.set_start(start)

    last = lexicon_fst.add_state()

    # this line projects space to epsilon
    lexicon_fst.add_arc(
        start, fst.Arc(32, 0, fst.Weight.One(lexicon_fst.weight_type()), last))
    lexicon_fst.set_final(last)
    for i, w in enumerate(words):
        w = w.strip()
        index = i + 1
        last = lexicon_fst.add_state()
        lexicon_fst.add_arc(
            start,
            fst.Arc(ord(w[0]), index,
                    fst.Weight.One(lexicon_fst.weight_type()), last))
        for c in w[1:]:
            this = lexicon_fst.add_state()
            lexicon_fst.add_arc(
                last,
                fst.Arc(ord(c), 0, fst.Weight.One(lexicon_fst.weight_type()),
                        this))
            last = this
        lexicon_fst.set_final(last, 0)
    lexicon_fst = fst.determinize(lexicon_fst).minimize().closure()

    with open('words.syms', 'w') as f:
        f.write('<eps> 0\n')
        for i, w in enumerate(words + input_words):  # we put word symbol here
            f.write('{} {}'.format(w, str(i + 1)))
            f.write('\n')
        f.write('<SPACE> {}\n'.format(str(32)))

    epsilon_fst = fst.Fst()
    start = epsilon_fst.add_state()
    end = epsilon_fst.add_state()
    for i, w in enumerate(words):
        index = i + 1
        epsilon_fst.add_arc(
            start,
            fst.Arc(0, index, fst.Weight.One(epsilon_fst.weight_type()), end))

    epsilon_fst.add_arc(
        start, fst.Arc(0, 32, fst.Weight.One(epsilon_fst.weight_type()), end))
    epsilon_fst.set_final(end, 0)
    epsilon_fst.set_start(start)
    epsilon_fst = epsilon_fst.closure()

    return lexicon_fst, epsilon_fst
Example #17
0
 def spellout(self, ifst, topk_wds=None):
     '''
     Spells out all the words in the ifst.
     '''
     if ifst is None:
         return None
     ifst.rmepsilon()
     ofst = fst.determinize(ifst)
     ofst.minimize()
     if topk_wds is not None:
         topk_wds.extend(self.get_topk_words(ofst))
     return self.wrd2ltr(ofst)
Example #18
0
    def next_char_dist(self, history, char_lm):
        '''
        Get the distribution of next character.
        '''
        history = self.concat_alphabet(history)
        history.arcsort(sort_type="olabel")
        output = fst.intersect(history, char_lm)
        output.rmepsilon()
        output = fst.determinize(output)
        output.minimize()

        # reads an fst to combine the weights of the next character.
        last_ltr = fst.compose(output, self.ltr_dist)
        last_ltr.project(True)
        last_ltr.push(to_final=True)
        last_ltr.rmepsilon()
        last_ltr = fst.determinize(last_ltr)
        last_ltr.minimize()

        # Extracts priors. Although it's a two-state machine, we have the
        # generic traverse procedure here just in case.
        prev_stateid = curr_stateid = None
        for state in last_ltr.states():
            if not curr_stateid is None:
                prev_stateid = curr_stateid
            curr_stateid = state
        priors = []
        syms = last_ltr.input_symbols()
        for arc in last_ltr.arcs(prev_stateid):
            ch = syms.find(arc.ilabel)
            w = float(arc.weight)
            if len(ch) == 1:
                priors.append((ch, w))

        # Sorts the prior by the probability and normalize it.
        priors = sorted(priors, key=lambda prior: prior[1])
        priors_vals = [BitWeight(prob) for _,prob in priors]
        total = sum(priors_vals, BitWeight(1e6))
        norm_priors = [(prob / total).loge() for prob in priors_vals]
        return zip([ch for ch,_ in priors], norm_priors)
Example #19
0
def test_compose_token_and_lexicon_fst(workdir, words_without_homophones):
    vocab = get_vocabulary_table(workdir, words_without_homophones)
    lexicon = get_lexicon(words_without_homophones)

    phoneme_table = PhonemeTable()
    phoneme_table.add_labels(phonemes)

    lexicon_fst = lexicon.create_fst(phoneme_table, vocab)

    token = Token()
    token_fst = token.create_fst(phoneme_table)

    fst = pywrapfst.compose(token_fst.arcsort('olabel'), lexicon_fst)
    fst = pywrapfst.determinize(fst)
Example #20
0
 def wrd2ltr(self, fstout):
     # if there are normalization methods in fst land..
     norm_fst = self.normalize(fstout)
     letter = fst.compose(norm_fst, self.spell)
     letter.push(to_final=False)
     letter.project(project_output=True)
     letter.rmepsilon()
     letter = fst.determinize(letter)
     for state in letter.states():
         if state==0:
             continue
         letter.set_final(state)
     
     return letter
Example #21
0
    def update(self, ch_dist):
        '''
        Update the history with the new likelihood array in the correct scale
        (nagative log space) to the history.
        '''
        new_ch = fst.Fst()
        new_ch.set_input_symbols(self.ch_syms)
        new_ch.set_output_symbols(self.ch_syms)
        new_ch.add_state()
        new_ch.set_start(0)
        new_ch.add_state()
        new_ch.set_final(1)
        space_code = -1
        space_pr = 0.
        for ch, pr in ch_dist:
            code = self.ch_syms.find(ch)
            if ch == '#':  # Adds space after we finish updating trailing chars.
                space_code = code
                space_pr = pr
                continue
            new_ch.add_arc(0, fst.Arc(code, code, pr, 1))
        new_ch.arcsort(sort_type="olabel")

        # Adds the trailing characters to existing binned history.
        for words_bin in self.prefix_words:
            if words_bin[2] >= 10:  # We discard the whole trail in this case (TODO)
                continue
            # Unless we are testing a straight line machine, this normally
            # doesn't happen in practice.
            if new_ch.num_arcs(0) == 0:
                continue
            words_bin[1].concat(new_ch).rmepsilon()
            words_bin[2] += 1

        # Continues updating the history and adds back the space if necessary.
        if space_code >= 0:
            new_ch.add_arc(0, fst.Arc(space_code, space_code, space_pr, 1))
        self.history_fst.concat(new_ch).rmepsilon()

        # Respectively update the binned history
        if space_code >= 0:  # If there is a space
            # Finishes the prefix words in current position
            word_lattice = fst.compose(self.history_fst, self.ltr2wrd)
            word_lattice.project(project_output=True).rmepsilon()
            word_lattice = fst.determinize(word_lattice)
            word_lattice.minimize()
            if word_lattice.num_states() == 0:
                word_lattice = self.create_empty_fst(self.wd_syms, self.wd_syms)
            trailing_chars = self.create_empty_fst(self.ch_syms, self.ch_syms)
            self.prefix_words.append([word_lattice, trailing_chars, 0])
Example #22
0
    def get_prior(self):
        '''
        set an array with priors
        in future priors are given from rsvp EEG vector

        OUTPUTS:
            an array of tuples, which consists of the character and the
            corresponding probabilities.
        '''
        sigma_h = self.create_machine_history()
        print(sigma_h)
        # intersect
        sigma_h.arcsort(sort_type="olabel")
        output_dist = fst.intersect(sigma_h, self.lm)
        print(output_dist)
        # process result
        output_dist = output_dist.rmepsilon()
        #output_dist = fst.rmepsilon(output_dist)
        output_dist = fst.determinize(output_dist)
        output_dist.minimize()
        output_dist = fst.push(output_dist, push_weights=True, to_final=True)

        # worth converting this history to np.array if vector computations
        # will be involeved
        #output_dist.arcsort(sort_type="olabel")

        # traverses the shortest path until we get to the second to
        # last state. And the arcs from that state to the final state contain
        # the distribution that we want.
        prev_stateid = curr_stateid = None
        for state in output_dist.states():
            if not curr_stateid is None:
                prev_stateid = curr_stateid
            curr_stateid = state
        priors = []
        for arc in output_dist.arcs(prev_stateid):
            ch = self.lm_syms.find(
                arc.ilabel)  #ilabel and olabel are the same.
            w = float(arc.weight)

            # TODO: for this demo we only need distribution over the characters
            # from 'a' to 'z'
            if len(ch) == 1 and ch in self.legit_ch_dict:
                priors.append((ch, w))

        # assuming the EEG input is an array like [("a", 0.3),("b", 0.2),...]
        # sort the array depending on the probability
        priors = sorted(priors, key=lambda prior: prior[1])
        normalized_dist = self._normalize([prob for _, prob in priors])
        return zip([ch for ch, _ in priors], normalized_dist)
Example #23
0
def make_termfst(s, paths, stoi):
    fst = wfst.Fst()
    start = fst.add_state()
    fst.set_start(start)
    for path in paths:
        a = start
        for g in path:
            b = fst.add_state()
            fst.add_arc(a, wfst.Arc(stoi[g], stoi[g], wfst.Weight.One(fst.weight_type()), b))
            a = b
        fst.set_final(b, wfst.Weight.One(fst.weight_type()))
    fst = wfst.determinize(fst)
    fst.minimize()
    return fst
Example #24
0
def make_rtn(s, dcg, stoi, fstcoll):
    fst = wfst.Fst()
    start = fst.add_state()
    fst.set_start(start)
    for path in dcg[s]:
        #print(path, file=sys.stderr)
        a = start
        for ss in path:
            b = fst.add_state()
            if ss in dcg and ss not in fstcoll:
                print("\t\tnew fst: {}".format(ss.upper()), file=sys.stderr)
                make_rtn(ss, dcg, stoi, fstcoll)
            fst.add_arc(a, wfst.Arc(stoi[ss], stoi[ss], wfst.Weight.One(fst.weight_type()), b))
            a = b
        fst.set_final(b, wfst.Weight.One(fst.weight_type()))
    fst = wfst.determinize(fst)
    fst.minimize()
    fstcoll[s] = fst
    return fstcoll
Example #25
0
def main(argv):
    training_input_dir = os.path.join(FLAGS.work_dir, 'training_inputs')
    os.makedirs(training_input_dir, exist_ok=True)
    words_txt = os.path.join(FLAGS.lang_dir, 'words.txt')
    if FLAGS.stage <= 0:
        word_table = openfst.SymbolTable.read_text(words_txt)
        alphabet_table = spelling_fst.create_alphabet_symbol_table(word_table)
        S = spelling_fst.create_spelling_fst(word_table, alphabet_table,
                                             FLAGS.repeat_letter)
        G = openfst.Fst.read(os.path.join(FLAGS.lang_dir, 'G.fst'))
        SG = openfst.determinize(openfst.compose(S, G))
        SG.minimize()

        S.write(os.path.join(training_input_dir), 'S.fst')
        SG.write(os.path.join(training_input_dir), 'SG.fst')

    with open(os.path.join(FLAGS.lang_dir, 'oov.txt')) as oov_fh:
        oov = oov_fh.read().strip()
    if FLAGS.stage <= 1:
        results = subprocess.run([
            "utils/sym2int.pl", "--map-oov", oov, "-f", "2-", "<", words_txt,
            os.path.join(FLAGS.train_data_dir, 'text')
        ],
                                 stdout=subprocess.PIPE,
                                 check=True,
                                 universal_newlines=True)
        openfst.FarWriter()
        for line in results.stdout:
            key, value = line.split(None, maxsplit=1)
            value = [int(index) for index in value.split()]
            value

    neural_net_dir = os.path.join(FLAGS.work_dir, 'nnet')
    os.makedirs(neural_net_dir, exist_ok=True)
    if FLAGS.stage <= 2:
        input_dataset = kaldi_table_dataset.KaldiFloat32MatrixDataset(
            "scp:" + os.path.join(FLAGS.train_data_dir, 'feats.scp'))
        label_dataset = kaldi_table_dataset.KaldiInt32VectorDataset(" ".join(
            "ark:utils/sym2int.pl", "--map-oov", oov, "-f", "2-", "<",
            words_txt, os.path.join(FLAGS.train_data_dir, 'text'), "|"))
        dataset = (tf.data.Dataset.zip(
            (input_dataset,
             label_dataset)).batch(FLAGS.batch_size).repeat(FLAGS.num_repeats))
Example #26
0
 def expand_rtn(self, func):
     """This method expands the RTN as far as necessary. This means
     that the RTN is expanded s.t. we can build the posterior for 
     ``cur_history``. In practice, this means that we follow all 
     epsilon edges and replaces all NT edges until all paths with 
     the prefix ``cur_history`` in the RTN have at least one more 
     terminal token. Then, we apply ``func`` to all reachable nodes.
     """
     updated = True
     while updated:
         updated = False
         label_fst_map = {}
         self.visited_nodes = {}
         self.cur_fst.arcsort(sort_type="olabel")
         self.add_to_label_fst_map_recursive(label_fst_map,
                                             {},
                                             self.cur_fst.start(), 
                                             0.0,
                                             self.cur_history, func)
         if label_fst_map:
             logging.debug("Replace %d NT arcs for history %s" % (
                                                         len(label_fst_map),
                                                         self.cur_history))
             # First in the list is the root FST and label
             replaced_fst = fst.replace(
                     [(len(label_fst_map) + 2000000000, self.cur_fst)] 
                     + [(nt_label, f) 
                         for (nt_label, f) in label_fst_map.iteritems()],
                     epsilon_on_replace=True)
             self.cur_fst = replaced_fst
             updated = True
     if self.rmeps or self.minimize_rtns:
         self.cur_fst.rmepsilon()
     if self.minimize_rtns:
         tmp = fst.determinize(self.cur_fst.determinize)
         self.cur_fst = tmp
         self.cur_fst.minimize()
Example #27
0
def build_ctc_trigram_decoding_fst_v2(S,
                                      trigrams,
                                      arc_type='log',
                                      use_context_blanks=False,
                                      prevent_epsilons=True,
                                      determinize=False,
                                      add_syms=False):
    """
    Args:
    """
    CTC = fst.Fst(arc_type=arc_type)
    weight_one = fst.Weight.One(CTC.weight_type())

    # Need a hashable type
    trigrams = sorted([tuple(tg) for tg in trigrams])

    # Translate from tuples of letter indices to indices
    # NOTE: Not all trigrams are present, states are enumerated without gaps.
    # in_ind = lambda s: sum([S**(2-i) * c for i,c in enumerate(s)]) + 1
    def in_ind(s):
        if not use_context_blanks and s[1] == 0:
            return tg2state[(
                0,
                0,
                0,
            )] + 1
        return tg2state[s] + 1

    def out_ind(s):
        return s

    def state_ind(s):
        return tg2state[s]

    # Build
    tg2state = {}
    # Add a final looping BBB state
    # for i, tg in enumerate(itertools.chain(trigrams, [(0, 0, 0)])):
    for i, tg in enumerate(trigrams):
        s1 = CTC.add_state()
        if tg[-1] == 0:
            CTC.set_final(s1)
        assert s1 == i
        tg2state[tg] = i

    assert trigrams[0] == (0, 0, 0)  # blank, blank, blank
    CTC.set_start(0)

    if 0:
        # Add a special state to handle empty labels
        s_final = CTC.add_state()
        CTC.set_final(s_final)
        CTC.add_arc(0, fst.Arc(in_ind((0, 0, 0)), 0, weight_one, s_final))
        CTC.add_arc(s_final, fst.Arc(in_ind((0, 0, 0)), 0, weight_one,
                                     s_final))


#     # Add the self-loop in the extra final state
#     CTC.add_arc(tg2state[(0, 0, 0)], fst.Arc(in_ind(s1), 0, weight_one, tg2state[(0, 0, 0)]))

    for i1, s1 in enumerate(trigrams):
        # Handle the self loop. Please note, that it correctly handles the start and final
        # (0, 0, 0) states
        if s1 != (0, 0, 0):
            CTC.add_arc(i1, fst.Arc(in_ind(s1), 0, weight_one, i1))

        if s1[1] == 0:
            base_low = (s1[0], s1[2], 0)
            base_high = (s1[0], s1[2], np.inf)
        else:
            base_low = (s1[1], s1[2], 0)
            base_high = (s1[1], s1[2], np.inf)

        base_index = bisect.bisect_left(trigrams, base_low)
        # assert trigrams[base_index] == base_low
        if not trigrams[base_index] == base_low:
            global WARNINGS
            if base_low not in WARNINGS:
                print("missing trigram ", base_low)
                WARNINGS.add(base_low)
        high_index = bisect.bisect_left(trigrams, base_high)
        for s2 in set(
                itertools.chain(trigrams[base_index:high_index],
                                [(s1[1], 0, s1[2])])):
            # self loop is already handled
            if s1 == s2:
                continue
            if s2 == (0, 0, 0):
                continue
            if s1 != (0, 0, 0):
                # once we emit the final blank, we need to terminate
                if s1[-1] == 0 and s2[-1] != 0:
                    continue
                if s1[-2] == (0, 0) and s2[-1] != 0:
                    continue
                # we can't emit a starting blank, unless we start
                if s2[0] == 0 and s1[0] != 0:
                    continue
            if s1 != (0, 0, 0) or prevent_epsilons:
                in_label = in_ind(s2)
            else:
                in_label = 0
            out_label = out_ind(s2[1])
            CTC.add_arc(
                i1, fst.Arc(in_label, out_label, weight_one, state_ind(s2)))
    print("Dec g is det?",
          CTC.properties(fst.I_DETERMINISTIC, fst.I_DETERMINISTIC) > 1)
    print("Determinizing the decoding graph")
    CTC = CTC.rmepsilon()
    if determinize:
        CTC = fst.determinize(CTC)
    print("Decoding graph nas %d states and max %d out-degree" %
          (CTC.num_states(),
           max([CTC.num_arcs(s) for s in range(CTC.num_states())])))

    CTC.arcsort('olabel')
    if add_syms:
        in_syms = fst.SymbolTable()
        in_syms.add_symbol('<eps>', 0)
        max_sym = 0
        for tg, key in tg2state.items():
            sym = ''.join(
                ['B' if c == 0 else chr(ord('a') + c - 1) for c in tg])
            max_sym = max(max_sym, max(tg))
            in_syms.add_symbol(sym, key + 1)

        out_syms = fst.SymbolTable()
        out_syms.add_symbol('<eps>', 0)
        for s in range(1, max_sym + 1):
            out_syms.add_symbol(chr(ord('a') + s - 1), s)
        CTC.set_input_symbols(in_syms)
        CTC.set_output_symbols(out_syms)

    return CTC
Example #28
0
    def __init__(self, dcg, descr):
        #print("Morphparse_DCG.__init__()", file=sys.stderr)
        #print("dcg[nonterminals]: {}".format(pprint.pformat(dcg["nonterminals"])), file=sys.stderr)
        ###Make symbol tables
        othersyms = set()
        for pos in descr["renamesyms"]:
            othersyms.update([e[1] for e in descr["renamesyms"][pos]])
        self.bounds = descr["bounds"]
        self.itos, self.stoi = make_symmaps(dcg, descr["graphs"], othersyms)

        # #DEBUG DUMP SYMTABLES
        # with codecs.open("tmp/stoi.pickle", "w", encoding="utf-8") as outfh:
        #     pickle.dump(self.stoi, outfh)
        # with codecs.open("tmp/itos.pickle", "w", encoding="utf-8") as outfh:
        #     pickle.dump(self.itos, outfh)

        termfsts = make_termfsts(dcg, descr["graphs"], self.stoi)
        # #DEBUG DUMP FST
        # for k in termfsts:
        #     print("DEBUG dumping:", k, file=sys.stderr)
        #     save_dot(termfsts[k], self.stoi, "tmp/termfst_"+k+".dot")
        #     termfsts[k].write("tmp/termfst_"+k+".fst")
            
        self.fsts = {}
        ###Expand/make non-terminal FSTs for each POS category
        for pos in descr["pos"]:
            print("Making/expanding non-terminal fst for POS:", pos, file=sys.stderr)
            fstcoll = make_rtn(pos, dcg["nonterminals"], self.stoi, {})
            # print("__init__(): fstcoll: {}".format(fstcoll.keys()), file=sys.stderr)
            # for sym in fstcoll:
            #     #DEBUG DUMP FST
            #     save_dot(fstcoll[sym], self.stoi, "tmp/"+pos+"_orig_"+sym+".dot")
            #     fstcoll[sym].write("tmp/"+pos+"_orig_"+sym+".fst")

            #replace non-terminals
            replace_pairs = [(self.stoi[pos], fstcoll.pop(pos))]
            for k, v in fstcoll.iteritems():
                replace_pairs.append((self.stoi[k], v))
            fst = wfst.replace(replace_pairs, call_arc_labeling="both")
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_expanded.dot")
            # fst.write("tmp/"+pos+"_expanded.fst")
            # if True: #DEBUGGING
            #     fst2 = fst.copy()
            #     #rename symbols (simplify) 
            #     if pos in descr["renamesyms"] and descr["renamesyms"][pos]:
            #         labpairs = map(lambda x: (self.stoi[x[0]], self.stoi[x[1]]), descr["renamesyms"][pos])
            #         fst2.relabel_pairs(opairs=labpairs, ipairs=labpairs)
            #     fst2.rmepsilon()
            #     fst2 = wfst.determinize(fst2)
            #     fst2.minimize()            
            #     #DEBUG DUMP FST
            #     save_dot(fst2, self.stoi, "tmp/"+pos+"_expandedsimple.dot")
            #     fst2.write("tmp/"+pos+"_expandedsimple.fst")            

            #replace terminals
            replace_pairs = [(self.stoi[pos], fst)]
            for k, v in termfsts.iteritems():
                replace_pairs.append((self.stoi[k], v))
            fst = wfst.replace(replace_pairs, call_arc_labeling="both")
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_expanded2.dot")
            # fst.write("tmp/"+pos+"_expanded2.fst")

            #rename symbols (simplify) JUST FOR DEBUGGING
            if pos in descr["renamesyms"] and descr["renamesyms"][pos]:
                labpairs = map(lambda x: (self.stoi[x[0]], self.stoi[x[1]]), descr["renamesyms"][pos])
                fst.relabel_pairs(opairs=labpairs, ipairs=labpairs)
            fst.rmepsilon()
            fst = wfst.determinize(fst)
            fst.minimize()            
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_prefinal.dot")
            # fst.write("tmp/"+pos+"_prefinal.fst")

            #Convert into transducer:
            #split I/O symbols by convention here: input symbols are single characters:
            #Input syms (relabel outputs to EPS):
            syms = [k for k in self.stoi if len(k) == 1]
            labpairs = map(lambda x: (self.stoi[x], self.stoi[EPS]), syms)
            fst.relabel_pairs(opairs=labpairs)
            #Output syms (relabel inputs to EPS):
            syms = [k for k in self.stoi if len(k) != 1]
            labpairs = map(lambda x: (self.stoi[x], self.stoi[EPS]), syms)
            fst.relabel_pairs(ipairs=labpairs)
            # #DEBUG DUMP FST
            # save_dot(fst, self.stoi, "tmp/"+pos+"_final.dot")
            # fst.write("tmp/"+pos+"_final.fst")
            self.fsts[pos] = fst
Example #29
0
                for arc in oov.arcs(state):
                    arc_string = str(state) + " " + str(
                        arc.nextstate) + " " + str(arc.ilabel) + " " + str(
                            arc.olabel) + " " + str(arc.weight.to_string())
                    if arc_string not in traversed_arcs:
                        traversed_arcs.append(arc_string)
                        found_oov.add_arc(
                            state,
                            fst.Arc(arc.ilabel, arc.ilabel, arc.weight,
                                    arc.nextstate))
                        if arc.ilabel == endlabel:

                            found_oov.set_final(
                                arc.nextstate,
                                fst.Weight.One(found_oov.weight_type()))
                            found_oov = fst.determinize(found_oov).minimize()
                            found_oov.verify()

                            if found_oov.num_states() > 3:
                                oovs.append(found_oov)
                                print("NUM OOV CANDIDATES: " + str(len(oovs)))
                            oov.delete_states(states=[arc.nextstate])
                            found_oov = fst.Fst(
                                arc_type=b'TripleTropicalWeight')
                            for i in range(0, oov.num_states()):
                                i = found_oov.add_state()
                            found_oov.set_start(0)
                            fst_states = []
                            fst_states.insert(0, 0)
                            traversed_arcs = []
                            break
Example #30
0
def main(out_dir=None, data_dir=None, annotation_dir=None):
    out_dir = os.path.expanduser(out_dir)
    data_dir = os.path.expanduser(data_dir)
    annotation_dir = os.path.expanduser(annotation_dir)

    annotation_dir = os.path.join(annotation_dir, 'action_annotations')

    vocab_fn = os.path.join(data_dir, 'ANU_ikea_dataset', 'indexing_files',
                            'atomic_action_list.txt')
    with open(vocab_fn, 'rt') as file_:
        action_vocab = file_.read().split('\n')
    part_names = (label.split('pick up ')[1] for label in action_vocab
                  if label.startswith('pick up'))
    new_action_vocab = tuple(f"{part}" for part in part_names)

    logger = utils.setupRootLogger(filename=os.path.join(out_dir, 'log.txt'))

    fig_dir = os.path.join(out_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    out_labels_dir = os.path.join(out_dir, 'labels')
    if not os.path.exists(out_labels_dir):
        os.makedirs(out_labels_dir)

    # gt_action = np.load(os.path.join(annotation_dir, 'gt_action.npy'), allow_pickle=True)
    with open(os.path.join(annotation_dir, 'gt_segments.json'), 'r') as _file:
        gt_segments = json.load(_file)

    ann_seqs = {
        seq_name: [ann for ann in ann_seq['annotation']]
        for seq_name, ann_seq in gt_segments['database'].items()
    }

    kinem_vocab = [lib_asm.Assembly()]

    all_label_index_seqs = collections.defaultdict(list)
    for seq_name, ann_seq in ann_seqs.items():
        logger.info(f"Processing sequence {seq_name}...")
        furn_name, other_name = seq_name.split('/')
        goal_state = make_goal_state(furn_name)

        label_seq = tuple(ann['label'] for ann in ann_seq)
        segment_seq = tuple(ann['segment'] for ann in ann_seq)
        start_seq, end_seq = tuple(zip(*segment_seq))
        df = pd.DataFrame({
            'start': start_seq,
            'end': end_seq,
            'label': label_seq
        })
        df = df.loc[df['label'] != 'NA']

        if not df.any().any():
            warn_str = f"No labels: {furn_name}, {other_name}"
            logger.warning(warn_str)
            continue

        out_path = os.path.join(out_labels_dir, furn_name)
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        try:
            new_df = convert_labels(df)
        except AssertionError as e:
            logger.warning(f"  Skipping video: {e}")
            continue

        new_label_seq = tuple(' '.join(part_name.split()[:-1])
                              for part_name in new_df['arg1'])
        label_index_seq = tuple(
            new_action_vocab.index(label) for label in new_label_seq)
        all_label_index_seqs[furn_name].append(label_index_seq)

        kinem_df = parse_assembly_actions(new_df, kinem_vocab)
        kinem_states = tuple(kinem_vocab[i] for i in kinem_df['state'])
        if not kinem_states[-1] == goal_state:
            warn_str = f"  Final structure != goal structure:\n{kinem_states[-1]}"
            logger.warning(warn_str)

        lib_asm.writeAssemblies(
            os.path.join(out_path, f"{other_name}_kinem-state.txt"),
            kinem_states)

        df.to_csv(os.path.join(out_path, f"{other_name}_human.csv"),
                  index=False)
        new_df.to_csv(os.path.join(out_path, f"{other_name}_kinem-action.csv"),
                      index=False)
        kinem_df.to_csv(os.path.join(out_path,
                                     f"{other_name}_kinem-state.csv"),
                        index=False)

        if not any(label_seq):
            logger.warning(f"No labels: {seq_name}")

    lib_asm.writeAssemblies(os.path.join(out_labels_dir, "kinem-vocab.txt"),
                            kinem_vocab)
    symbol_table = fstutils.makeSymbolTable(new_action_vocab)
    for furn_name, label_index_seqs in all_label_index_seqs.items():
        label_fsts = tuple(
            fstutils.fromSequence(label_index_seq, symbol_table=symbol_table)
            for label_index_seq in label_index_seqs)
        union_fst = libfst.determinize(fstutils.easyUnion(*label_fsts))
        union_fst.minimize()

        # for i, label_fst in enumerate(label_fsts):
        #     fn = os.path.join(fig_dir, f"{furn_name}-{i}")
        #     label_fst.draw(
        #         fn, isymbols=symbol_table, osymbols=symbol_table,
        #         # vertical=True,
        #         portrait=True,
        #         acceptor=True
        #     )
        #     gv.render('dot', 'pdf', fn)
        fn = os.path.join(fig_dir, f"{furn_name}-union")
        union_fst.draw(
            fn,
            isymbols=symbol_table,
            osymbols=symbol_table,
            # vertical=True,
            portrait=True,
            acceptor=True)
        gv.render('dot', 'pdf', fn)
Example #31
0
    word_begin = c
    print >> compiler_accept_vocab, '0 ' + str(c) + ' <eps> <eps>'
    for char in word:
        print >> compiler_accept_vocab, str(c) + ' ' + str(
            c + 1) + ' ' + char + ' ' + char
        c = c + 1
    print >> compiler_accept_vocab, str(c) + ' ' + str(c + 1) + ' </w> </w>'
    c = c + 1
    print >> compiler_accept_vocab, str(c)
    word_nodes[word] = (word_begin, c)
    c = c + 1

for word in word_nodes.keys():
    for follower in vocab_words[
            word]:  # add an arc from last state of word to first state of follower
        print >> compiler_accept_vocab, str(word_nodes[word][1]) + ' ' + str(
            word_nodes[follower][0]) + ' <eps> <eps>'

for c in special_characters:  #+['<sil>']:
    print >> compiler, '0 0 ' + c + ' ' + c

fst_vocab = compiler_accept_vocab.compile()
# save_autom(fst_vocab.arcsort(), 'vocab')
fst_vocab = fst.determinize(fst_vocab.rmepsilon()).minimize().arcsort()
index_name = '.'.join(index_file.split('/')[-1].split('.')[0:-1])
fst_vocab.write('FSTs/vocab_' + index_name + '.fst')

# erroneous_vocab = fst.compose(error_maker_fst, fst_vocab).project(project_output=False).rmepsilon()
# erroneous_vocab = fst.determinize(erroneous_vocab).minimize()
# erroneous_vocab.write('FSTs/erroneous_vocab_'+index_name+'.fst')