def convert_to_binary_multi_subject(e1): e1_label = zeros_var_cuda([len(e1), num_labels]) e1_label_abs = zeros_var_cuda([len(e1), num_labels]) for i in range(len(e1)): e1_label[i][e1[i]] = 1 e1_label_abs[i][self.kg.get_typeid(e1[i])] = 1 return e1_label, e1_label_abs
def convert_to_binary_multi_object(e2): e2_label = zeros_var_cuda([len(e2), num_labels]) e2_label_abs = zeros_var_cuda([len(e2), num_labels]) for i in range(len(e2)): e2_label[i][e2[i]] = 1 e2_label_abs[i][self.kg.get_typeid(e2[i])] = 1 return e2_label, e2_label_abs
def initialize_path(self, init_action, kg): # [batch_size, action_dim] if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(init_action[0]) else: init_action_embedding = self.get_action_embedding(init_action, kg) init_action_embedding.unsqueeze_(1) # (batch_size, seq_len, input_size), seq_len = 1 # [num_layers, batch_size, dim] init_h = zeros_var_cuda([self.history_num_layers, len(init_action_embedding), self.history_dim]) init_c = zeros_var_cuda([self.history_num_layers, len(init_action_embedding), self.history_dim]) self.path = [self.path_encoder(init_action_embedding, (init_h, init_c))[1]] # list of (h_n, c_n)
def initialize_path(self, init_action, kg): # [batch_size, action_dim] init_relation_embedding = kg.get_relation_embeddings(init_action[0]).unsqueeze(1) init_entity_embedding = kg.get_entity_embeddings(init_action[1]).unsqueeze(1) # [num_layers, batch_size, dim] init_h = zeros_var_cuda([self.history_num_layers, len(init_entity_embedding), self.history_dim]) init_c = zeros_var_cuda([self.history_num_layers, len(init_entity_embedding), self.history_dim]) self.path = [self.entity_agent(init_entity_embedding, (init_h, init_c))[1]] init_h = zeros_var_cuda([self.history_num_layers, len(init_relation_embedding), self.history_dim]) init_c = zeros_var_cuda([self.history_num_layers, len(init_relation_embedding), self.history_dim]) self.path_r = [self.relation_agent(init_relation_embedding, (init_h, init_c))[1]]
def predict(self, mini_batch, verbose=False): kg, pn = self.kg, self.mdl e1, e2, r = self.format_batch(mini_batch) beam_search_output = search.beam_search(pn, e1, r, e2, kg, self.num_rollout_steps, self.beam_size) pred_e2s = beam_search_output['pred_e2s'] pred_e2_scores = beam_search_output['pred_e2_scores'] if verbose: # print inference paths search_traces = beam_search_output['search_traces'] output_beam_size = min(self.beam_size, pred_e2_scores.shape[1]) for i in range(len(e1)): for j in range(output_beam_size): ind = i * output_beam_size + j if pred_e2s[i][j] == kg.dummy_e: break search_trace = [] for k in range(len(search_traces)): search_trace.append((int(search_traces[k][0][ind]), int(search_traces[k][1][ind]))) print('beam {}: score = {} \n<PATH> {}'.format( j, float(pred_e2_scores[i][j]), ops.format_path(search_trace, kg))) with torch.no_grad(): pred_scores = zeros_var_cuda([len(e1), kg.num_entities]) for i in range(len(e1)): pred_scores[i][pred_e2s[i].long()] = torch.exp( pred_e2_scores[i]) return pred_scores, beam_search_output["rule_score"]
def initialize_path(self, init_action, kg): # [batch_size, action_dim] # path comprises only of relation if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(init_action[0]) init_context = None # path comprises of MINERVA's vectors: [relation; entity], or just relation # For us, we want to change the configuration of the LSTM to generate the # paramaters using the relation, and then use the entity as the path.. else: # init_action_embedding = self.get_action_embedding(init_action, kg) init_r, init_e = init_action init_relation_embedding = kg.get_relation_embeddings(init_r) init_entity_embedding = kg.get_entity_embeddings(init_e) if self.context_info is None: init_action_embedding = torch.cat( [init_relation_embedding, init_entity_embedding], dim=-1) init_context = None else: init_action_embedding = init_entity_embedding init_context = init_relation_embedding # TODO: test that we can squeeze in LSTM layer, to keep inputs below consistent # init_action_embedding.unsqueeze_(1) # [num_layers, batch_size, dim] # init_h = zeros_var_cuda([self.history_num_layers, len(init_action_embedding), self.history_dim]) # init_c = zeros_var_cuda([self.history_num_layers, len(init_action_embedding), self.history_dim]) init_h = zeros_var_cuda([ len(init_action_embedding), self.history_num_layers, self.history_dim ]) init_c = zeros_var_cuda([ len(init_action_embedding), self.history_num_layers, self.history_dim ]) #print('action device: {} | history devices: {}, {} | context device: {}'.format(init_action_embedding.device, # init_h.device, # init_c.device, # init_context.device)) self.path = [ self.path_encoder(init_action_embedding, (init_h, init_c), init_context)[1] ]
def initialize_path(self, action: Action, kg: KnowledgeGraph): # [batch_size, action_dim] if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(action.rel) else: init_action_embedding = self.get_action_embedding(action, kg) init_action_embedding.unsqueeze_(1) # [num_layers, batch_size, dim] init_h = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) init_c = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) self.path = [ self.path_encoder(init_action_embedding, (init_h, init_c))[1] ]
def forward_fact_oracle(e1, r, e2, kg): oracle = zeros_var_cuda([len(e1), kg.num_entities]).cuda() for i in range(len(e1)): _e1, _r = int(e1[i]), int(r[i]) if _e1 in kg.all_object_vectors and _r in kg.all_object_vectors[_e1]: answer_vector = kg.all_object_vectors[_e1][_r] oracle[i][answer_vector] = 1 else: raise ValueError("Query answer not found") oracle_e2 = ops.batch_lookup(oracle, e2.unsqueeze(1)) return oracle_e2
def initialize_path(self, init_action, kg): # [batch_size, action_dim] if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(init_action[0]) else: # ================= newly added =================== # init_action_embedding = self.get_action_embedding(init_action, kg) init_action_embedding = self.get_agg_action_embedding( init_action, kg) # ================= newly added =================== init_action_embedding.unsqueeze_(1) # [num_layers, batch_size, dim] init_h = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) init_c = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) self.path = [ self.path_encoder(init_action_embedding, (init_h, init_c))[1] ]
def convert_to_binary_multi_object(e2): e2_label = zeros_var_cuda([len(e2), num_labels]) for i in range(len(e2)): e2_label[i][e2[i]] = 1 return e2_label
def convert_to_binary_multi_subject(e1): e1_label = zeros_var_cuda([len(e1), num_labels]) for i in range(len(e1)): e1_label[i][e1[i]] = 1 return e1_label
def beam_search(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action(log_action_dist, action_space): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] action_beam_offset = action_ind // action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx // action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) init_action = (r_s, e_s) # path encoder emb_e_s = pn.initialize_path(init_action, q, kg, 'eval') if kg.args.save_beam_search_paths: search_trace = [(r_s, e_s)] # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [ e_s, emb_e_s, q, e_t, t == 0, t == (num_steps - 1), last_r, seen_nodes ] # one step forward in search db_outcomes, _, _ = pn.transit(e, obs, kg, 'eval', use_action_space_bucketing=True, merge_aspace_batching_outcome=True) action_space, action_dist = db_outcomes[0] # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path(action, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = ' <-{}- {} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = ' -{}-> {} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) print(e_name, top_k_edge_labels) path_components_list.append(p_c) beam_search_output['path_components_list'] = path_components_list return beam_search_output
def initialize_path(self, init_action, q, kg, mode): ''' if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(init_action[0]) else: init_action_embedding = self.get_action_embedding(init_action, kg) ''' emb_e_s = None if self.relation_only_in_path: if mode == 'train': q = q.view(-1, self.num_rollouts)[:, 0] e_s = init_action[1] e_s = e_s.view(-1, self.num_rollouts)[:, 0] emb_e_s, _ = self.graph_transformer(e_s, q, self.dg.training_graph, self.dg.seen_id2entity, self.bandwidth, 'train') emb_e_s = emb_e_s.unsqueeze(1).expand(-1, self.num_rollouts, -1) emb_e_s = torch.flatten(emb_e_s, start_dim=0).view(-1, self.entity_dim) else: e_s = init_action[1] if self.args.inference: emb_e_s, _ = self.graph_transformer( e_s, q, self.dg.aux_graph, self.dg.seen_id2entity, self.bandwidth, 'test') else: emb_e_s, _ = self.graph_transformer( e_s, q, self.dg.eval_graph, self.dg.seen_id2entity, self.bandwidth, 'eval') emb_r_0 = self.graph_transformer.dropout( self.graph_transformer.emb_r(init_action[0])) init_action_embedding = torch.cat([emb_r_0, emb_e_s], dim=-1) else: if mode == 'train': q = q.view(-1, self.num_rollouts)[:, 0] e_s = init_action[1] e_s = e_s.view(-1, self.num_rollouts)[:, 0] emb_e_s, _ = self.graph_transformer(e_s, q, self.dg.training_graph, self.dg.seen_id2entity, self.bandwidth, 'train') emb_e_s = emb_e_s.unsqueeze(1).expand(-1, self.num_rollouts, -1) emb_e_s = torch.flatten(emb_e_s, start_dim=0).view(-1, self.entity_dim) else: e_s = init_action[1] if self.args.inference: emb_e_s, _ = self.graph_transformer( e_s, q, self.dg.aux_graph, self.dg.seen_id2entity, self.bandwidth, 'test') else: emb_e_s, _ = self.graph_transformer( e_s, q, self.dg.eval_graph, self.dg.seen_id2entity, self.bandwidth, 'eval') emb_r_0 = self.graph_transformer.dropout( self.graph_transformer.emb_r(init_action[0])) init_action_embedding = torch.cat([emb_r_0, emb_e_s], dim=-1) init_action_embedding.unsqueeze_(1) # [num_layers, batch_size, dim] init_h = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) init_c = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) self.path = [ self.path_encoder(init_action_embedding, (init_h, init_c))[1] ] return emb_e_s
def predict(self, mini_batch, verbose=False): kg, pn = self.kg, self.mdl e1, e2, r = self.format_batch(mini_batch) width = kg.num_entities - 1 beam_search_output = search.beam_search(pn, e1, r, e2, kg, self.num_rollout_steps, self.beam_size) with torch.no_grad(): pred_e2s = beam_search_output['pred_e2s'] pred_e2_scores = beam_search_output['pred_e2_scores'] pred_traces = beam_search_output['pred_traces'] for u in range(pred_e2s.size()[0]): for v in range(pred_e2s.size()[1]): if pred_e2s[u][v].item() == kg.dummy_end_e: for i in range(self.num_rollout_steps, 0, -1): if pred_traces[i][0][1][u * width + v] != kg.dummy_end_e: pred_e2s[u][v] = pred_traces[i][0][1][u * width + v] break accum_scores = ops.sort_by_path(pred_traces, kg) # for testing ggg = open("traces.txt", "a") hhh = open("accum_scores.txt", "a") for i in range(len(e1)): if e1[i] == 104 and r[i] == 5 and e2[i] == 72: print("Epoch: ", pred_traces[1][0][0][i * width:(i + 1) * width], "\n", pred_traces[1][0][1][i * width:(i + 1) * width], "\n", pred_traces[2][0][0][i * width:(i + 1) * width], "\n", pred_traces[2][0][1][i * width:(i + 1) * width], file=ggg) print(sorted(accum_scores[i].items(), key=lambda x: x[1], reverse=True), file=hhh) ggg.close() hhh.close() if verbose: # print inference paths search_traces = beam_search_output['search_traces'] output_beam_size = min(self.beam_size, pred_e2_scores.shape[1]) for i in range(len(e1)): for j in range(output_beam_size): ind = i * output_beam_size + j if pred_e2s[i][j] == kg.dummy_e: break search_trace = [] for k in range(len(search_traces)): search_trace.append((int(search_traces[k][0][ind]), int(search_traces[k][1][ind]))) print('beam {}: score = {} \n<PATH> {}'.format( j, float(pred_e2_scores[i][j]), ops.format_path(search_trace, kg))) with torch.no_grad(): uu = open("scores_pred.txt", "a") pred_scores = zeros_var_cuda([len(e1), kg.num_entities]) for i in range(len(e1)): pred_scores[i][pred_e2s[i]] = torch.exp(pred_e2_scores[i]) if e1[i] == 104 and r[i] == 5 and e2[i] == 72: print(pred_scores[i], file=uu) # 计算ranking的方法有待检讨。理论上应该使用第二种 # 这是两个加和取平均的方式。但实际上只有rma发挥了有效作用 # pred_scores[i][pred_e2s_1[i]] += torch.div(pred_e2_scores_1[i], 2.0) # pred_scores[i][pred_e2s_2[i]] += torch.div(pred_e2_scores_2[i], 2.0) # pred_scores[i] = torch.exp(pred_scores[i]) # 这是两步计算平均的方式 # for j in range(len(pred_e2s_2)): # pred_scores[i][pred_e2s_2[i][j]] = torch.exp(pred_e2_scores_2[i][j]) *\ # ops.check_path(accum_scores[i],pred_traces_2, i*width+j) uu.close() return pred_scores
def beam_search_same(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False, return_merge_scores=None, same_start=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action(log_action_dist, action_space, return_merge_scores=None): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) if return_merge_scores is not None: if return_merge_scores == 'sum': reduce_method = torch.sum elif return_merge_scores == 'mean': reduce_method = torch.mean else: reduce_method = None all_action_ind = torch.LongTensor([ range(beam_action_space_size) for _ in range(len(log_action_dist)) ]).cuda(device=0) all_next_r = ops.batch_lookup(r_space.view(batch_size, -1), all_action_ind) all_next_e = ops.batch_lookup(e_space.view(batch_size, -1), all_action_ind) real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same( log_action_dist, all_next_e, all_next_r, method=reduce_method) next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], [] for i in range(batch_size): k_prime = min(len(real_log_action_prob[i]), k) top_log_prob, top_ind = torch.topk(real_log_action_prob[i], k_prime) top_next_e, top_next_r, top_ind = real_next_e[i][ top_ind], real_next_r[i][top_ind], real_action_ind[i][ top_ind] next_e_list.append(top_next_e.unsqueeze(0)) next_r_list.append(top_next_r.unsqueeze(0)) action_ind_list.append(top_ind.unsqueeze(0)) log_action_prob_list.append(top_log_prob.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0).view(-1) action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1).view(-1) else: log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] 概率 :param action_space (r_space, e_space): 实体 r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx / action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) # WHY 最初为啥要空关系 seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) #TODO init_action = (r_s, e_s) e_s_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_s]), requires_grad=False) e_t_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_t]), requires_grad=False) # print("()()()es_size() e_s_abs.size():", e_s.size(), e_s_abs.size()) seen_nodes_abs = int_fill_var_cuda(e_s_abs.size(), kg.dummy_e).unsqueeze(1) init_action_abs = (r_s, e_s_abs) init_action = (r_s, e_s) # path encoder # pn.initialize_path(init_action, kg) pn.initialize_abs_path(init_action, init_action_abs, kg, same_start) if kg.args.save_paths_to_csv: search_trace = [(r_s, e_s)] # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action action_abs = init_action_abs for t in range(num_steps): last_r, e = action last_r, e_abs = action_abs assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_s_abs = ops.tile_along_beam(e_abs.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) e_t_abs = ops.tile_along_beam(e_t_abs.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes] obs_abs = [ e_s_abs, q, e_t_abs, t == (num_steps - 1), last_r, seen_nodes_abs ] # one step forward in search # db_outcomes, _, _ = pn.transit_same( # e, obs, kg, use_action_space_bucketing=False, # merge_aspace_batching_outcome=True) # TODO:细跟一下里面的get_action_space_in_buckets db_outcomes, _, _ = pn.transit_same( e, obs, e_abs, obs_abs, kg, use_action_space_bucketing=False, merge_aspace_batching_outcome=True ) # TODO:细跟一下里面的get_action_space_in_buckets action_space, action_dist = db_outcomes[0] # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1: if return_merge_scores is None: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space, return_merge_scores) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space, None) #(next_r, next_e) action_abs = (action[0], var_cuda(torch.LongTensor( [kg.get_typeid(_) for _ in action[1]]), requires_grad=False)) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path_abs(action_abs, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) seen_nodes_abs = torch.cat( [seen_nodes_abs[action_offset], action_abs[1].unsqueeze(1)], dim=1) if kg.args.save_paths_to_csv: adjust_search_trace(search_trace, action_offset) search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = '<-{}-{} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = '-{}->{} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) path_components_list.append(p_c) beam_search_output['search_traces'] = search_trace SAME_ALL_PATHS.append(beam_search_output) return beam_search_output
def beam_search(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action_r(log_action_dist): """ Get top k relations. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :return: next_r, log_action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) action_space_size = log_action_dist.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = (action_ind % action_space_size).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return next_r, log_action_prob, action_offset def top_k_action(log_action_dist, action_space): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), log_action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), log_action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx / action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) def to_one_hot(x): y = torch.eye(kg.num_relations).cuda() return y[x] def adjust_relation_trace(relation_trace, action_offset, relation): return torch.cat( [relation_trace[action_offset], relation.unsqueeze(1)], 1) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) init_action = (r_s, e_s) # record original q init_q = q # path encoder pn.initialize_path(init_action, kg) if kg.args.save_beam_search_paths: search_trace = [(r_s, e_s)] relation_trace = r_s.unsqueeze(1) # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action # pretrain evaluation without traversing in the graph if kg.args.pretrain and kg.args.pretrain_out_of_graph: for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None] # one step forward in relation search r_prob, _ = pn.transit_r(e, obs, kg) # => [batch_size*k, relation_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(r_prob) #action_space = torch.stack([torch.arange(kg.num_relations)] * len(r_prob), 0).long().cuda() action_r, log_action_prob, action_offset = top_k_action_r( log_action_dist) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path_r(action_r, kg, offset=action_offset) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) relation_trace = adjust_relation_trace(relation_trace, action_offset, action_r) # one step forward in entity search k = int(action_r.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) e = e[action_offset] action = (action_r, e) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view( batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace rule_score = torch.zeros(batch_size, output_beam_size).cuda().float() top_rules = pickle.load(open(kg.args.rule, 'rb')) for i in range(batch_size): for j in range(output_beam_size): path_ij = relation_trace[i * output_beam_size + j, 1:].cpu().numpy().tolist() if not int(init_q[i]) in top_rules.keys(): rule_score[i][j] = 0.0 elif tuple(path_ij) in top_rules[int(init_q[i])].keys(): rule_score[i][j] = top_rules[int( init_q[i])][tuple(path_ij)] #print(tuple(relation_trace[i * output_beam_size + j, 1 : ])) #print(top_rules[int(init_q[i])].keys()) # rule_score = (rule_score > 0).float() beam_search_output["rule_score"] = torch.mean(rule_score, 1) assert len(beam_search_output["rule_score"]) == batch_size return beam_search_output for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes] # one step forward in search db_outcomes, _, _ = pn.transit(e, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=True) action_space, action_dist = db_outcomes[0] # incorporate r_prob (r_space, e_space), action_mask = action_space r_prob, _ = pn.transit_r(e, obs, kg) if kg.args.pretrain: r_space = torch.ones(len(r_space), kg.num_relations).long() * torch.arange( kg.num_relations) r_space = r_space.cuda() e_space = torch.ones(len(e_space), kg.num_relations).long().cuda() action_dist = r_prob action_dist = action_dist * kg.r_prob_mask + ( 1 - kg.r_prob_mask) * ops.EPSILON action_space = (r_space, e_space), action_mask else: # r_space_dist = torch.matmul(r_prob.unsqueeze(1), to_one_hot(r_space).transpose(1, 2)) # action_dist = torch.mul(r_space_dist.squeeze(1), action_dist) r_space_dist = torch.gather(r_prob, 1, r_space) action_dist = torch.mul(r_space_dist, action_dist) action_dist = action_dist * action_mask + ( 1 - action_mask) * ops.EPSILON # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1 and not kg.args.pretrain: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path(action, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) search_trace.append(action) relation_trace = adjust_relation_trace(relation_trace, action_offset, action[0]) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace rule_score = torch.zeros(batch_size, output_beam_size).cuda().float() top_rules = pickle.load(open(kg.args.rule, 'rb')) for i in range(batch_size): for j in range(output_beam_size): path_ij = relation_trace[i * output_beam_size + j, 1:].cpu().numpy().tolist() if not int(init_q[i]) in top_rules.keys(): rule_score[i][j] = 0.0 elif tuple(path_ij) in top_rules[int(init_q[i])].keys(): rule_score[i][j] = top_rules[int(init_q[i])][tuple(path_ij)] #print(tuple(relation_trace[i * output_beam_size + j, 1 : ])) #print(top_rules[int(init_q[i])].keys()) # rule_score = (rule_score > 0).float() beam_search_output["rule_score"] = torch.mean(rule_score, 1) assert len(beam_search_output["rule_score"]) == batch_size if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = ' <-{}- {} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = ' -{}-> {} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) path_components_list.append(p_c) beam_search_output['path_components_list'] = path_components_list return beam_search_output