コード例 #1
0
 def decode_beam(self, ex, beam_size=1, max_len=100):
   h_t = self._encode(ex.x_inds)
   beam = [[Derivation(ex, 1, [], hidden_state=h_t,
                       attention_list=[], copy_list=[])]]
   finished = []
   for i in range(1, max_len):
     new_beam = []
     for deriv in beam[i-1]:
       cur_p = deriv.p
       h_t = deriv.hidden_state
       y_tok_seq = deriv.y_toks
       write_dist = self._decoder_write(h_t)
       sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(write_dist)],
                            reverse=True)
       for j in range(beam_size):
         p_y_t, y_t = sorted_dist[j]
         new_p = cur_p * p_y_t
         if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
           finished.append(Derivation(ex, new_p, y_tok_seq))
           continue
         y_tok = self.out_vocabulary.get_word(y_t)
         new_h_t = self._decoder_step(y_t, h_t)
         new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok],
                                hidden_state=new_h_t)
         new_beam.append(new_entry)
     new_beam.sort(key=lambda x: x.p, reverse=True)
     beam.append(new_beam[:beam_size])
     finished.sort(key=lambda x: x.p, reverse=True)
   return sorted(finished, key=lambda x: x.p, reverse=True)
コード例 #2
0
ファイル: attn2hist.py プロジェクト: shijogeorge24/seq2sql
 def decode_beam(self, ex, beam_size=1, max_len=100):
   h_t, annotations = self._encode(ex.x_inds)
   beam = [[Derivation(ex, 1, [], hidden_state=h_t, 
                       attention_list=[], copy_list=[])]]
   finished = []
   for i in range(1, max_len):
     #print >> sys.stderr, 'decode_beam: length = %d' % i
     if len(beam[i-1]) == 0: break
     # See if beam_size-th finished deriv is best than everything on beam now.
     if len(finished) >= beam_size:
       finished_p = finished[beam_size-1].p
       cur_best_p = beam[i-1][0].p
       if cur_best_p < finished_p:
         break
     new_beam = []
     for deriv in beam[i-1]:
       cur_p = deriv.p
       h_t = deriv.hidden_state
       y_tok_seq = deriv.y_toks
       attention_list = deriv.attention_list
       copy_list = deriv.copy_list
       write_dist, c_t, alpha = self._loc_decoder_write(annotations, h_t)
       #for p_y_t,y_t in enumerate(write_dist):
       #    y_tok = self.out_vocabulary.get_word(y_t)
       #    print(y_tok)
       #    input('im here ..')
       sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(write_dist)],
                            reverse=True)
       for j in range(beam_size):
         p_y_t, y_t = sorted_dist[j]
         new_p = cur_p * p_y_t
         if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
           finished.append(Derivation(ex, new_p, y_tok_seq,
                                      attention_list=attention_list + [alpha],
                                      copy_list=copy_list + [0]))
           continue
         if y_t < self.out_vocabulary.size():
           y_tok = self.out_vocabulary.get_word(y_t)
           do_copy = 0
         else:
           new_ind = y_t - self.out_vocabulary.size()
           augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
           y_tok = augmented_copy_toks[new_ind]
           y_t = self.out_vocabulary.get_index(y_tok)
           do_copy = 1
         new_h_t = self._decoder_step(y_t, c_t, h_t)
         new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok],
                                hidden_state=new_h_t,
                                attention_list=attention_list + [alpha],
                                copy_list=copy_list + [do_copy])
         new_beam.append(new_entry)
     new_beam.sort(key=lambda x: x.p, reverse=True)
     beam.append(new_beam[:beam_size])
     finished.sort(key=lambda x: x.p, reverse=True)
   return sorted(finished, key=lambda x: x.p, reverse=True)
コード例 #3
0
ファイル: attention.py プロジェクト: dadashkarimi/mtl4log
 def decode_greedy(self, ex, max_len=100):
     h_t, annotations = self._encode(ex.x_inds)
     y_tok_seq = []
     p_y_seq = []  # Should be handy for error analysis
     p = 1
     for i in range(max_len):
         write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
         y_t = numpy.argmax(write_dist)
         p_y_t = write_dist[y_t]
         p_y_seq.append(p_y_t)
         p *= p_y_t
         if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
             break
         if y_t < self.out_vocabulary.size():
             y_tok = self.out_vocabulary.get_word(y_t)
         else:
             new_ind = y_t - self.out_vocabulary.size()
             augmented_copy_toks = ex.copy_toks + [
                 Vocabulary.END_OF_SENTENCE
             ]
             y_tok = augmented_copy_toks[new_ind]
             y_t = self.out_vocabulary.get_index(y_tok)
         y_tok_seq.append(y_tok)
         h_t = self._decoder_step(y_t, ex.d_inds[0], c_t, h_t)
     return [Derivation(ex, p, y_tok_seq)]
コード例 #4
0
ファイル: attn2hist.py プロジェクト: dadashkarimi/mtl4log
    def decode_by_em(self, ex, max_len=100):
        h_t, annotations = self._encode(ex.x_inds)
        y_tok_seq = []
        p_y_seq = []  # Should be handy for error analysis
        p = 1
        print(ex.x_str)
        for i in range(max_len):  # step 1
            write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
            #write_test = write_dist[(self.spec.domain_size-1)*self.out_vocabulary.size():]
            domain_scores_h_t = T.nnet.softmax(
                self.spec.get_domain_scores_h_t(h_t))
            domain_scores_c_t = T.nnet.softmax(
                self.spec.get_domain_scores_c_t(c_t))

            #d_t = numpy.argmax(domain_scores_h_t)
            y_t = numpy.argmax(write_dist)  #FIXME
            #p_y_t = domain_scores[d_t]*write_dist[y_t] #FIXME
            p_y_t = write_dist[y_t]  #FIXME
            p_y_seq.append(p_y_t)
            p *= p_y_t
            if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
                break
            if y_t < self.out_vocabulary.size():
                #if y_t < self.out_vocabulary.size():
                #y_t = (y_t%self.out_vocabulary.size())
                y_tok = self.out_vocabulary.get_word(y_t)
            else:
                new_ind = y_t - self.out_vocabulary.size()
                #new_ind = y_t - self.out_vocabulary.size()
                augmented_copy_toks = ex.copy_toks + [
                    Vocabulary.END_OF_SENTENCE
                ]
                y_tok = augmented_copy_toks[new_ind]
                y_t = self.out_vocabulary.get_index(y_tok)
            '''elif y_t>=self.out_vocabulary.size() and y_t<(self.out_vocabulary.size()+self.in_vocabulary.size()):
        new_ind = y_t -self.out_vocabulary.size()
        y_tok = self.in_vocabulary.get_word(new_ind)
        print('History >')
        print(y_tok,write_dist[y_t])
        y_t = self.out_vocabulary.get_index(y_tok)
      else:
        new_ind = y_t - (self.out_vocabulary.size()+self.in_vocabulary.size())
        augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
        y_tok = augmented_copy_toks[new_ind]
        print('Local >')
        print(y_tok,write_dist[y_t])
        y_t = self.out_vocabulary.get_index(y_tok)'''
            y_tok_seq.append(y_tok)
            h_t = self._decoder_step(y_t, c_t, h_t)
        return [Derivation(ex, p, y_tok_seq)]
