def decode_beam(self, ex, beam_size=1, max_len=100): h_t = self._encode(ex.x_inds) beam = [[Derivation(ex, 1, [], hidden_state=h_t, attention_list=[], copy_list=[])]] finished = [] for i in range(1, max_len): new_beam = [] for deriv in beam[i-1]: cur_p = deriv.p h_t = deriv.hidden_state y_tok_seq = deriv.y_toks write_dist = self._decoder_write(h_t) sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(write_dist)], reverse=True) for j in range(beam_size): p_y_t, y_t = sorted_dist[j] new_p = cur_p * p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: finished.append(Derivation(ex, new_p, y_tok_seq)) continue y_tok = self.out_vocabulary.get_word(y_t) new_h_t = self._decoder_step(y_t, h_t) new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok], hidden_state=new_h_t) new_beam.append(new_entry) new_beam.sort(key=lambda x: x.p, reverse=True) beam.append(new_beam[:beam_size]) finished.sort(key=lambda x: x.p, reverse=True) return sorted(finished, key=lambda x: x.p, reverse=True)
def decode_beam(self, ex, beam_size=1, max_len=100): h_t, annotations = self._encode(ex.x_inds) beam = [[Derivation(ex, 1, [], hidden_state=h_t, attention_list=[], copy_list=[])]] finished = [] for i in range(1, max_len): #print >> sys.stderr, 'decode_beam: length = %d' % i if len(beam[i-1]) == 0: break # See if beam_size-th finished deriv is best than everything on beam now. if len(finished) >= beam_size: finished_p = finished[beam_size-1].p cur_best_p = beam[i-1][0].p if cur_best_p < finished_p: break new_beam = [] for deriv in beam[i-1]: cur_p = deriv.p h_t = deriv.hidden_state y_tok_seq = deriv.y_toks attention_list = deriv.attention_list copy_list = deriv.copy_list write_dist, c_t, alpha = self._loc_decoder_write(annotations, h_t) #for p_y_t,y_t in enumerate(write_dist): # y_tok = self.out_vocabulary.get_word(y_t) # print(y_tok) # input('im here ..') sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(write_dist)], reverse=True) for j in range(beam_size): p_y_t, y_t = sorted_dist[j] new_p = cur_p * p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: finished.append(Derivation(ex, new_p, y_tok_seq, attention_list=attention_list + [alpha], copy_list=copy_list + [0])) continue if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) do_copy = 0 else: new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok) do_copy = 1 new_h_t = self._decoder_step(y_t, c_t, h_t) new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok], hidden_state=new_h_t, attention_list=attention_list + [alpha], copy_list=copy_list + [do_copy]) new_beam.append(new_entry) new_beam.sort(key=lambda x: x.p, reverse=True) beam.append(new_beam[:beam_size]) finished.sort(key=lambda x: x.p, reverse=True) return sorted(finished, key=lambda x: x.p, reverse=True)
def decode_greedy(self, ex, max_len=100): h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 for i in range(max_len): write_dist, c_t, alpha = self._decoder_write(annotations, h_t) y_t = numpy.argmax(write_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: break if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) else: new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [ Vocabulary.END_OF_SENTENCE ] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok) y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, ex.d_inds[0], c_t, h_t) return [Derivation(ex, p, y_tok_seq)]
def decode_by_em(self, ex, max_len=100): h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 print(ex.x_str) for i in range(max_len): # step 1 write_dist, c_t, alpha = self._decoder_write(annotations, h_t) #write_test = write_dist[(self.spec.domain_size-1)*self.out_vocabulary.size():] domain_scores_h_t = T.nnet.softmax( self.spec.get_domain_scores_h_t(h_t)) domain_scores_c_t = T.nnet.softmax( self.spec.get_domain_scores_c_t(c_t)) #d_t = numpy.argmax(domain_scores_h_t) y_t = numpy.argmax(write_dist) #FIXME #p_y_t = domain_scores[d_t]*write_dist[y_t] #FIXME p_y_t = write_dist[y_t] #FIXME p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: break if y_t < self.out_vocabulary.size(): #if y_t < self.out_vocabulary.size(): #y_t = (y_t%self.out_vocabulary.size()) y_tok = self.out_vocabulary.get_word(y_t) else: new_ind = y_t - self.out_vocabulary.size() #new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [ Vocabulary.END_OF_SENTENCE ] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok) '''elif y_t>=self.out_vocabulary.size() and y_t<(self.out_vocabulary.size()+self.in_vocabulary.size()): new_ind = y_t -self.out_vocabulary.size() y_tok = self.in_vocabulary.get_word(new_ind) print('History >') print(y_tok,write_dist[y_t]) y_t = self.out_vocabulary.get_index(y_tok) else: new_ind = y_t - (self.out_vocabulary.size()+self.in_vocabulary.size()) augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] y_tok = augmented_copy_toks[new_ind] print('Local >') print(y_tok,write_dist[y_t]) y_t = self.out_vocabulary.get_index(y_tok)''' y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, c_t, h_t) return [Derivation(ex, p, y_tok_seq)]
def decode_greedy(self, ex, max_len=100): return "TEST" h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 for i in range(100): #max_len): write_dist, c_t, alpha = self._decoder_write(annotations, h_t) for j in range(len(write_dist)): y_t = j x_t = 0 if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) if y_t >= self.out_vocabulary.size(): new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [ Vocabulary.END_OF_SENTENCE ] y_tok = augmented_copy_toks[new_ind] x_t = self.in_vocabulary.get_index(y_tok) if x_t in self.spec.pair_stat: for y_, p_xy in self.spec.pair_stat[x_t]: write_dist[y_] = 1.0 * write_dist[y_] + 0.0 * p_xy #write_dist[y_] = 0.9*write_dist[y_]+0.1*p_xy #write_dist[y_] = 0.8*write_dist[y_]+0.2*p_xy #write_dist[y_] = 0.7*write_dist[y_]+0.3*p_xy #write_dist[y_] = 0.6*write_dist[y_]+0.4*p_xy #write_dist[y_] = 0.5*write_dist[y_]+0.5*p_xy y_t = numpy.argmax(write_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: break if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) else: new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [ Vocabulary.END_OF_SENTENCE ] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok) y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, c_t, h_t) return [Derivation(ex, p, y_tok_seq)]
def decode_greedy(self, ex, max_len=100): h_t = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 for i in range(max_len): write_dist = self._decoder_write(h_t) y_t = numpy.argmax(write_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: break y_tok = self.out_vocabulary.get_word(y_t) y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, h_t) return [Derivation(ex, p, y_tok_seq)]
def getBoundDerivation(cfg, string, positionFunction): queue = q.Queue() queue.put((cfg[3], Derivation(entries=[(cfg[3], 0)]))) while (not queue.empty()): current, derivation = queue.get() if current == string: derivation.add(current, -1) derivation.clean() return derivation elif positionFunction(current, cfg[0]) != -1: pos = positionFunction(current, cfg[0]) for v in cfg[2][current[pos]]: newDerivation = copy.deepcopy(derivation) newDerivation.add(current, pos) queue.put((replaceAtIndexWithString(current, pos, v), newDerivation)) else: pass
def getAllDerivationsRecursive(cfg, string, currentString, maxSubs, maxDevs, derivations, subs, currentDevs): if len(derivations) == maxDevs or maxSubs == subs: pass elif string == currentString: currentDevs.append((currentString, -1)) derivations.append(Derivation(entries=currentDevs)) currentDevs.pop() else: numberOfElements = len(currentString) for i in range(len(currentString)): if (currentString[i] in cfg[0]): for v in cfg[2][currentString[i]]: cs = replaceAtIndexWithString(currentString, i, v) currentDevs.append((currentString, i)) getAllDerivationsRecursive(cfg, string, cs, maxSubs, maxDevs, derivations, subs + 1, currentDevs) currentDevs.pop() return derivations
def top_down_beam_search(self, words, oracle_actions, oracle_tokens, buffer, stack_top, action_top, beam_size=5): finished = [] stack = [] action = self._ROOT nt = self.nt_vocab[oracle_tokens.pop(0)] nt_embedding = self.nt_input_layer(self.nt_lookup[nt]) stack_state = stack_top.add_input(nt_embedding) stack.append((stack_state, 'p', nt_embedding)) action_embedding = self.act_input_layer(self.act_lookup[action]) action_top = action_top.add_input(action_embedding) # first, set up the initial beam beam = [Derivation(stack, action_top, [self.act_vocab.token(action)], [self.nt_vocab.token(nt)], 0, nt_allowed=1, ter_allowed=0, reducable=0, total_nt=1)] # loop until we obtain enough finished beam while len(finished) < beam_size: new_beam = [] unfinished = [] # collect all possible expanded beam for der in beam: valid_actions = [] if len(der.stack) >= 1: if der.ter_allowed == 1: valid_actions += [self._TER] if der.nt_allowed == 1: valid_actions += [self._NT] + self._NT_general if len(der.stack) >= 2 and der.reducable != 0: valid_actions += [self._ACT] stack_embedding = der.stack[-1][0].output() action_summary = der.action_top.output() word_weights = self.attention(stack_embedding, buffer) buffer_embedding, _ = attention_output(buffer, word_weights, 'soft_average') for i in range(len(der.stack)): if der.stack[len(der.stack)-1-i][1] == 'p': parent_embedding = der.stack[len(der.stack)-1-i][2] break parser_state = dy.concatenate([buffer_embedding, stack_embedding, parent_embedding, action_summary]) h = self.mlp_layer(parser_state) log_probs = dy.log_softmax(self.act_proj_layer(h), valid_actions) sorted_actions_logprobs = sorted(enumerate(log_probs.vec_value()), key=itemgetter(1), reverse=True) sorted_actions = [x[0] for x in sorted_actions_logprobs if x[1] > -999] sorted_logprobs = [x[1] for x in sorted_actions_logprobs if x[1] > -999] if len(sorted_actions) > beam_size: sorted_actions = sorted_actions[:beam_size] sorted_logprobs = sorted_logprobs[:beam_size] for action, logprob in zip(sorted_actions, sorted_logprobs): if action == self._NT: output_feature, output_logprob = attention_output(buffer, word_weights, self.test_selection, argmax=True) log_probs_nt = dy.log_softmax(self.nt_proj_layer(output_feature)) sorted_nt_logprobs = sorted(enumerate(log_probs_nt.vec_value()), key=itemgetter(1), reverse=True) for i in range(beam_size): new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, der.logp+logprob+sorted_nt_logprobs[i][1], 1, 1, 1, der.total_nt, action, sorted_nt_logprobs[i][0])) elif action == self._TER: output_feature, output_logprob = attention_output(buffer, word_weights, self.test_selection, argmax=True) log_probs_ter = dy.log_softmax(self.ter_proj_layer(output_feature)) sorted_ter_logprobs = sorted(enumerate(log_probs_ter.vec_value()), key=itemgetter(1), reverse=True) for i in range(beam_size): new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, der.logp+logprob+sorted_ter_logprobs[i][1], 1, 1, 1, der.total_nt, action, sorted_ter_logprobs[i][0])) else: new_beam.append(Derivation(der.stack, der.action_top, der.output_actions, der.output_tokens, der.logp+logprob, 1, 1, 1, der.total_nt, next_act=action)) # sort these expanded beam, keep only the top k new_beam.sort(key=lambda x: x.logp, reverse=True) if len(new_beam) > beam_size: new_beam = new_beam[:beam_size] # execute and update the top k remaining beam for i in range(len(new_beam)): action = new_beam[i].next_act if action == self._NT: nt = new_beam[i].next_tok nt_embedding = self.nt_input_layer(self.nt_lookup[nt]) stack_state, label, _ = new_beam[i].stack[-1] stack_state = stack_state.add_input(nt_embedding) new_beam[i].stack.append((stack_state, 'p', nt_embedding)) new_beam[i].output_actions.append(self.act_vocab.token(action)) new_beam[i].output_tokens.append(self.nt_vocab.token(nt)) new_beam[i].total_nt += 1 elif action == self._TER: ter = new_beam[i].next_tok ter_embedding = self.ter_input_layer(self.ter_lookup[ter]) stack_state, label, _ = new_beam[i].stack[-1] stack_state = stack_state.add_input(ter_embedding) new_beam[i].stack.append((stack_state, 'c', ter_embedding)) new_beam[i].output_actions.append(self.act_vocab.token(action)) new_beam[i].output_tokens.append(self.ter_vocab.token(ter)) elif action in self._NT_general: nt = self.act_vocab.token(new_beam[i].next_act).rstrip(')').lstrip('NT(') nt = self.nt_vocab[nt] nt_embedding = self.nt_input_layer(self.nt_lookup[nt]) stack_state, label, _ = new_beam[i].stack[-1] stack_state = stack_state.add_input(nt_embedding) new_beam[i].stack.append((stack_state, 'p', nt_embedding)) new_beam[i].output_actions.append(self.act_vocab.token(action)) new_beam[i].output_tokens.append(self.nt_vocab.token(nt)) new_beam[i].total_nt += 1 else: found_p = 0 path_input = [] while found_p != 1: top = new_beam[i].stack.pop() top_raw_rep, top_label, top_rep = top[2], top[1], top[0] path_input.append(top_raw_rep) if top_label == 'p': found_p = 1 parent_rep = path_input.pop() composed_rep = self.subtree_input_layer(dy.concatenate([dy.average(path_input), parent_rep])) stack_state = new_beam[i].stack[-1][0] if new_beam[i].stack else stack_top stack_state = stack_state.add_input(composed_rep) new_beam[i].stack.append((stack_state, 'c', composed_rep)) reduced = 1 new_beam[i].output_actions.append(self.act_vocab.token(action)) # appended the beam to the finished set if it is completed if len(new_beam[i].stack) == 1: finished.append(new_beam[i]) continue # if not completed, proceed action_embedding = self.act_input_layer(self.act_lookup[action]) new_beam[i].action_top = new_beam[i].action_top.add_input(action_embedding) reducable = 1 nt_allowed = 1 ter_allowed = 1 #reduce cannot follow nt if new_beam[i].stack[-1][1] == 'p': reducable = 0 #nt is disabled if maximum open non-terminal allowed is reached count_p = 0 for item in new_beam[i].stack: if item[1] == 'p': count_p += 1 if count_p >= 10: nt_allowed = 0 if len(new_beam[i].stack) > len(words) or new_beam[i].total_nt > len(words): nt_allowed = 0 #ter is disabled if maximum children under the open nt is reached count_c = 0 for item in new_beam[i].stack[::-1]: if item[1] == 'c': count_c += 1 else: break if count_c >= 10: ter_allowed = 0 new_beam[i].nt_allowed = nt_allowed new_beam[i].ter_allowed = ter_allowed new_beam[i].reducable = reducable new_beam[i].next_act = None new_beam[i].next_tok = None unfinished.append(new_beam[i]) beam = unfinished finished.sort(key=lambda x: x.logp, reverse=True) return finished[0].output_actions, finished[0].output_tokens
def decode_by_em(self,ex,max_len=100): h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 t,a = self.spec.em_model #s_total[e] += t[(e, f)] * a[(i, j, l_e, l_f)] l_f =0 print(ex.x_str) for i in range(max_len): # step 1 from ibm2 import ibm2 write_dist, c_t, alpha = self._decoder_write(annotations, h_t) #print('first:') #for j in range(self.out_vocabulary.size()): # print(j,self.out_vocabulary.get_word(j),write_dist[j]) y_t = numpy.argmax(write_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: l_f = i # length of the predicted sequence break if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) elif y_t>=self.out_vocabulary.size() and y_t<(self.out_vocabulary.size()+self.in_vocabulary.size()): new_ind = y_t -self.out_vocabulary.size() y_tok = self.in_vocabulary.get_word(new_ind) print('----> second:') print(y_tok,write_dist[y_t]) y_t = self.out_vocabulary.get_index(y_tok) else: new_ind = y_t - (self.out_vocabulary.size()+self.in_vocabulary.size()) augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] y_tok = augmented_copy_toks[new_ind] print('---> third:') print(y_tok,write_dist[y_t]) y_t = self.out_vocabulary.get_index(y_tok) # liang '''if y_t == Vocabulary.END_OF_SENTENCE_INDEX: l_f = i # length of the predicted sequence break if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) else: new_ind = y_t - (self.out_vocabulary.size()) augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok)''' y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, c_t, h_t) '''print('first: ',y_tok_seq) es = ex.x_str.split() fs = y_tok_seq args =(es,fs,t,a) ibmmodel=ibm2([(ex.x_str,ex.y_str)]) print(ibmmodel.show_matrix(*args)) h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 for i in range(max_len):# step 2 write_dist, c_t, alpha = self._decoder_write(annotations, h_t) for j in range(len(write_dist)): y_t = j x_t = 0 if y_t >= self.out_vocabulary.size(): new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] x_tok = augmented_copy_toks[new_ind] x_t = self.in_vocabulary.get_index(x_tok) if x_t in self.spec.pair_stat: for y_,p_xy in self.spec.pair_stat[x_t]: y_tok_ = self.out_vocabulary.get_word(y_) write_dist[y_] = 1.0*write_dist[y_]+0.0*float(a[(new_ind,i,len(ex.x_str),l_f)]) y_t = numpy.argmax(write_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if y_t == Vocabulary.END_OF_SENTENCE_INDEX: break if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_word(y_t) else: new_ind = y_t - self.out_vocabulary.size() augmented_copy_toks = ex.copy_toks + [Vocabulary.END_OF_SENTENCE] y_tok = augmented_copy_toks[new_ind] y_t = self.out_vocabulary.get_index(y_tok) y_tok_seq.append(y_tok) h_t = self._decoder_step(y_t, c_t, h_t) print('second: ',y_tok_seq) es = ex.x_str.split() fs = y_tok_seq args =(es,fs,t,a) ibmmodel=ibm2([(ex.x_str,ex.y_str)]) print(ibmmodel.show_matrix(*args))''' #input('here ..') return [Derivation(ex, p, y_tok_seq)]
def decode_beam(self, domain, ex, domain_convertor, domain_controller, general_controller, beam_size=1, max_len=100): h_t, annotations = self._encode(ex.x_inds) beam = [[ Derivation(ex, 1, [], [], hidden_state=h_t, p_list=[], entity_lex_map=ex.entity_lex_map, attention_list=[]) ]] finished = [] final_finished = [] action_all = self.out_vocabulary.get_action_list() #print('entity_lex in decode_beam = %s' % ex.entity_lex_map) for i in range(1, max_len): #print >> sys.stderr, 'decode_beam: length = %d' % i if len(beam[i - 1]) == 0: break # See if beam_size-th finished deriv is best than everything on beam now. if len(finished) >= beam_size: finished_p = finished[beam_size - 1].p cur_best_p = beam[i - 1][0].p if cur_best_p < finished_p: break new_beam = [] for deriv in beam[i - 1]: cur_p = deriv.p h_t = deriv.hidden_state y_tok_seq = deriv.y_toks p_list = deriv.p_list attention_list = deriv.attention_list entity_lex_map = deriv.entity_lex_map #print('entity_lex in loop for beam = %s' % entity_lex_map) gen_pre_action_for_test = deriv.gen_pre_action_in_deriv gen_pre_action_class_for_test = deriv.gen_pre_action_class_in_deriv gen_pre_arg_list_for_test = copy.deepcopy( deriv.gen_pre_arg_list_in_deriv) node_dict_for_test = copy.deepcopy(deriv.node_dict_in_deriv) type_node_dict_for_test = copy.deepcopy( deriv.type_node_dict_in_deriv) entity_node_dict_for_test = copy.deepcopy( deriv.entity_node_dict_in_deriv) operation_dict_for_test = copy.deepcopy( deriv.operation_dict_in_deriv) edge_dict_for_test = copy.deepcopy(deriv.edge_dict_in_deriv) return_node_for_test = copy.deepcopy( deriv.return_node_in_deriv) db_triple_for_test = copy.deepcopy(deriv.db_triple_in_deriv) fun_trace_list_for_test = copy.deepcopy( deriv.fun_trace_list_in_deriv) #print('***************************************') #print('y_tok_seq: %s' % ' '.join(y_tok_seq)) #print('pre_action_for_test: ', gen_pre_action_for_test) #print('pre_action_class_for_test: ', gen_pre_action_class_for_test) #print('pre_arg_list_for_test: ', gen_pre_arg_list_for_test) #print('type_node_dict_for_test: ', type_node_dict_for_test) #print('edge_dict_for_test: ', edge_dict_for_test) #print('node_dict_for_test: ', node_dict_for_test) #print('entity_node_dict_for_test: ', entity_node_dict_for_test) #print('db_trible_for_test: ', db_triple_for_test) #print('operation_dict_for_test: ', operation_dict_for_test) #print('return_node_for_test: ', return_node_for_test) #print('fun_trace_list_for_test: ', fun_trace_list_for_test) write_dist, c_t, alpha = self._decoder_write(annotations, h_t) legal_dist_gen = self.get_legal_action_list( general_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test, gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test, operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test, fun_trace_list_for_test, action_all, entity_lex_map=entity_lex_map) action_all_for_domain = [] for ii in range(len(legal_dist_gen)): if legal_dist_gen[ii]: action_all_for_domain.append(action_all[ii]) else: action_all_for_domain.append('<COPY>') legal_dist_dom = self.get_legal_action_list( domain_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test, gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test, operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test, fun_trace_list_for_test, action_all_for_domain, entity_lex_map=entity_lex_map) #print('write_dist: (', len(write_dist), ') ', write_dist) #print('legal_dist_gen: (', len(legal_dist_gen), ') ', legal_dist_gen) #print('legal_dist_dom: (', len(legal_dist_dom), ') ', legal_dist_dom) final_dist = write_dist * legal_dist_gen * legal_dist_dom #print('final_dist: (', len(final_dist), ') ', final_dist) sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(final_dist)], reverse=True) for j in range(beam_size): gen_pre_action_for_read = gen_pre_action_for_test gen_pre_action_class_for_read = gen_pre_action_class_for_test gen_pre_arg_list_for_read = copy.deepcopy( gen_pre_arg_list_for_test) node_dict_for_read = copy.deepcopy(node_dict_for_test) type_node_dict_for_read = copy.deepcopy( type_node_dict_for_test) entity_node_dict_for_read = copy.deepcopy( entity_node_dict_for_test) operation_dict_for_read = copy.deepcopy( operation_dict_for_test) edge_dict_for_read = copy.deepcopy(edge_dict_for_test) return_node_for_read = copy.deepcopy(return_node_for_test) db_triple_for_read = copy.deepcopy(db_triple_for_test) fun_trace_list_for_read = copy.deepcopy( fun_trace_list_for_test) #print('--------------') #print('pre_action_for_read: ', gen_pre_action_for_read) #print('pre_action_class_for_read: ', gen_pre_action_class_for_read) #print('pre_arg_list_for_read: ', gen_pre_arg_list_for_read) #print('type_node_dict_for_read: ', type_node_dict_for_read) #print('edge_dict_for_read: ', edge_dict_for_read) #print('node_dict_for_read: ', node_dict_for_read) #print('entity_node_dict_for_read: ', entity_node_dict_for_read) #print('db_trible_for_read: ', db_triple_for_read) #print('operation_dict_for_read: ', operation_dict_for_read) #print('return_node_for_read: ', return_node_for_read) #print('fun_trace_list_for_read: ', fun_trace_list_for_read) #print('--------------') p_y_t, y_t = sorted_dist[j] if p_y_t == 0.0: continue new_p = cur_p * p_y_t append_flag = False if self.out_vocabulary.action_is_end(domain, y_t): append_flag = True if y_t < self.out_vocabulary.size(): y_tok = self.out_vocabulary.get_action(y_t) else: print('error action index!') new_h_t = self._decoder_step(y_t, c_t, h_t) #print('y_tok: ', y_tok, ' p_y_t: ', p_y_t) action_token = y_tok gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \ general_controller.is_legal_action_then_read(gen_pre_action_class_for_read, gen_pre_arg_list_for_read, action_token, gen_pre_action_for_read, node_dict_for_read, type_node_dict_for_read, entity_node_dict_for_read, operation_dict_for_read, edge_dict_for_read, return_node_for_read, db_triple_for_read, fun_trace_list_for_read) if not gen_flag: print('test is right, but read is wrong!') continue if append_flag: finished.append( Derivation( ex, new_p, y_tok_seq + [y_tok], [], p_list=p_list + [p_y_t], entity_lex_map=entity_lex_map, attention_list=attention_list + [alpha], gen_pre_action_in_deriv=gen_pre_action_out, gen_pre_action_class_in_deriv= gen_pre_action_class_out, gen_pre_arg_list_in_deriv=gen_pre_arg_list_out, node_dict_in_deriv=node_dict_for_read, type_node_dict_in_deriv=type_node_dict_for_read, entity_node_dict_in_deriv= entity_node_dict_for_read, operation_dict_in_deriv=operation_dict_for_read, edge_dict_in_deriv=edge_dict_for_read, return_node_in_deriv=return_node_for_read, db_triple_in_deriv=db_triple_for_read, fun_trace_list_in_deriv=fun_trace_list_out)) continue new_entry = Derivation( ex, new_p, y_tok_seq + [y_tok], [], hidden_state=new_h_t, p_list=p_list + [p_y_t], entity_lex_map=entity_lex_map, attention_list=attention_list + [alpha], gen_pre_action_in_deriv=gen_pre_action_out, gen_pre_action_class_in_deriv=gen_pre_action_class_out, gen_pre_arg_list_in_deriv=gen_pre_arg_list_out, node_dict_in_deriv=node_dict_for_read, type_node_dict_in_deriv=type_node_dict_for_read, entity_node_dict_in_deriv=entity_node_dict_for_read, operation_dict_in_deriv=operation_dict_for_read, edge_dict_in_deriv=edge_dict_for_read, return_node_in_deriv=return_node_for_read, db_triple_in_deriv=db_triple_for_read, fun_trace_list_in_deriv=fun_trace_list_out) new_beam.append(new_entry) new_beam.sort(key=lambda x: x.p, reverse=True) beam.append(new_beam[:beam_size]) finished.sort(key=lambda x: x.p, reverse=True) for deriv in finished: y_new_toks = [] entity_lex = deriv.example.entity_lex_map for y_tok in deriv.y_toks: if y_tok.startswith('add_entity_node:'): entity = y_tok[y_tok.index(':-:') + 3:] if entity in entity_lex: new_entity = entity_lex[entity] y_new_tok = y_tok.replace(entity, new_entity) y_new_toks.append(y_new_tok) continue y_new_toks.append(y_tok) y_new_str = ' '.join(y_new_toks) #print('entity_lex in finished = %s' % entity_lex) #print('y_new_str: %s' % y_new_str) y_toks_lf = domain_convertor(y_new_str, domain_controller, general_controller, entity_lex_map=entity_lex) #print('y_toks_lf: %s' % ' '.join(y_toks_lf)) new_entry = Derivation(deriv.example, deriv.p, deriv.y_toks, y_toks_lf, deriv.hidden_state, deriv.p_list, deriv.entity_lex_map, deriv.attention_list) final_finished.append(new_entry) return sorted(final_finished, key=lambda x: x.p, reverse=True)
def decode_greedy(self, domain, ex, domain_convertor, domain_controller, general_controller, max_len=100): h_t, annotations = self._encode(ex.x_inds) y_tok_seq = [] p_y_seq = [] # Should be handy for error analysis p = 1 break_flag = False action_all = self.out_vocabulary.get_action_list() gen_pre_action_in = '' gen_pre_action_class_in = 'start' gen_pre_arg_list_in = [] node_dict = {} type_node_dict = {} entity_node_dict = {} operation_dict = {} edge_dict = {} return_node = {} db_triple = {} fun_trace_list_in = [] for i in range(max_len): write_dist, c_t, alpha = self._decoder_write(annotations, h_t) legal_dist_gen = self.get_legal_action_list( general_controller, gen_pre_action_class_in, gen_pre_arg_list_in, gen_pre_action_in, node_dict, type_node_dict, entity_node_dict, operation_dict, edge_dict, return_node, db_triple, fun_trace_list_in, action_all) legal_dist_dom = self.get_legal_action_list( domain_controller, gen_pre_action_class_in, gen_pre_arg_list_in, gen_pre_action_in, node_dict, type_node_dict, entity_node_dict, operation_dict, edge_dict, return_node, db_triple, fun_trace_list_in, action_all) final_dist = write_dist * legal_dist_gen * legal_dist_dom #print('write_dist: ', write_dist) #print('legal_dist_gen: ', legal_dist_gen) #print('legal_dist_dom: ', legal_dist_dom) #print('final_dist: ', final_dist) y_t = numpy.argmax(final_dist) p_y_t = write_dist[y_t] p_y_seq.append(p_y_t) p *= p_y_t if self.out_vocabulary.action_is_end(y_t): break_flag = True y_tok = self.out_vocabulary.get_action(y_t) if y_t >= self.out_vocabulary.all_size(): print('error in attention for out vocabulary') y_tok_seq.append(y_tok) action_token = y_tok gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \ general_controller.is_legal_action_then_read(gen_pre_action_class_in, gen_pre_arg_list_in, action_token, gen_pre_action_in, node_dict, type_node_dict, entity_node_dict, operation_dict, edge_dict, return_node, db_triple, fun_trace_list_in) gen_pre_action_class_in = gen_pre_action_class_out gen_pre_arg_list_in = gen_pre_arg_list_out gen_pre_action_in = gen_pre_action_out fun_trace_list_in = fun_trace_list_out if break_flag: break h_t = self._decoder_step(y_t, c_t, h_t) y_tok_lf = domain_convertor(' '.join(y_tok_seq), domain_controller, general_controller) return [Derivation(ex, p, y_tok_seq, y_tok_lf)]
# -*- coding: utf-8 -*- from derivation import Derivation, InvalidStep from wff_quick import is_well_formed_formula # Page 188: d = Derivation() with d.fantasy('<p=0⊃~~p=0>') as f: f.step('<~~~p=0⊃~p=0>') # contrapositive f.step('<~p=0⊃~p=0>') # double-tilde f.step('<p=0∨~p=0>') # switcheroo # Pages 189–190: d = Derivation() with d.fantasy('<<p=0⊃q=0>∧<~p=0⊃q=0>>') as f: f.step('<p=0⊃q=0>') # separation f.step('<~q=0⊃~p=0>') # contrapositive f.step('<~p=0⊃q=0>') # separation f.step('<~q=0⊃~~p=0>') # contrapositive with f.fantasy('~q=0') as g: # push again g.step('~q=0') # premise g.step('<~q=0⊃~p=0>') # carry-over of line 4 g.step('~p=0') # detachment g.step('<~q=0⊃~~p=0>') # carry-over of line 6 g.step('~~p=0') # detachment g.step('<~p=0∧~~p=0>') # joining g.step('~<p=0∨~p=0>') # De Morgan f.step('<~q=0⊃~<p=0∨~p=0>>') # fantasy rule f.step('<<p=0∨~p=0>⊃q=0>') # contrapositive with f.fantasy('~p=0') as g: pass
def decode_beam(self, domain, ex, domain_convertor, domain_controller, general_controller, beam_size=1, max_len=100): h_t, annotations = self._encode(ex.x_inds) beam = [[Derivation(ex, 1, [], [], hidden_state=h_t,p_list=[], attention_list=[], copy_list=[], copy_entity_list=ex.copy_toks)]] finished = [] final_finished = [] action_all_raw = self.out_vocabulary.get_action_list() action_all = action_all_raw[:self.out_vocabulary.size()] for action in action_all_raw[self.out_vocabulary.size():]: action_all.append('<COPY>') copy_entity_list = ex.copy_toks for i in range(1, max_len): #print >> sys.stderr, 'decode_beam: length = %d' % i if len(beam[i-1]) == 0: break # See if beam_size-th finished deriv is best than everything on beam now. if len(finished) >= beam_size: finished_p = finished[beam_size-1].p cur_best_p = beam[i-1][0].p if cur_best_p < finished_p: break new_beam = [] for deriv in beam[i-1]: cur_p = deriv.p expanded_action_all = action_all h_t = deriv.hidden_state y_tok_seq = deriv.y_toks p_list = deriv.p_list attention_list = deriv.attention_list copy_list = deriv.copy_list if self.spec.attention_copying: added_action_list = [] for copy_item in copy_entity_list: if copy_item == '<COPY>': added_action_list.append('<COPY>') else: new_action = 'add_entity_node:-:' + copy_item added_action_list.append(new_action) added_action_list.append('<COPY>') #print('added_action_list: ', added_action_list) expanded_action_all = expanded_action_all + added_action_list gen_pre_action_for_test = deriv.gen_pre_action_in_deriv gen_pre_action_class_for_test = deriv.gen_pre_action_class_in_deriv gen_pre_arg_list_for_test = copy.deepcopy(deriv.gen_pre_arg_list_in_deriv) node_dict_for_test = copy.deepcopy(deriv.node_dict_in_deriv) type_node_dict_for_test = copy.deepcopy(deriv.type_node_dict_in_deriv) entity_node_dict_for_test = copy.deepcopy(deriv.entity_node_dict_in_deriv) operation_dict_for_test = copy.deepcopy(deriv.operation_dict_in_deriv) edge_dict_for_test = copy.deepcopy(deriv.edge_dict_in_deriv) return_node_for_test = copy.deepcopy(deriv.return_node_in_deriv) db_triple_for_test = copy.deepcopy(deriv.db_triple_in_deriv) fun_trace_list_for_test = copy.deepcopy(deriv.fun_trace_list_in_deriv) #print('***************************************') #print('y_tok_seq: ', y_tok_seq) #print('pre_action_for_test: ', gen_pre_action_for_test) #print('pre_action_class_for_test: ', gen_pre_action_class_for_test) #print('pre_arg_list_for_test: ', gen_pre_arg_list_for_test) #print('type_node_dict_for_test: ', type_node_dict_for_test) #print('edge_dict_for_test: ', edge_dict_for_test) #print('node_dict_for_test: ', node_dict_for_test) #print('entity_node_dict_for_test: ', entity_node_dict_for_test) #print('db_trible_for_test: ', db_triple_for_test) #print('operation_dict_for_test: ', operation_dict_for_test) #print('return_node_for_test: ', return_node_for_test) #print('fun_trace_list_for_test: ', fun_trace_list_for_test) write_dist, c_t, alpha = self._decoder_write(annotations, h_t) legal_dist_gen = self.get_legal_action_list(general_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test, gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test, operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test, fun_trace_list_for_test, expanded_action_all) expanded_action_all_for_domain = [] for ii in range(len(legal_dist_gen)): if legal_dist_gen[ii]: expanded_action_all_for_domain.append(expanded_action_all[ii]) else: expanded_action_all_for_domain.append('<COPY>') legal_dist_dom = self.get_legal_action_list(domain_controller, gen_pre_action_class_for_test, gen_pre_arg_list_for_test, gen_pre_action_for_test, node_dict_for_test, type_node_dict_for_test, entity_node_dict_for_test, operation_dict_for_test, edge_dict_for_test, return_node_for_test, db_triple_for_test, fun_trace_list_for_test, expanded_action_all_for_domain) #print('write_dist: (', len(write_dist), ') ', write_dist) #print('legal_dist_gen: (', len(legal_dist_gen), ') ', legal_dist_gen) #print('legal_dist_dom: (', len(legal_dist_dom), ') ', legal_dist_dom) final_dist = write_dist * legal_dist_gen * legal_dist_dom #print('final_dist: (', len(final_dist), ') ', final_dist) sorted_dist = sorted([(p_y_t, y_t) for y_t, p_y_t in enumerate(final_dist)], reverse=True) for j in range(beam_size): gen_pre_action_for_read = gen_pre_action_for_test gen_pre_action_class_for_read = gen_pre_action_class_for_test gen_pre_arg_list_for_read = copy.deepcopy(gen_pre_arg_list_for_test) node_dict_for_read = copy.deepcopy(node_dict_for_test) type_node_dict_for_read = copy.deepcopy(type_node_dict_for_test) entity_node_dict_for_read = copy.deepcopy(entity_node_dict_for_test) operation_dict_for_read = copy.deepcopy(operation_dict_for_test) edge_dict_for_read = copy.deepcopy(edge_dict_for_test) return_node_for_read = copy.deepcopy(return_node_for_test) db_triple_for_read = copy.deepcopy(db_triple_for_test) fun_trace_list_for_read = copy.deepcopy(fun_trace_list_for_test) #print('--------------') #print('pre_action_for_read: ', gen_pre_action_for_read) #print('pre_action_class_for_read: ', gen_pre_action_class_for_read) #print('pre_arg_list_for_read: ', gen_pre_arg_list_for_read) #print('type_node_dict_for_read: ', type_node_dict_for_read) #print('edge_dict_for_read: ', edge_dict_for_read) #print('node_dict_for_read: ', node_dict_for_read) #print('entity_node_dict_for_read: ', entity_node_dict_for_read) #print('db_trible_for_read: ', db_triple_for_read) #print('operation_dict_for_read: ', operation_dict_for_read) #print('return_node_for_read: ', return_node_for_read) #print('fun_trace_list_for_read: ', fun_trace_list_for_read) #print('--------------') p_y_t, y_t = sorted_dist[j] if p_y_t == 0.0: continue new_p = cur_p * p_y_t append_flag = False if self.out_vocabulary.action_is_end(domain, y_t): append_flag = True if y_t < self.out_vocabulary.all_size(): do_copy = 0 y_tok = self.out_vocabulary.get_action(y_t) else: do_copy = 1 new_index = y_t - self.out_vocabulary.all_size() y_tok = 'add_entity_node:-:' + ex.copy_toks[new_index] y_t = self.out_vocabulary.get_index(y_tok) new_h_t = self._decoder_step(y_t, c_t, h_t) #print('y_tok: ', y_tok, ' p_y_t: ', p_y_t) action_token = y_tok gen_flag, gen_pre_action_class_out, gen_pre_arg_list_out, gen_pre_action_out, fun_trace_list_out = \ general_controller.is_legal_action_then_read(gen_pre_action_class_for_read, gen_pre_arg_list_for_read, action_token, gen_pre_action_for_read, node_dict_for_read, type_node_dict_for_read, entity_node_dict_for_read, operation_dict_for_read, edge_dict_for_read, return_node_for_read, db_triple_for_read, fun_trace_list_for_read) if not gen_flag: print('test is right, but read is wrong!') continue if append_flag: finished.append(Derivation(ex, new_p, y_tok_seq + [y_tok], [], p_list=p_list+[p_y_t], attention_list=attention_list + [alpha], copy_list=copy_list + [do_copy], copy_entity_list=copy_entity_list, gen_pre_action_in_deriv = gen_pre_action_out, gen_pre_action_class_in_deriv = gen_pre_action_class_out, gen_pre_arg_list_in_deriv = gen_pre_arg_list_out, node_dict_in_deriv = node_dict_for_read, type_node_dict_in_deriv = type_node_dict_for_read, entity_node_dict_in_deriv = entity_node_dict_for_read, operation_dict_in_deriv = operation_dict_for_read, edge_dict_in_deriv = edge_dict_for_read, return_node_in_deriv = return_node_for_read, db_triple_in_deriv = db_triple_for_read, fun_trace_list_in_deriv = fun_trace_list_out)) continue new_entry = Derivation(ex, new_p, y_tok_seq + [y_tok], [], hidden_state=new_h_t, p_list=p_list+[p_y_t], attention_list=attention_list + [alpha], copy_list=copy_list + [do_copy], copy_entity_list=copy_entity_list, gen_pre_action_in_deriv = gen_pre_action_out, gen_pre_action_class_in_deriv = gen_pre_action_class_out, gen_pre_arg_list_in_deriv = gen_pre_arg_list_out, node_dict_in_deriv = node_dict_for_read, type_node_dict_in_deriv = type_node_dict_for_read, entity_node_dict_in_deriv = entity_node_dict_for_read, operation_dict_in_deriv = operation_dict_for_read, edge_dict_in_deriv = edge_dict_for_read, return_node_in_deriv = return_node_for_read, db_triple_in_deriv = db_triple_for_read, fun_trace_list_in_deriv = fun_trace_list_out) new_beam.append(new_entry) new_beam.sort(key=lambda x: x.p, reverse=True) beam.append(new_beam[:beam_size]) finished.sort(key=lambda x: x.p, reverse=True) for deriv in finished: y_toks_lf = domain_convertor(' '.join(deriv.y_toks), domain_controller, general_controller) new_entry = Derivation(deriv.example, deriv.p, deriv.y_toks, y_toks_lf, \ deriv.hidden_state, deriv.p_list, deriv.attention_list, deriv.copy_list, deriv.copy_entity_list) final_finished.append(new_entry) return sorted(final_finished, key=lambda x: x.p, reverse=True)