def beam_search_same(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False, return_merge_scores=None, same_start=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action(log_action_dist, action_space, return_merge_scores=None): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) if return_merge_scores is not None: if return_merge_scores == 'sum': reduce_method = torch.sum elif return_merge_scores == 'mean': reduce_method = torch.mean else: reduce_method = None all_action_ind = torch.LongTensor([ range(beam_action_space_size) for _ in range(len(log_action_dist)) ]).cuda(device=0) all_next_r = ops.batch_lookup(r_space.view(batch_size, -1), all_action_ind) all_next_e = ops.batch_lookup(e_space.view(batch_size, -1), all_action_ind) real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same( log_action_dist, all_next_e, all_next_r, method=reduce_method) next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], [] for i in range(batch_size): k_prime = min(len(real_log_action_prob[i]), k) top_log_prob, top_ind = torch.topk(real_log_action_prob[i], k_prime) top_next_e, top_next_r, top_ind = real_next_e[i][ top_ind], real_next_r[i][top_ind], real_action_ind[i][ top_ind] next_e_list.append(top_next_e.unsqueeze(0)) next_r_list.append(top_next_r.unsqueeze(0)) action_ind_list.append(top_ind.unsqueeze(0)) log_action_prob_list.append(top_log_prob.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0).view(-1) action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1).view(-1) else: log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] 概率 :param action_space (r_space, e_space): 实体 r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx / action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) # WHY 最初为啥要空关系 seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) #TODO init_action = (r_s, e_s) e_s_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_s]), requires_grad=False) e_t_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_t]), requires_grad=False) # print("()()()es_size() e_s_abs.size():", e_s.size(), e_s_abs.size()) seen_nodes_abs = int_fill_var_cuda(e_s_abs.size(), kg.dummy_e).unsqueeze(1) init_action_abs = (r_s, e_s_abs) init_action = (r_s, e_s) # path encoder # pn.initialize_path(init_action, kg) pn.initialize_abs_path(init_action, init_action_abs, kg, same_start) if kg.args.save_paths_to_csv: search_trace = [(r_s, e_s)] # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action action_abs = init_action_abs for t in range(num_steps): last_r, e = action last_r, e_abs = action_abs assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_s_abs = ops.tile_along_beam(e_abs.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) e_t_abs = ops.tile_along_beam(e_t_abs.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes] obs_abs = [ e_s_abs, q, e_t_abs, t == (num_steps - 1), last_r, seen_nodes_abs ] # one step forward in search # db_outcomes, _, _ = pn.transit_same( # e, obs, kg, use_action_space_bucketing=False, # merge_aspace_batching_outcome=True) # TODO:细跟一下里面的get_action_space_in_buckets db_outcomes, _, _ = pn.transit_same( e, obs, e_abs, obs_abs, kg, use_action_space_bucketing=False, merge_aspace_batching_outcome=True ) # TODO:细跟一下里面的get_action_space_in_buckets action_space, action_dist = db_outcomes[0] # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1: if return_merge_scores is None: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space, return_merge_scores) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space, None) #(next_r, next_e) action_abs = (action[0], var_cuda(torch.LongTensor( [kg.get_typeid(_) for _ in action[1]]), requires_grad=False)) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path_abs(action_abs, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) seen_nodes_abs = torch.cat( [seen_nodes_abs[action_offset], action_abs[1].unsqueeze(1)], dim=1) if kg.args.save_paths_to_csv: adjust_search_trace(search_trace, action_offset) search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = '<-{}-{} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = '-{}->{} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) path_components_list.append(p_c) beam_search_output['search_traces'] = search_trace SAME_ALL_PATHS.append(beam_search_output) return beam_search_output
def beam_search(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action(log_action_dist, action_space): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # *** compute parent offset # [batch_size, k] action_beam_offset = action_ind // action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx // action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) init_action = (r_s, e_s) # path encoder emb_e_s = pn.initialize_path(init_action, q, kg, 'eval') if kg.args.save_beam_search_paths: search_trace = [(r_s, e_s)] # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [ e_s, emb_e_s, q, e_t, t == 0, t == (num_steps - 1), last_r, seen_nodes ] # one step forward in search db_outcomes, _, _ = pn.transit(e, obs, kg, 'eval', use_action_space_bucketing=True, merge_aspace_batching_outcome=True) action_space, action_dist = db_outcomes[0] # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path(action, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = ' <-{}- {} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = ' -{}-> {} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) print(e_name, top_k_edge_labels) path_components_list.append(p_c) beam_search_output['path_components_list'] = path_components_list return beam_search_output
def beam_search(pn, e_s, q, e_t, kg, num_steps, beam_size, return_path_components=False): """ Beam search from source. :param pn: Policy network. :param e_s: (Variable:batch) source entity indices. :param q: (Variable:batch) query relation indices. :param e_t: (Variable:batch) target entity indices. :param kg: Knowledge graph environment. :param num_steps: Number of search steps. :param beam_size: Beam size used in search. :param return_path_components: If set, return all path components at the end of search. """ assert (num_steps >= 1) batch_size = len(e_s) def top_k_action_r(log_action_dist): """ Get top k relations. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :return: next_r, log_action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) action_space_size = log_action_dist.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = (action_ind % action_space_size).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return next_r, log_action_prob, action_offset def top_k_action(log_action_dist, action_space): """ Get top k actions. - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*k, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*k, action_space_size] e_space: [batch_size*k, action_space_size] :return: (next_r, next_e), log_action_prob, action_offset: [batch_size*new_k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] # => [batch_size, k'*action_space_size] log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] k = min(beam_size, beam_action_space_size) # [batch_size, k] log_action_prob, action_ind = torch.topk(log_action_dist, k) next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1) next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1) # [batch_size, k] => [batch_size*k] log_action_prob = log_action_prob.view(-1) # compute parent offset # [batch_size, k] action_beam_offset = action_ind / action_space_size # [batch_size, 1] action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1) # [batch_size, k] => [batch_size*k] action_offset = (action_batch_offset + action_beam_offset).view(-1) return (next_r, next_e), log_action_prob, action_offset def top_k_answer_unique(log_action_dist, action_space): """ Get top k unique entities - k = beam_size if the beam size is smaller than or equal to the beam action space size - k = beam_action_space_size otherwise :param log_action_dist: [batch_size*beam_size, action_space_size] :param action_space (r_space, e_space): r_space: [batch_size*beam_size, action_space_size] e_space: [batch_size*beam_size, action_space_size] :return: (next_r, next_e), log_action_prob, action_offset: [batch_size*k] """ full_size = len(log_action_dist) assert (full_size % batch_size == 0) last_k = int(full_size / batch_size) (r_space, e_space), _ = action_space action_space_size = r_space.size()[1] r_space = r_space.view(batch_size, -1) e_space = e_space.view(batch_size, -1) log_action_dist = log_action_dist.view(batch_size, -1) beam_action_space_size = log_action_dist.size()[1] assert (beam_action_space_size % action_space_size == 0) k = min(beam_size, beam_action_space_size) next_r_list, next_e_list = [], [] log_action_prob_list = [] action_offset_list = [] for i in range(batch_size): log_action_dist_b = log_action_dist[i] r_space_b = r_space[i] e_space_b = e_space[i] unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu())) unique_log_action_dist, unique_idx = unique_max( unique_e_space_b, e_space_b, log_action_dist_b) k_prime = min(len(unique_e_space_b), k) top_unique_log_action_dist, top_unique_idx2 = torch.topk( unique_log_action_dist, k_prime) top_unique_idx = unique_idx[top_unique_idx2] top_unique_beam_offset = top_unique_idx / action_space_size top_r = r_space_b[top_unique_idx] top_e = e_space_b[top_unique_idx] next_r_list.append(top_r.unsqueeze(0)) next_e_list.append(top_e.unsqueeze(0)) log_action_prob_list.append( top_unique_log_action_dist.unsqueeze(0)) top_unique_batch_offset = i * last_k top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset action_offset_list.append(top_unique_action_offset.unsqueeze(0)) next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1) next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1) log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=-ops.HUGE_INT) action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1) return (next_r, next_e), log_action_prob.view(-1), action_offset.view(-1) def adjust_search_trace(search_trace, action_offset): for i, (r, e) in enumerate(search_trace): new_r = r[action_offset] new_e = e[action_offset] search_trace[i] = (new_r, new_e) def to_one_hot(x): y = torch.eye(kg.num_relations).cuda() return y[x] def adjust_relation_trace(relation_trace, action_offset, relation): return torch.cat( [relation_trace[action_offset], relation.unsqueeze(1)], 1) # Initialization r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r) seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1) init_action = (r_s, e_s) # record original q init_q = q # path encoder pn.initialize_path(init_action, kg) if kg.args.save_beam_search_paths: search_trace = [(r_s, e_s)] relation_trace = r_s.unsqueeze(1) # Run beam search for num_steps # [batch_size*k], k=1 log_action_prob = zeros_var_cuda(batch_size) if return_path_components: log_action_probs = [] action = init_action # pretrain evaluation without traversing in the graph if kg.args.pretrain and kg.args.pretrain_out_of_graph: for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None] # one step forward in relation search r_prob, _ = pn.transit_r(e, obs, kg) # => [batch_size*k, relation_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(r_prob) #action_space = torch.stack([torch.arange(kg.num_relations)] * len(r_prob), 0).long().cuda() action_r, log_action_prob, action_offset = top_k_action_r( log_action_dist) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path_r(action_r, kg, offset=action_offset) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) relation_trace = adjust_relation_trace(relation_trace, action_offset, action_r) # one step forward in entity search k = int(action_r.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) e = e[action_offset] action = (action_r, e) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: search_trace.append(action) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view( batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace rule_score = torch.zeros(batch_size, output_beam_size).cuda().float() top_rules = pickle.load(open(kg.args.rule, 'rb')) for i in range(batch_size): for j in range(output_beam_size): path_ij = relation_trace[i * output_beam_size + j, 1:].cpu().numpy().tolist() if not int(init_q[i]) in top_rules.keys(): rule_score[i][j] = 0.0 elif tuple(path_ij) in top_rules[int(init_q[i])].keys(): rule_score[i][j] = top_rules[int( init_q[i])][tuple(path_ij)] #print(tuple(relation_trace[i * output_beam_size + j, 1 : ])) #print(top_rules[int(init_q[i])].keys()) # rule_score = (rule_score > 0).float() beam_search_output["rule_score"] = torch.mean(rule_score, 1) assert len(beam_search_output["rule_score"]) == batch_size return beam_search_output for t in range(num_steps): last_r, e = action assert (q.size() == e_s.size()) assert (q.size() == e_t.size()) assert (e.size()[0] % batch_size == 0) assert (q.size()[0] % batch_size == 0) k = int(e.size()[0] / batch_size) # => [batch_size*k] q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k) e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k) e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k) obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes] # one step forward in search db_outcomes, _, _ = pn.transit(e, obs, kg, use_action_space_bucketing=True, merge_aspace_batching_outcome=True) action_space, action_dist = db_outcomes[0] # incorporate r_prob (r_space, e_space), action_mask = action_space r_prob, _ = pn.transit_r(e, obs, kg) if kg.args.pretrain: r_space = torch.ones(len(r_space), kg.num_relations).long() * torch.arange( kg.num_relations) r_space = r_space.cuda() e_space = torch.ones(len(e_space), kg.num_relations).long().cuda() action_dist = r_prob action_dist = action_dist * kg.r_prob_mask + ( 1 - kg.r_prob_mask) * ops.EPSILON action_space = (r_space, e_space), action_mask else: # r_space_dist = torch.matmul(r_prob.unsqueeze(1), to_one_hot(r_space).transpose(1, 2)) # action_dist = torch.mul(r_space_dist.squeeze(1), action_dist) r_space_dist = torch.gather(r_prob, 1, r_space) action_dist = torch.mul(r_space_dist, action_dist) action_dist = action_dist * action_mask + ( 1 - action_mask) * ops.EPSILON # => [batch_size*k, action_space_size] log_action_dist = log_action_prob.view(-1, 1) + ops.safe_log(action_dist) # [batch_size*k, action_space_size] => [batch_size*new_k] if t == num_steps - 1 and not kg.args.pretrain: action, log_action_prob, action_offset = top_k_answer_unique( log_action_dist, action_space) else: action, log_action_prob, action_offset = top_k_action( log_action_dist, action_space) if return_path_components: ops.rearrange_vector_list(log_action_probs, action_offset) log_action_probs.append(log_action_prob) pn.update_path(action, kg, offset=action_offset) seen_nodes = torch.cat( [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1) if kg.args.save_beam_search_paths: adjust_search_trace(search_trace, action_offset) search_trace.append(action) relation_trace = adjust_relation_trace(relation_trace, action_offset, action[0]) output_beam_size = int(action[0].size()[0] / batch_size) # [batch_size*beam_size] => [batch_size, beam_size] beam_search_output = dict() beam_search_output['pred_e2s'] = action[1].view(batch_size, -1) beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1) if kg.args.save_beam_search_paths: beam_search_output['search_traces'] = search_trace rule_score = torch.zeros(batch_size, output_beam_size).cuda().float() top_rules = pickle.load(open(kg.args.rule, 'rb')) for i in range(batch_size): for j in range(output_beam_size): path_ij = relation_trace[i * output_beam_size + j, 1:].cpu().numpy().tolist() if not int(init_q[i]) in top_rules.keys(): rule_score[i][j] = 0.0 elif tuple(path_ij) in top_rules[int(init_q[i])].keys(): rule_score[i][j] = top_rules[int(init_q[i])][tuple(path_ij)] #print(tuple(relation_trace[i * output_beam_size + j, 1 : ])) #print(top_rules[int(init_q[i])].keys()) # rule_score = (rule_score > 0).float() beam_search_output["rule_score"] = torch.mean(rule_score, 1) assert len(beam_search_output["rule_score"]) == batch_size if return_path_components: path_width = 10 path_components_list = [] for i in range(batch_size): p_c = [] for k, log_action_prob in enumerate(log_action_probs): top_k_edge_labels = [] for j in range(min(output_beam_size, path_width)): ind = i * output_beam_size + j r = kg.id2relation[int(search_trace[k + 1][0][ind])] e = kg.id2entity[int(search_trace[k + 1][1][ind])] if r.endswith('_inv'): edge_label = ' <-{}- {} {}'.format( r[:-4], e, float(log_action_probs[k][ind])) else: edge_label = ' -{}-> {} {}'.format( r, e, float(log_action_probs[k][ind])) top_k_edge_labels.append(edge_label) top_k_action_prob = log_action_prob[:path_width] e_name = kg.id2entity[int( search_trace[1][0][i * output_beam_size])] if k == 0 else '' p_c.append((e_name, top_k_edge_labels, var_to_numpy(top_k_action_prob))) path_components_list.append(p_c) beam_search_output['path_components_list'] = path_components_list return beam_search_output