コード例 #5
0
    def decode_greedy(self, ex, max_len=100):
        return "TEST"
        h_t, annotations = self._encode(ex.x_inds)
        y_tok_seq = []
        p_y_seq = []  # Should be handy for error analysis
        p = 1
        for i in range(100):  #max_len):
            write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
            for j in range(len(write_dist)):
                y_t = j
                x_t = 0
                if y_t < self.out_vocabulary.size():
                    y_tok = self.out_vocabulary.get_word(y_t)
                if y_t >= self.out_vocabulary.size():
                    new_ind = y_t - self.out_vocabulary.size()
                    augmented_copy_toks = ex.copy_toks + [
                        Vocabulary.END_OF_SENTENCE
                    ]
                    y_tok = augmented_copy_toks[new_ind]
                    x_t = self.in_vocabulary.get_index(y_tok)
                    if x_t in self.spec.pair_stat:
                        for y_, p_xy in self.spec.pair_stat[x_t]:
                            write_dist[y_] = 1.0 * write_dist[y_] + 0.0 * p_xy
                            #write_dist[y_] = 0.9*write_dist[y_]+0.1*p_xy
                            #write_dist[y_] = 0.8*write_dist[y_]+0.2*p_xy
                            #write_dist[y_] = 0.7*write_dist[y_]+0.3*p_xy
                            #write_dist[y_] = 0.6*write_dist[y_]+0.4*p_xy
                            #write_dist[y_] = 0.5*write_dist[y_]+0.5*p_xy

            y_t = numpy.argmax(write_dist)

            p_y_t = write_dist[y_t]
            p_y_seq.append(p_y_t)
            p *= p_y_t
            if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
                break
            if y_t < self.out_vocabulary.size():
                y_tok = self.out_vocabulary.get_word(y_t)
            else:
                new_ind = y_t - self.out_vocabulary.size()
                augmented_copy_toks = ex.copy_toks + [
                    Vocabulary.END_OF_SENTENCE
                ]
                y_tok = augmented_copy_toks[new_ind]
                y_t = self.out_vocabulary.get_index(y_tok)

            y_tok_seq.append(y_tok)
            h_t = self._decoder_step(y_t, c_t, h_t)
        return [Derivation(ex, p, y_tok_seq)]
コード例 #6
0
 def decode_greedy(self, ex, max_len=100):
   h_t = self._encode(ex.x_inds)
   y_tok_seq = []
   p_y_seq = []  # Should be handy for error analysis
   p = 1
   for i in range(max_len):
     write_dist = self._decoder_write(h_t)
     y_t = numpy.argmax(write_dist)
     p_y_t = write_dist[y_t]
     p_y_seq.append(p_y_t)
     p *= p_y_t
     if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
       break
     y_tok = self.out_vocabulary.get_word(y_t)
     y_tok_seq.append(y_tok)
     h_t = self._decoder_step(y_t, h_t)
   return [Derivation(ex, p, y_tok_seq)]
コード例 #7
0
def getBoundDerivation(cfg, string, positionFunction):
    queue = q.Queue()
    queue.put((cfg[3], Derivation(entries=[(cfg[3], 0)])))
    while (not queue.empty()):
        current, derivation = queue.get()
        if current == string:
            derivation.add(current, -1)
            derivation.clean()
            return derivation
        elif positionFunction(current, cfg[0]) != -1:
            pos = positionFunction(current, cfg[0])
            for v in cfg[2][current[pos]]:
                newDerivation = copy.deepcopy(derivation)
                newDerivation.add(current, pos)
                queue.put((replaceAtIndexWithString(current, pos,
                                                    v), newDerivation))
        else:
            pass
コード例 #8
0
def getAllDerivationsRecursive(cfg, string, currentString, maxSubs, maxDevs,
                               derivations, subs, currentDevs):
    if len(derivations) == maxDevs or maxSubs == subs:
        pass
    elif string == currentString:
        currentDevs.append((currentString, -1))
        derivations.append(Derivation(entries=currentDevs))
        currentDevs.pop()
    else:
        numberOfElements = len(currentString)
        for i in range(len(currentString)):
            if (currentString[i] in cfg[0]):
                for v in cfg[2][currentString[i]]:
                    cs = replaceAtIndexWithString(currentString, i, v)
                    currentDevs.append((currentString, i))
                    getAllDerivationsRecursive(cfg, string, cs, maxSubs,
                                               maxDevs, derivations, subs + 1,
                                               currentDevs)
                    currentDevs.pop()

    return derivations
コード例 #9
0
    def top_down_beam_search(self, words, oracle_actions, oracle_tokens, buffer, stack_top, action_top, beam_size=5):
        finished = []    
        stack = []
        action = self._ROOT
        nt = self.nt_vocab[oracle_tokens.pop(0)]
        nt_embedding = self.nt_input_layer(self.nt_lookup[nt])

        stack_state = stack_top.add_input(nt_embedding)
        stack.append((stack_state, 'p', nt_embedding))

        action_embedding = self.act_input_layer(self.act_lookup[action])
        action_top = action_top.add_input(action_embedding)
        # first, set up the initial beam
        beam = [Derivation(stack, action_top, [self.act_vocab.token(action)], [self.nt_vocab.token(nt)], 0, nt_allowed=1, ter_allowed=0, reducable=0, total_nt=1)]

        # loop until we obtain enough finished beam
        while len(finished) < beam_size:
            new_beam = []
            unfinished = []

            # collect all possible expanded beam
            for der in beam:
                valid_actions = []
                if len(der.stack) >= 1:
                    if der.ter_allowed == 1:
                        valid_actions += [self._TER]
                    if der.nt_allowed == 1:
                        valid_actions += [self._NT] + self._NT_general
                if len(der.stack) >= 2 and der.reducable != 0:
                    valid_actions += [self._ACT]

                stack_embedding = der.stack[-1][0].output()
                action_summary = der.action_top.output()
                word_weights = self.attention(stack_embedding, buffer)
                buffer_embedding, _ = attention_output(buffer, word_weights, 'soft_average')

                for i in range(len(der.stack)):
                    if der.stack[len(der.stack)-1-i][1] == 'p':
                        parent_embedding = der.stack[len(der.stack)-1-i][2]
                        break
                parser_state = dy.concatenate([buffer_embedding, stack_embedding, parent_embedding, action_summary])
                h = self.mlp_layer(parser_state)
                log_probs = dy.log_softmax(self.act_proj_layer(h), valid_actions)
                
                sorted_actions_logprobs = sorted(enumerate(log_probs.vec_value()), key=itemgetter(1), reverse=True)
                sorted_actions = [x[0] for x in sorted_actions_logprobs if x[1] > -999]
                sorted_logprobs = [x[1] for x in sorted_actions_logprobs if x[1] > -999]
                if len(sorted_actions) > beam_size:
                    sorted_actions = sorted_actions[:beam_size] 
                    sorted_logprobs = sorted_logprobs[:beam_size]

                for action, logprob in zip(sorted_actions, sorted_logprobs):
                    if action == self._NT:
                        output_feature, output_logprob = attention_output(buffer, word_weights, self.test_selection, argmax=True)
                        log_probs_nt = dy.log_softmax(self.nt_proj_layer(output_feature))
                        sorted_nt_logprobs = sorted(enumerate(log_probs_nt.vec_value()), key=itemgetter(1), reverse=True)
                        for i in range(beam_size):
                            new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, 
                                                    der.logp+logprob+sorted_nt_logprobs[i][1], 1, 1, 1, der.total_nt, action, sorted_nt_logprobs[i][0]))
                    elif action == self._TER:
                        output_feature, output_logprob = attention_output(buffer, word_weights, self.test_selection, argmax=True)
                        log_probs_ter = dy.log_softmax(self.ter_proj_layer(output_feature))
                        sorted_ter_logprobs = sorted(enumerate(log_probs_ter.vec_value()), key=itemgetter(1), reverse=True)
                        for i in range(beam_size):
                            new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, 
                                                    der.logp+logprob+sorted_ter_logprobs[i][1], 1, 1, 1, der.total_nt, action, sorted_ter_logprobs[i][0]))
                    else:
                        new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, 
                                                    der.logp+logprob, 1, 1, 1, der.total_nt, next_act=action))

            # sort these expanded beam, keep only the top k 
            new_beam.sort(key=lambda x: x.logp, reverse=True)
            if len(new_beam) > beam_size:
                new_beam = new_beam[:beam_size]

            # execute and update the top k remaining beam
            for i in range(len(new_beam)):
                action = new_beam[i].next_act
                if action == self._NT:
                    nt = new_beam[i].next_tok
                    nt_embedding = self.nt_input_layer(self.nt_lookup[nt])

                    stack_state, label, _ = new_beam[i].stack[-1] 
                    stack_state = stack_state.add_input(nt_embedding)
                    new_beam[i].stack.append((stack_state, 'p', nt_embedding))

                    new_beam[i].output_actions.append(self.act_vocab.token(action))
                    new_beam[i].output_tokens.append(self.nt_vocab.token(nt))
                    new_beam[i].total_nt += 1

                elif action == self._TER:
                    ter = new_beam[i].next_tok
                    ter_embedding = self.ter_input_layer(self.ter_lookup[ter])

                    stack_state, label, _ = new_beam[i].stack[-1]
                    stack_state = stack_state.add_input(ter_embedding)
                    new_beam[i].stack.append((stack_state, 'c', ter_embedding))

                    new_beam[i].output_actions.append(self.act_vocab.token(action))
                    new_beam[i].output_tokens.append(self.ter_vocab.token(ter))

                elif action in self._NT_general:
                    nt = self.act_vocab.token(new_beam[i].next_act).rstrip(')').lstrip('NT(')
                    nt = self.nt_vocab[nt]
                    nt_embedding = self.nt_input_layer(self.nt_lookup[nt])

                    stack_state, label, _ = new_beam[i].stack[-1]
                    stack_state = stack_state.add_input(nt_embedding)
                    new_beam[i].stack.append((stack_state, 'p', nt_embedding))

                    new_beam[i].output_actions.append(self.act_vocab.token(action))
                    new_beam[i].output_tokens.append(self.nt_vocab.token(nt))
                    new_beam[i].total_nt += 1

                else:
                    found_p = 0
                    path_input = []
                    while found_p != 1:
                        top = new_beam[i].stack.pop()
                        top_raw_rep, top_label, top_rep = top[2], top[1], top[0]
                        path_input.append(top_raw_rep)
                        if top_label == 'p':
                            found_p = 1
                    parent_rep = path_input.pop()
                    composed_rep = self.subtree_input_layer(dy.concatenate([dy.average(path_input), parent_rep]))

                    stack_state = new_beam[i].stack[-1][0] if new_beam[i].stack else stack_top
                    stack_state = stack_state.add_input(composed_rep)
                    new_beam[i].stack.append((stack_state, 'c', composed_rep))
                    reduced = 1
  
                    new_beam[i].output_actions.append(self.act_vocab.token(action))

                # appended the beam to the finished set if it is completed
                if len(new_beam[i].stack) == 1:
                    finished.append(new_beam[i])
                    continue

                # if not completed, proceed
                action_embedding = self.act_input_layer(self.act_lookup[action])
                new_beam[i].action_top = new_beam[i].action_top.add_input(action_embedding)
 
                reducable = 1
                nt_allowed = 1
                ter_allowed = 1

                #reduce cannot follow nt
                if new_beam[i].stack[-1][1] == 'p':
                    reducable = 0

                #nt is disabled if maximum open non-terminal allowed is reached
                count_p = 0
                for item in new_beam[i].stack:
                    if item[1] == 'p':
                        count_p += 1
                if count_p >= 10:
                    nt_allowed = 0

                if len(new_beam[i].stack) > len(words) or new_beam[i].total_nt > len(words):
                    nt_allowed = 0

                #ter is disabled if maximum children under the open nt is reached
                count_c = 0
                for item in new_beam[i].stack[::-1]:
                    if item[1] == 'c':
                        count_c += 1
                    else:
                        break
                if count_c >= 10:
                    ter_allowed = 0
 
                new_beam[i].nt_allowed = nt_allowed
                new_beam[i].ter_allowed = ter_allowed
                new_beam[i].reducable = reducable
                new_beam[i].next_act = None
                new_beam[i].next_tok = None

                unfinished.append(new_beam[i])

            beam = unfinished

        finished.sort(key=lambda x: x.logp, reverse=True)    

        return finished[0].output_actions, finished[0].output_tokens
コード例 #10
0
ファイル: attn2hist.py プロジェクト: shijogeorge24/seq2sql
  def decode_by_em(self,ex,max_len=100):
    h_t, annotations = self._encode(ex.x_inds)
    y_tok_seq = []
    p_y_seq = []  # Should be handy for error analysis
    p = 1
    t,a = self.spec.em_model
    #s_total[e] += t[(e, f)] * a[(i, j, l_e, l_f)]
    l_f =0
    print(ex.x_str)
    for i in range(max_len): # step 1
      from ibm2 import ibm2
      
      write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
      #print('first:')
      #for j in range(self.out_vocabulary.size()):
      # print(j,self.out_vocabulary.get_word(j),write_dist[j])

      y_t = numpy.argmax(write_dist)
      p_y_t = write_dist[y_t]
      p_y_seq.append(p_y_t)
      p *= p_y_t
      if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
        l_f = i # length of the predicted sequence
        break
      if y_t < self.out_vocabulary.size():
        y_tok = self.out_vocabulary.get_word(y_t)
      elif y_t>=self.out_vocabulary.size() and y_t<(self.out_vocabulary.size()+self.in_vocabulary.size()):
        new_ind = y_t -self.out_vocabulary.size()
        y_tok = self.in_vocabulary.get_word(new_ind)
        print('----> second:')
        print(y_tok,write_dist[y_t])
        y_t = self.out_vocabulary.get_index(y_tok)
      else:
        new_ind = y_t - (self.out_vocabulary.size()+self.in_vocabulary.size())
        augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
        y_tok = augmented_copy_toks[new_ind]
        print('---> third:')
        print(y_tok,write_dist[y_t])
        y_t = self.out_vocabulary.get_index(y_tok)
      # liang
      '''if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
        l_f = i # length of the predicted sequence
        break
      if y_t < self.out_vocabulary.size():
        y_tok = self.out_vocabulary.get_word(y_t)
      else:
        new_ind = y_t - (self.out_vocabulary.size())
        augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
        y_tok = augmented_copy_toks[new_ind]
        y_t = self.out_vocabulary.get_index(y_tok)'''




      y_tok_seq.append(y_tok)
      h_t = self._decoder_step(y_t, c_t, h_t)
    
    '''print('first: ',y_tok_seq)
    es = ex.x_str.split() 
    fs = y_tok_seq
    args =(es,fs,t,a)
    ibmmodel=ibm2([(ex.x_str,ex.y_str)])
    print(ibmmodel.show_matrix(*args))

    
    h_t, annotations = self._encode(ex.x_inds)
    y_tok_seq = []
    p_y_seq = []  # Should be handy for error analysis
    p = 1
    for i in range(max_len):# step 2
      write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
      for j in range(len(write_dist)):
          y_t = j
          x_t = 0
          if y_t >= self.out_vocabulary.size():
            new_ind = y_t - self.out_vocabulary.size()
            augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
            x_tok = augmented_copy_toks[new_ind]

            x_t = self.in_vocabulary.get_index(x_tok)
            if x_t in self.spec.pair_stat:
                for y_,p_xy in self.spec.pair_stat[x_t]:
                    y_tok_ = self.out_vocabulary.get_word(y_)
                    write_dist[y_] = 1.0*write_dist[y_]+0.0*float(a[(new_ind,i,len(ex.x_str),l_f)])
      y_t = numpy.argmax(write_dist)
      
      p_y_t = write_dist[y_t]
      p_y_seq.append(p_y_t)
      p *= p_y_t
      if y_t == Vocabulary.END_OF_SENTENCE_INDEX:
        break
      if y_t < self.out_vocabulary.size():
        y_tok = self.out_vocabulary.get_word(y_t)
      else:
        new_ind = y_t - self.out_vocabulary.size()
        augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE]
        y_tok = augmented_copy_toks[new_ind]
        y_t = self.out_vocabulary.get_index(y_tok)

      y_tok_seq.append(y_tok)
      h_t = self._decoder_step(y_t, c_t, h_t)

    
    print('second: ',y_tok_seq)
    es = ex.x_str.split() 
    fs = y_tok_seq
    args =(es,fs,t,a)
    ibmmodel=ibm2([(ex.x_str,ex.y_str)])
    print(ibmmodel.show_matrix(*args))'''
    #input('here ..')
    return [Derivation(ex, p, y_tok_seq)]
コード例 #11
0
    def decode_beam(self,
                    domain,
                    ex,
                    domain_convertor,
                    domain_controller,
                    general_controller,
                    beam_size=1,
                    max_len=100):
        h_t, annotations = self._encode(ex.x_inds)
        beam = [[
            Derivation(ex,
                       1, [], [],
                       hidden_state=h_t,
                       p_list=[],
                       entity_lex_map=ex.entity_lex_map,
                       attention_list=[])
        ]]
        finished = []
        final_finished = []
        action_all = self.out_vocabulary.get_action_list()

        #print('entity_lex in decode_beam = %s' % ex.entity_lex_map)

        for i in range(1, max_len):
            #print >> sys.stderr, 'decode_beam: length = %d' % i
            if len(beam[i - 1]) == 0: break
            # See if beam_size-th finished deriv is best than everything on beam now.
            if len(finished) >= beam_size:
                finished_p = finished[beam_size - 1].p
                cur_best_p = beam[i - 1][0].p
                if cur_best_p < finished_p:
                    break
            new_beam = []

            for deriv in beam[i - 1]:
                cur_p = deriv.p
                h_t = deriv.hidden_state
                y_tok_seq = deriv.y_toks
                p_list = deriv.p_list
                attention_list = deriv.attention_list
                entity_lex_map = deriv.entity_lex_map

                #print('entity_lex in loop for beam = %s' % entity_lex_map)

                gen_pre_action_for_test = deriv.gen_pre_action_in_deriv
                gen_pre_action_class_for_test = deriv.gen_pre_action_class_in_deriv
                gen_pre_arg_list_for_test = copy.deepcopy(
                    deriv.gen_pre_arg_list_in_deriv)

                node_dict_for_test = copy.deepcopy(deriv.node_dict_in_deriv)
                type_node_dict_for_test = copy.deepcopy(
                    deriv.type_node_dict_in_deriv)
                entity_node_dict_for_test = copy.deepcopy(
                    deriv.entity_node_dict_in_deriv)
                operation_dict_for_test = copy.deepcopy(
                    deriv.operation_dict_in_deriv)
                edge_dict_for_test = copy.deepcopy(deriv.edge_dict_in_deriv)
                return_node_for_test = copy.deepcopy(
                    deriv.return_node_in_deriv)
                db_triple_for_test = copy.deepcopy(deriv.db_triple_in_deriv)
                fun_trace_list_for_test = copy.deepcopy(
                    deriv.fun_trace_list_in_deriv)
                #print('***************************************')
                #print('y_tok_seq: %s' % ' '.join(y_tok_seq))
                #print('pre_action_for_test: ', gen_pre_action_for_test)
                #print('pre_action_class_for_test: ', gen_pre_action_class_for_test)
                #print('pre_arg_list_for_test: ', gen_pre_arg_list_for_test)
                #print('type_node_dict_for_test: ', type_node_dict_for_test)
                #print('edge_dict_for_test: ', edge_dict_for_test)
                #print('node_dict_for_test: ', node_dict_for_test)
                #print('entity_node_dict_for_test: ', entity_node_dict_for_test)
                #print('db_trible_for_test: ', db_triple_for_test)
                #print('operation_dict_for_test: ', operation_dict_for_test)
                #print('return_node_for_test: ', return_node_for_test)
                #print('fun_trace_list_for_test: ', fun_trace_list_for_test)

                write_dist, c_t, alpha = self._decoder_write(annotations, h_t)

                legal_dist_gen = self.get_legal_action_list(
                    general_controller,
                    gen_pre_action_class_for_test,
                    gen_pre_arg_list_for_test,
                    gen_pre_action_for_test,
                    node_dict_for_test,
                    type_node_dict_for_test,
                    entity_node_dict_for_test,
                    operation_dict_for_test,
                    edge_dict_for_test,
                    return_node_for_test,
                    db_triple_for_test,
                    fun_trace_list_for_test,
                    action_all,
                    entity_lex_map=entity_lex_map)

                action_all_for_domain = []
                for ii in range(len(legal_dist_gen)):
                    if legal_dist_gen[ii]:
                        action_all_for_domain.append(action_all[ii])
                    else:
                        action_all_for_domain.append('<COPY>')

                legal_dist_dom = self.get_legal_action_list(
                    domain_controller,
                    gen_pre_action_class_for_test,
                    gen_pre_arg_list_for_test,
                    gen_pre_action_for_test,
                    node_dict_for_test,
                    type_node_dict_for_test,
                    entity_node_dict_for_test,
                    operation_dict_for_test,
                    edge_dict_for_test,
                    return_node_for_test,
                    db_triple_for_test,
                    fun_trace_list_for_test,
                    action_all_for_domain,
                    entity_lex_map=entity_lex_map)

                #print('write_dist: (', len(write_dist), ') ', write_dist)
                #print('legal_dist_gen: (', len(legal_dist_gen), ') ', legal_dist_gen)
                #print('legal_dist_dom: (', len(legal_dist_dom), ') ', legal_dist_dom)

                final_dist = write_dist * legal_dist_gen * legal_dist_dom

                #print('final_dist: (', len(final_dist), ') ', final_dist)

                sorted_dist = sorted([(p_y_t, y_t)
                                      for y_t, p_y_t in enumerate(final_dist)],
                                     reverse=True)

                for j in range(beam_size):
                    gen_pre_action_for_read = gen_pre_action_for_test
                    gen_pre_action_class_for_read = gen_pre_action_class_for_test
                    gen_pre_arg_list_for_read = copy.deepcopy(
                        gen_pre_arg_list_for_test)

                    node_dict_for_read = copy.deepcopy(node_dict_for_test)
                    type_node_dict_for_read = copy.deepcopy(
                        type_node_dict_for_test)
                    entity_node_dict_for_read = copy.deepcopy(
                        entity_node_dict_for_test)
                    operation_dict_for_read = copy.deepcopy(
                        operation_dict_for_test)
                    edge_dict_for_read = copy.deepcopy(edge_dict_for_test)
                    return_node_for_read = copy.deepcopy(return_node_for_test)
                    db_triple_for_read = copy.deepcopy(db_triple_for_test)
                    fun_trace_list_for_read = copy.deepcopy(
                        fun_trace_list_for_test)

                    #print('--------------')
                    #print('pre_action_for_read: ', gen_pre_action_for_read)
                    #print('pre_action_class_for_read: ', gen_pre_action_class_for_read)
                    #print('pre_arg_list_for_read: ', gen_pre_arg_list_for_read)
                    #print('type_node_dict_for_read: ', type_node_dict_for_read)
                    #print('edge_dict_for_read: ', edge_dict_for_read)
                    #print('node_dict_for_read: ', node_dict_for_read)
                    #print('entity_node_dict_for_read: ', entity_node_dict_for_read)
                    #print('db_trible_for_read: ', db_triple_for_read)
                    #print('operation_dict_for_read: ', operation_dict_for_read)
                    #print('return_node_for_read: ', return_node_for_read)
                    #print('fun_trace_list_for_read: ', fun_trace_list_for_read)
                    #print('--------------')

                    p_y_t, y_t = sorted_dist[j]
                    if p_y_t == 0.0:
                        continue
                    new_p = cur_p * p_y_t
                    append_flag = False
                    if self.out_vocabulary.action_is_end(domain, y_t):
                        append_flag = True
                    if y_t < self.out_vocabulary.size():
                        y_tok = self.out_vocabulary.get_action(y_t)
                    else:
                        print('error action index!')
                    new_h_t = self._decoder_step(y_t, c_t, h_t)
                    #print('y_tok: ', y_tok, ' p_y_t: ', p_y_t)
                    action_token = y_tok
                    gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \
                      general_controller.is_legal_action_then_read(gen_pre_action_class_for_read, gen_pre_arg_list_for_read, action_token,
                                                                   gen_pre_action_for_read, node_dict_for_read, type_node_dict_for_read, entity_node_dict_for_read,
                                                                   operation_dict_for_read, edge_dict_for_read, return_node_for_read, db_triple_for_read,
                                                                   fun_trace_list_for_read)
                    if not gen_flag:
                        print('test is right, but read is wrong!')
                        continue
                    if append_flag:
                        finished.append(
                            Derivation(
                                ex,
                                new_p,
                                y_tok_seq + [y_tok], [],
                                p_list=p_list + [p_y_t],
                                entity_lex_map=entity_lex_map,
                                attention_list=attention_list + [alpha],
                                gen_pre_action_in_deriv=gen_pre_action_out,
                                gen_pre_action_class_in_deriv=
                                gen_pre_action_class_out,
                                gen_pre_arg_list_in_deriv=gen_pre_arg_list_out,
                                node_dict_in_deriv=node_dict_for_read,
                                type_node_dict_in_deriv=type_node_dict_for_read,
                                entity_node_dict_in_deriv=
                                entity_node_dict_for_read,
                                operation_dict_in_deriv=operation_dict_for_read,
                                edge_dict_in_deriv=edge_dict_for_read,
                                return_node_in_deriv=return_node_for_read,
                                db_triple_in_deriv=db_triple_for_read,
                                fun_trace_list_in_deriv=fun_trace_list_out))
                        continue
                    new_entry = Derivation(
                        ex,
                        new_p,
                        y_tok_seq + [y_tok], [],
                        hidden_state=new_h_t,
                        p_list=p_list + [p_y_t],
                        entity_lex_map=entity_lex_map,
                        attention_list=attention_list + [alpha],
                        gen_pre_action_in_deriv=gen_pre_action_out,
                        gen_pre_action_class_in_deriv=gen_pre_action_class_out,
                        gen_pre_arg_list_in_deriv=gen_pre_arg_list_out,
                        node_dict_in_deriv=node_dict_for_read,
                        type_node_dict_in_deriv=type_node_dict_for_read,
                        entity_node_dict_in_deriv=entity_node_dict_for_read,
                        operation_dict_in_deriv=operation_dict_for_read,
                        edge_dict_in_deriv=edge_dict_for_read,
                        return_node_in_deriv=return_node_for_read,
                        db_triple_in_deriv=db_triple_for_read,
                        fun_trace_list_in_deriv=fun_trace_list_out)
                    new_beam.append(new_entry)

            new_beam.sort(key=lambda x: x.p, reverse=True)
            beam.append(new_beam[:beam_size])
            finished.sort(key=lambda x: x.p, reverse=True)
        for deriv in finished:
            y_new_toks = []
            entity_lex = deriv.example.entity_lex_map
            for y_tok in deriv.y_toks:
                if y_tok.startswith('add_entity_node:'):
                    entity = y_tok[y_tok.index(':-:') + 3:]
                    if entity in entity_lex:
                        new_entity = entity_lex[entity]
                        y_new_tok = y_tok.replace(entity, new_entity)
                        y_new_toks.append(y_new_tok)
                        continue
                y_new_toks.append(y_tok)
            y_new_str = ' '.join(y_new_toks)

            #print('entity_lex in finished = %s' % entity_lex)
            #print('y_new_str: %s' % y_new_str)
            y_toks_lf = domain_convertor(y_new_str,
                                         domain_controller,
                                         general_controller,
                                         entity_lex_map=entity_lex)
            #print('y_toks_lf: %s' % ' '.join(y_toks_lf))
            new_entry = Derivation(deriv.example, deriv.p, deriv.y_toks,
                                   y_toks_lf, deriv.hidden_state, deriv.p_list,
                                   deriv.entity_lex_map, deriv.attention_list)
            final_finished.append(new_entry)
        return sorted(final_finished, key=lambda x: x.p, reverse=True)
コード例 #12
0
    def decode_greedy(self,
                      domain,
                      ex,
                      domain_convertor,
                      domain_controller,
                      general_controller,
                      max_len=100):
        h_t, annotations = self._encode(ex.x_inds)
        y_tok_seq = []
        p_y_seq = []  # Should be handy for error analysis
        p = 1
        break_flag = False
        action_all = self.out_vocabulary.get_action_list()

        gen_pre_action_in = ''
        gen_pre_action_class_in = 'start'
        gen_pre_arg_list_in = []
        node_dict = {}
        type_node_dict = {}
        entity_node_dict = {}
        operation_dict = {}
        edge_dict = {}
        return_node = {}
        db_triple = {}
        fun_trace_list_in = []

        for i in range(max_len):
            write_dist, c_t, alpha = self._decoder_write(annotations, h_t)
            legal_dist_gen = self.get_legal_action_list(
                general_controller, gen_pre_action_class_in,
                gen_pre_arg_list_in, gen_pre_action_in, node_dict,
                type_node_dict, entity_node_dict, operation_dict, edge_dict,
                return_node, db_triple, fun_trace_list_in, action_all)

            legal_dist_dom = self.get_legal_action_list(
                domain_controller, gen_pre_action_class_in,
                gen_pre_arg_list_in, gen_pre_action_in, node_dict,
                type_node_dict, entity_node_dict, operation_dict, edge_dict,
                return_node, db_triple, fun_trace_list_in, action_all)

            final_dist = write_dist * legal_dist_gen * legal_dist_dom
            #print('write_dist: ', write_dist)
            #print('legal_dist_gen: ', legal_dist_gen)
            #print('legal_dist_dom: ', legal_dist_dom)
            #print('final_dist: ', final_dist)
            y_t = numpy.argmax(final_dist)

            p_y_t = write_dist[y_t]
            p_y_seq.append(p_y_t)
            p *= p_y_t
            if self.out_vocabulary.action_is_end(y_t):
                break_flag = True
            y_tok = self.out_vocabulary.get_action(y_t)
            if y_t >= self.out_vocabulary.all_size():
                print('error in attention for out vocabulary')
            y_tok_seq.append(y_tok)
            action_token = y_tok
            gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \
              general_controller.is_legal_action_then_read(gen_pre_action_class_in, gen_pre_arg_list_in, action_token,
                                                           gen_pre_action_in, node_dict, type_node_dict, entity_node_dict,
                                                           operation_dict, edge_dict, return_node, db_triple,
                                                           fun_trace_list_in)
            gen_pre_action_class_in = gen_pre_action_class_out
            gen_pre_arg_list_in = gen_pre_arg_list_out
            gen_pre_action_in = gen_pre_action_out
            fun_trace_list_in = fun_trace_list_out

            if break_flag:
                break
            h_t = self._decoder_step(y_t, c_t, h_t)
        y_tok_lf = domain_convertor(' '.join(y_tok_seq), domain_controller,
                                    general_controller)
        return [Derivation(ex, p, y_tok_seq, y_tok_lf)]
コード例 #13
0
# -*- coding: utf-8 -*-

from derivation import Derivation, InvalidStep
from wff_quick import is_well_formed_formula

# Page 188:
d = Derivation()
with d.fantasy('<p=0⊃~~p=0>') as f:
    f.step('<~~~p=0⊃~p=0>')  # contrapositive
    f.step('<~p=0⊃~p=0>')    # double-tilde
    f.step('<p=0∨~p=0>')     # switcheroo

# Pages 189–190:
d = Derivation()
with d.fantasy('<<p=0⊃q=0>∧<~p=0⊃q=0>>') as f:
    f.step('<p=0⊃q=0>')     # separation
    f.step('<~q=0⊃~p=0>')   # contrapositive
    f.step('<~p=0⊃q=0>')    # separation
    f.step('<~q=0⊃~~p=0>')  # contrapositive
    with f.fantasy('~q=0') as g:  # push again
        g.step('~q=0')            # premise
        g.step('<~q=0⊃~p=0>')     # carry-over of line 4
        g.step('~p=0')            # detachment
        g.step('<~q=0⊃~~p=0>')    # carry-over of line 6
        g.step('~~p=0')           # detachment
        g.step('<~p=0∧~~p=0>')    # joining
        g.step('~<p=0∨~p=0>')     # De Morgan
    f.step('<~q=0⊃~<p=0∨~p=0>>')  # fantasy rule
    f.step('<<p=0∨~p=0>⊃q=0>')  # contrapositive
    with f.fantasy('~p=0') as g:
        pass
コード例 #14
0
ファイル: attention.py プロジェクト: zhuang-li/Seq2Act
  def decode_beam(self, domain, ex, domain_convertor, domain_controller, general_controller, beam_size=1, max_len=100):
    h_t, annotations = self._encode(ex.x_inds)
    beam = [[Derivation(ex, 1, [], [], hidden_state=h_t,p_list=[],
                        attention_list=[], copy_list=[], copy_entity_list=ex.copy_toks)]]
    finished = []
    final_finished = []
    action_all_raw = self.out_vocabulary.get_action_list()
    action_all = action_all_raw[:self.out_vocabulary.size()]
    for action in action_all_raw[self.out_vocabulary.size():]:
        action_all.append('<COPY>')
    copy_entity_list = ex.copy_toks


    for i in range(1, max_len):
      #print >> sys.stderr, 'decode_beam: length = %d' % i
      if len(beam[i-1]) == 0: break
      # See if beam_size-th finished deriv is best than everything on beam now.
      if len(finished) >= beam_size:
        finished_p = finished[beam_size-1].p
        cur_best_p = beam[i-1][0].p
        if cur_best_p < finished_p:
          break
      new_beam = []

      for deriv in beam[i-1]:
        cur_p = deriv.p
        expanded_action_all = action_all
        h_t = deriv.hidden_state
        y_tok_seq = deriv.y_toks
        p_list = deriv.p_list
        attention_list = deriv.attention_list
        copy_list = deriv.copy_list
        if self.spec.attention_copying:
            added_action_list = []
            for copy_item in copy_entity_list:
                if copy_item == '<COPY>':
                    added_action_list.append('<COPY>')
                else:
                    new_action = 'add_entity_node:-:' + copy_item
                    added_action_list.append(new_action)
            added_action_list.append('<COPY>')
            #print('added_action_list: ', added_action_list)
            expanded_action_all = expanded_action_all + added_action_list

        gen_pre_action_for_test = deriv.gen_pre_action_in_deriv
        gen_pre_action_class_for_test = deriv.gen_pre_action_class_in_deriv
        gen_pre_arg_list_for_test = copy.deepcopy(deriv.gen_pre_arg_list_in_deriv)

        node_dict_for_test = copy.deepcopy(deriv.node_dict_in_deriv)
        type_node_dict_for_test = copy.deepcopy(deriv.type_node_dict_in_deriv)
        entity_node_dict_for_test = copy.deepcopy(deriv.entity_node_dict_in_deriv)
        operation_dict_for_test = copy.deepcopy(deriv.operation_dict_in_deriv)
        edge_dict_for_test = copy.deepcopy(deriv.edge_dict_in_deriv)
        return_node_for_test = copy.deepcopy(deriv.return_node_in_deriv)
        db_triple_for_test = copy.deepcopy(deriv.db_triple_in_deriv)
        fun_trace_list_for_test = copy.deepcopy(deriv.fun_trace_list_in_deriv)
        #print('***************************************')
        #print('y_tok_seq: ', y_tok_seq)
        #print('pre_action_for_test: ', gen_pre_action_for_test)
        #print('pre_action_class_for_test: ', gen_pre_action_class_for_test)
        #print('pre_arg_list_for_test: ', gen_pre_arg_list_for_test)
        #print('type_node_dict_for_test: ', type_node_dict_for_test)
        #print('edge_dict_for_test: ', edge_dict_for_test)
        #print('node_dict_for_test: ', node_dict_for_test)
        #print('entity_node_dict_for_test: ', entity_node_dict_for_test)
        #print('db_trible_for_test: ', db_triple_for_test)
        #print('operation_dict_for_test: ', operation_dict_for_test)
        #print('return_node_for_test: ', return_node_for_test)
        #print('fun_trace_list_for_test: ', fun_trace_list_for_test)

        write_dist, c_t, alpha = self._decoder_write(annotations, h_t)

        legal_dist_gen = self.get_legal_action_list(general_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test,
                                                     gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test,
                                                     operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test,
                                                     fun_trace_list_for_test, expanded_action_all)
        expanded_action_all_for_domain = []
        for ii in range(len(legal_dist_gen)):
            if legal_dist_gen[ii]:
                expanded_action_all_for_domain.append(expanded_action_all[ii])
            else:
                expanded_action_all_for_domain.append('<COPY>')


        legal_dist_dom = self.get_legal_action_list(domain_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test,
                                                     gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test,
                                                     operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test,
                                                     fun_trace_list_for_test, expanded_action_all_for_domain)

        #print('write_dist: (', len(write_dist), ') ', write_dist)
        #print('legal_dist_gen: (', len(legal_dist_gen), ') ', legal_dist_gen)
        #print('legal_dist_dom: (', len(legal_dist_dom), ') ', legal_dist_dom)

        final_dist = write_dist * legal_dist_gen * legal_dist_dom

        #print('final_dist: (', len(final_dist), ') ', final_dist)


        sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(final_dist)],
                             reverse=True)

        for j in range(beam_size):
          gen_pre_action_for_read = gen_pre_action_for_test
          gen_pre_action_class_for_read = gen_pre_action_class_for_test
          gen_pre_arg_list_for_read = copy.deepcopy(gen_pre_arg_list_for_test)

          node_dict_for_read = copy.deepcopy(node_dict_for_test)
          type_node_dict_for_read = copy.deepcopy(type_node_dict_for_test)
          entity_node_dict_for_read = copy.deepcopy(entity_node_dict_for_test)
          operation_dict_for_read = copy.deepcopy(operation_dict_for_test)
          edge_dict_for_read = copy.deepcopy(edge_dict_for_test)
          return_node_for_read = copy.deepcopy(return_node_for_test)
          db_triple_for_read = copy.deepcopy(db_triple_for_test)
          fun_trace_list_for_read = copy.deepcopy(fun_trace_list_for_test)

          #print('--------------')
          #print('pre_action_for_read: ', gen_pre_action_for_read)
          #print('pre_action_class_for_read: ', gen_pre_action_class_for_read)
          #print('pre_arg_list_for_read: ', gen_pre_arg_list_for_read)
          #print('type_node_dict_for_read: ', type_node_dict_for_read)
          #print('edge_dict_for_read: ', edge_dict_for_read)
          #print('node_dict_for_read: ', node_dict_for_read)
          #print('entity_node_dict_for_read: ', entity_node_dict_for_read)
          #print('db_trible_for_read: ', db_triple_for_read)
          #print('operation_dict_for_read: ', operation_dict_for_read)
          #print('return_node_for_read: ', return_node_for_read)
          #print('fun_trace_list_for_read: ', fun_trace_list_for_read)
          #print('--------------')

          p_y_t, y_t = sorted_dist[j]
          if p_y_t == 0.0:
            continue
          new_p = cur_p * p_y_t
          append_flag = False
          if self.out_vocabulary.action_is_end(domain, y_t):
            append_flag = True
          if y_t < self.out_vocabulary.all_size():
            do_copy = 0
            y_tok = self.out_vocabulary.get_action(y_t)
          else:
            do_copy = 1
            new_index = y_t - self.out_vocabulary.all_size()
            y_tok = 'add_entity_node:-:' + ex.copy_toks[new_index]
            y_t = self.out_vocabulary.get_index(y_tok)
          new_h_t = self._decoder_step(y_t, c_t, h_t)
          #print('y_tok: ', y_tok, ' p_y_t: ', p_y_t)
          action_token = y_tok
          gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \
            general_controller.is_legal_action_then_read(gen_pre_action_class_for_read, gen_pre_arg_list_for_read, action_token,
                                                         gen_pre_action_for_read, node_dict_for_read, type_node_dict_for_read, entity_node_dict_for_read,
                                                         operation_dict_for_read, edge_dict_for_read, return_node_for_read, db_triple_for_read,
                                                         fun_trace_list_for_read)
          if not gen_flag:
            print('test is right, but read is wrong!')
            continue
          if append_flag:
            finished.append(Derivation(ex, new_p, y_tok_seq + [y_tok], [], p_list=p_list+[p_y_t],
                                       attention_list=attention_list + [alpha], copy_list=copy_list + [do_copy], copy_entity_list=copy_entity_list,
                                 gen_pre_action_in_deriv = gen_pre_action_out, gen_pre_action_class_in_deriv = gen_pre_action_class_out,
                                 gen_pre_arg_list_in_deriv = gen_pre_arg_list_out, node_dict_in_deriv = node_dict_for_read, type_node_dict_in_deriv = type_node_dict_for_read,
                                 entity_node_dict_in_deriv = entity_node_dict_for_read, operation_dict_in_deriv = operation_dict_for_read,
                                 edge_dict_in_deriv = edge_dict_for_read, return_node_in_deriv = return_node_for_read,
                                  db_triple_in_deriv = db_triple_for_read, fun_trace_list_in_deriv = fun_trace_list_out))
            continue
          new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok], [],
                                 hidden_state=new_h_t, p_list=p_list+[p_y_t],
                                 attention_list=attention_list + [alpha], copy_list=copy_list + [do_copy], copy_entity_list=copy_entity_list,
                                 gen_pre_action_in_deriv = gen_pre_action_out, gen_pre_action_class_in_deriv = gen_pre_action_class_out,
                                 gen_pre_arg_list_in_deriv = gen_pre_arg_list_out, node_dict_in_deriv = node_dict_for_read, type_node_dict_in_deriv = type_node_dict_for_read,
                                 entity_node_dict_in_deriv = entity_node_dict_for_read, operation_dict_in_deriv = operation_dict_for_read,
                                 edge_dict_in_deriv = edge_dict_for_read, return_node_in_deriv = return_node_for_read,
                                  db_triple_in_deriv = db_triple_for_read, fun_trace_list_in_deriv = fun_trace_list_out)
          new_beam.append(new_entry)

      new_beam.sort(key=lambda x: x.p, reverse=True)
      beam.append(new_beam[:beam_size])
      finished.sort(key=lambda x: x.p, reverse=True)
    for deriv in finished:
      y_toks_lf = domain_convertor(' '.join(deriv.y_toks), domain_controller, general_controller)
      new_entry = Derivation(deriv.example, deriv.p, deriv.y_toks, y_toks_lf, \
                             deriv.hidden_state, deriv.p_list, deriv.attention_list, deriv.copy_list, deriv.copy_entity_list)
      final_finished.append(new_entry)
    return sorted(final_finished, key=lambda x: x.p, reverse=True)