Пример #1
0
    def rollout(self, e_s, q, e_t, num_steps, visualize_action_probs=False):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param pn: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query embedding.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :param visualize_action_probs: If set, save action probabilities for visualization.
        :return pred_e2: Target entities reached at the end of rollout.
        :return log_path_prob: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert num_steps > 0
        kg, pn = self.kg, self.mdl

        # Initialization
        log_action_probs = []
        action_entropy = []
        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
        path_components = []

        path_trace = [(r_s, e_s)]
        pn.initialize_path((r_s, e_s), kg)

        for t in range(num_steps):
            last_r, e = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
            db_outcomes, inv_offset, policy_entropy = pn.transit(
                e,
                obs,
                kg,
                use_action_space_bucketing=self.use_action_space_bucketing)
            sample_outcome = self.sample_action(db_outcomes, inv_offset)
            action = sample_outcome["action_sample"]
            pn.update_path(action, kg)
            action_prob = sample_outcome["action_prob"]
            log_action_probs.append(ops.safe_log(action_prob))
            action_entropy.append(policy_entropy)
            seen_nodes = torch.cat([seen_nodes, e.unsqueeze(1)], dim=1)
            path_trace.append(action)

            if visualize_action_probs:
                top_k_action = sample_outcome["top_actions"]
                top_k_action_prob = sample_outcome["top_action_probs"]
                path_components.append((e, top_k_action, top_k_action_prob))

        pred_e2 = path_trace[-1][1]
        self.record_path_trace(path_trace)

        return {
            "pred_e2": pred_e2,
            "log_action_probs": log_action_probs,
            "action_entropy": action_entropy,
            "path_trace": path_trace,
            "path_components": path_components,
        }
    def rollout(self, e_s, q, e_t, num_steps, visualize_action_probs=False):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param agent: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query relation indices.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :param visualize_action_probs: If set, save action probabilities for visualization.
        :return pred_e2: Target entities reached at the end of rollout.
        :return log_path_prob: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert num_steps > 0
        kg, agent = self.kg, self.agent

        # Initialization
        log_action_probs = []
        action_entropy = []
        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
        path_components = []

        path_trace: List[Action] = [Action(r_s, e_s)]
        agent.initialize_path(path_trace[0], kg)

        for t in range(num_steps):
            last_r, e = path_trace[-1]
            obs = Observation(e_s, q, e_t, t == (num_steps - 1), last_r,
                              seen_nodes)
            # e_t is needed to form the ground_truth_edge_mask
            ab: BucketActions = agent.transit(e, obs, kg,
                                              self.use_action_space_bucketing)
            action, action_prob = self.sample_action(ab)
            agent.update_path(action, kg)
            log_action_probs.append(ops.safe_log(action_prob))
            action_entropy.append(ab.entropy)
            seen_nodes = torch.cat([seen_nodes, e.unsqueeze(1)], dim=1)
            path_trace.append(action)

        pred_e2 = path_trace[-1][1]
        self.record_path_trace(path_trace)

        return {
            "pred_e2": pred_e2,
            "log_action_probs": log_action_probs,
            "action_entropy": action_entropy,
            "path_trace": path_trace,
            "path_components": path_components,
        }
Пример #3
0
    def rollout_pretrain(self, e_s, q, e_t, num_steps):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param pn: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query relation indices.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :return log_action_probs: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        # Initialization
        log_action_probs = []
        action_entropy = []

        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)

        path_trace = [r_s]
        pn.initialize_path((r_s, e_s), kg)
        for t in range(num_steps):

            last_r = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None]

            # relation selection
            r_prob, policy_entropy = pn.transit_r(None, obs, kg)
            sample_outcome_r = self.sample_relation(r_prob)
            action_r = sample_outcome_r['action_sample']
            action_prob = sample_outcome_r['action_prob']
            pn.update_path_r(action_r, kg)

            action_entropy.append(policy_entropy)

            log_action_probs.append(ops.safe_log(action_prob))
            path_trace.append(action_r)

        return {
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace
        }
def beam_search(pn,
                e_s,
                q,
                e_t,
                kg,
                num_steps,
                beam_size,
                return_path_components=False):
    """
    Beam search from source.

    :param pn: Policy network.
    :param e_s: (Variable:batch) source entity indices.
    :param q: (Variable:batch) query relation indices.
    :param e_t: (Variable:batch) target entity indices.
    :param kg: Knowledge graph environment.
    :param num_steps: Number of search steps.
    :param beam_size: Beam size used in search.
    :param return_path_components: If set, return all path components at the end of search.
    """
    assert (num_steps >= 1)
    batch_size = len(e_s)

    def top_k_action(log_action_dist, action_space):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)
        # [batch_size, k]
        log_action_prob, action_ind = torch.topk(log_action_dist, k)
        next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                  action_ind).view(-1)
        next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                  action_ind).view(-1)
        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind // action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset

    def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx // action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)

    def adjust_search_trace(search_trace, action_offset):
        for i, (r, e) in enumerate(search_trace):
            new_r = r[action_offset]
            new_e = e[action_offset]
            search_trace[i] = (new_r, new_e)

    # Initialization
    r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
    seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
    init_action = (r_s, e_s)
    # path encoder
    emb_e_s = pn.initialize_path(init_action, q, kg, 'eval')
    if kg.args.save_beam_search_paths:
        search_trace = [(r_s, e_s)]

    # Run beam search for num_steps
    # [batch_size*k], k=1
    log_action_prob = zeros_var_cuda(batch_size)
    if return_path_components:
        log_action_probs = []

    action = init_action
    for t in range(num_steps):
        last_r, e = action
        assert (q.size() == e_s.size())
        assert (q.size() == e_t.size())
        assert (e.size()[0] % batch_size == 0)
        assert (q.size()[0] % batch_size == 0)
        k = int(e.size()[0] / batch_size)
        # => [batch_size*k]
        q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
        e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
        e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
        obs = [
            e_s, emb_e_s, q, e_t, t == 0, t == (num_steps - 1), last_r,
            seen_nodes
        ]
        # one step forward in search
        db_outcomes, _, _ = pn.transit(e,
                                       obs,
                                       kg,
                                       'eval',
                                       use_action_space_bucketing=True,
                                       merge_aspace_batching_outcome=True)
        action_space, action_dist = db_outcomes[0]
        # => [batch_size*k, action_space_size]
        log_action_dist = log_action_prob.view(-1,
                                               1) + ops.safe_log(action_dist)
        # [batch_size*k, action_space_size] => [batch_size*new_k]
        if t == num_steps - 1:
            action, log_action_prob, action_offset = top_k_answer_unique(
                log_action_dist, action_space)
        else:
            action, log_action_prob, action_offset = top_k_action(
                log_action_dist, action_space)
        if return_path_components:
            ops.rearrange_vector_list(log_action_probs, action_offset)
            log_action_probs.append(log_action_prob)
        pn.update_path(action, kg, offset=action_offset)
        seen_nodes = torch.cat(
            [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1)
        if kg.args.save_beam_search_paths:
            adjust_search_trace(search_trace, action_offset)
            search_trace.append(action)

    output_beam_size = int(action[0].size()[0] / batch_size)
    # [batch_size*beam_size] => [batch_size, beam_size]
    beam_search_output = dict()
    beam_search_output['pred_e2s'] = action[1].view(batch_size, -1)
    beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1)
    if kg.args.save_beam_search_paths:
        beam_search_output['search_traces'] = search_trace

    if return_path_components:
        path_width = 10
        path_components_list = []
        for i in range(batch_size):
            p_c = []
            for k, log_action_prob in enumerate(log_action_probs):
                top_k_edge_labels = []
                for j in range(min(output_beam_size, path_width)):
                    ind = i * output_beam_size + j
                    r = kg.id2relation[int(search_trace[k + 1][0][ind])]
                    e = kg.id2entity[int(search_trace[k + 1][1][ind])]
                    if r.endswith('_inv'):
                        edge_label = ' <-{}- {} {}'.format(
                            r[:-4], e, float(log_action_probs[k][ind]))
                    else:
                        edge_label = ' -{}-> {} {}'.format(
                            r, e, float(log_action_probs[k][ind]))
                    top_k_edge_labels.append(edge_label)
                top_k_action_prob = log_action_prob[:path_width]
                e_name = kg.id2entity[int(
                    search_trace[1][0][i *
                                       output_beam_size])] if k == 0 else ''
                p_c.append((e_name, top_k_edge_labels,
                            var_to_numpy(top_k_action_prob)))
                print(e_name, top_k_edge_labels)
            path_components_list.append(p_c)
        beam_search_output['path_components_list'] = path_components_list

    return beam_search_output
Пример #5
0
    def export_fuzzy_facts(self):
        """
        Export high confidence facts according to the model.
        """
        kg, mdl = self.kg, self.mdl

        # Gather all possible (subject, relation) and (relation, object) pairs
        sub_rel, rel_obj = {}, {}
        for file_name in ['raw.kb', 'train.triples', 'dev.triples', 'test.triples']:
            with open(os.path.join(self.data_dir, file_name)) as f:
                for line in f:
                    e1, e2, r = line.strip().split()
                    e1_id, e2_id, r_id = kg.triple2ids((e1, e2, r))
                    if not e1_id in sub_rel:
                        sub_rel[e1_id] = {}
                    if not r_id in sub_rel[e1_id]:
                        sub_rel[e1_id][r_id] = set()
                    sub_rel[e1_id][r_id].add(e2_id)
                    if not e2_id in rel_obj:
                        rel_obj[e2_id] = {}
                    if not r_id in rel_obj[e2_id]:
                        rel_obj[e2_id][r_id] = set()
                    rel_obj[e2_id][r_id].add(e1_id)

        o_f = open(os.path.join(self.data_dir, 'train.fuzzy.triples'), 'w')
        print('Saving fuzzy facts to {}'.format(os.path.join(self.data_dir, 'train.fuzzy.triples')))
        count = 0
        # Save recovered objects
        e1_ids, r_ids = [], []
        for e1_id in sub_rel:
            for r_id in sub_rel[e1_id]:
                e1_ids.append(e1_id)
                r_ids.append(r_id)
        for i in range(0, len(e1_ids), self.batch_size):
            e1_ids_b = e1_ids[i:i+self.batch_size]
            r_ids_b = r_ids[i:i+self.batch_size]
            e1 = var_cuda(torch.LongTensor(e1_ids_b))
            r = var_cuda(torch.LongTensor(r_ids_b))
            pred_scores = mdl.forward(e1, r, kg)
            for j in range(pred_scores.size(0)):
                for _e2 in range(pred_scores.size(1)):
                    if _e2 in [NO_OP_ENTITY_ID, DUMMY_ENTITY_ID]:
                        continue
                    if pred_scores[j, _e2] >= self.theta:
                        _e1 = int(e1[j])
                        _r = int(r[j])
                        o_f.write('{}\t{}\t{}\t{}\n'.format(
                            kg.id2entity[_e1], kg.id2entity[_e2], kg.id2relation[_r], float(pred_scores[j, _e2])))
                        count += 1
                        if count % 1000 == 0:
                            print('{} fuzzy facts exported'.format(count))
        # Save recovered subjects
        e2_ids, r_ids = [], []
        for e2_id in rel_obj:
            for r_id in rel_obj[e2_id]:
                e2_ids.append(e2_id)
                r_ids.append(r_id)
        e1 = int_var_cuda(torch.arange(kg.num_entities))
        for i in range(len(e2_ids)):
            r = int_fill_var_cuda(e1.size(), r_ids[i])
            e2 = int_fill_var_cuda(e1.size(), e2_ids[i])
            pred_scores = mdl.forward_fact(e1, r, e2, kg)
            for j in range(pred_scores.size(1)):
                if pred_scores[j] > self.theta:
                    _e1 = int(e1[j])
                    if _e1 in [NO_OP_ENTITY_ID, DUMMY_ENTITY_ID]:
                        continue
                    _r = int(r[j])
                    _e2 = int(e2[j])
                    if _e1 in sub_rel and _r in sub_rel[_e1]:
                        continue
                    o_f.write('{}\t{}\t{}\t{}\n'.format(
                        kg.id2entity[_e1], kg.id2entity[_e2], kg.id2relation[_r], float(pred_scores[j])))
                    count += 1
                    if count % 1000 == 0:
                        print('{} fuzzy facts exported'.format(count))
Пример #6
0
    def teacher_forcing_pretrain(self, e_s, q, e_t, num_steps):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param pn: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query relation indices.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :return log_action_probs: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        # Initialization
        log_action_probs = []
        action_entropy = []

        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        path_label = []
        cnt = 0
        for q_b in q:
            if int(q_b) not in self.rel2rules.keys():
                path_label.append(
                    torch.randint(kg.num_relations,
                                  (num_steps, )).numpy().tolist())
                continue
            rules = self.rel2rules[int(q_b)]
            cnt += 1
            # uniform
            # sample_rule_id = torch.randint(len(rules.keys()), (1,)).item()

            # weighted by confidence score
            rule_dist_orig = torch.tensor(list(rules.values())).cuda()
            rand = torch.rand(rule_dist_orig.size())
            keep_mask = var_cuda(rand > self.action_dropout_rate).float()
            rule_dist = rule_dist_orig * keep_mask
            rule_sum = torch.sum(rule_dist, 0)
            is_zero = (rule_sum == 0).float()  #.unsqueeze(1)
            rule_dist = rule_dist + is_zero * rule_dist_orig
            sample_rule_id = torch.multinomial(rule_dist, 1).item()
            path_label.append(list(rules.keys())[sample_rule_id])
        path_label = torch.tensor(path_label).cuda()
        #print('rule_path_percentage:', cnt/len(path_label))

        path_trace = [r_s]
        pn.initialize_path((r_s, e_s), kg)
        for t in range(num_steps):
            last_r = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None]

            # relation selection
            r_prob, policy_entropy = pn.transit_r(None, obs, kg)
            action_r = path_label[:, t]
            action_prob = ops.batch_lookup(r_prob, action_r.view(-1, 1))
            pn.update_path_r(action_r, kg)

            action_entropy.append(policy_entropy)

            log_action_probs.append(ops.safe_log(action_prob))
            path_trace.append(action_r)

        return {
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace
        }
Пример #7
0
    def rollout(self, e_s, q, e_t, num_steps, visualize_action_probs=False):
        # 改变:现场计算reward。

        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        def reward_fun_binary(e1, r, e2, pred_e2, reward_binary):
            reward = (pred_e2 == e2).float()
            for i in range(e1.size()[0]):
                if reward_binary[i] and pred_e2[i] == kg.dummy_end_e:
                    reward[i] = 1
            return reward

        # Initialization
        log_action_probs = []
        action_entropy = []
        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
        path_components = []
        reward = torch.zeros(e_s.size()).cuda()

        path_trace = [(r_s, e_s)]
        pn.initialize_path((r_s, e_s), kg)

        logr = open("traces.txt", "a")
        for t in range(num_steps):
            last_r, e = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
            db_outcomes, inv_offset, policy_entropy = pn.transit(
                e,
                obs,
                kg,
                use_action_space_bucketing=self.use_action_space_bucketing)
            sample_outcome = self.sample_action(db_outcomes, inv_offset)
            action = sample_outcome['action_sample']
            reward = reward + reward_fun_binary(e_s, q, e_t, action[1],
                                                reward)  #现场计算reward
            torch.set_printoptions(threshold=5000)
            pn.update_path(action, kg)
            action_prob = sample_outcome['action_prob']
            log_action_probs.append(ops.safe_log(action_prob))
            action_entropy.append(policy_entropy)
            seen_nodes = torch.cat([seen_nodes, e.unsqueeze(1)], dim=1)
            path_trace.append(action)
            #print(action[0], file=logr)

            if visualize_action_probs:
                top_k_action = sample_outcome['top_actions']
                top_k_action_prob = sample_outcome['top_action_probs']
                path_components.append((e, top_k_action, top_k_action_prob))

        pred_e2 = path_trace[-1][1]  #理论来讲需要改,但是实际上好像没用而且耽误backprop……
        reward = self.reward_fun(e_s, q, e_t, pred_e2, reward)
        #print(reward, file=logr)
        self.record_path_trace(path_trace)

        return {
            'pred_e2': pred_e2,
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace,
            'path_components': path_components,
            'reward': reward
        }
def beam_search_same(pn,
                     e_s,
                     q,
                     e_t,
                     kg,
                     num_steps,
                     beam_size,
                     return_path_components=False,
                     return_merge_scores=None,
                     same_start=False):
    """
    Beam search from source.

    :param pn: Policy network.
    :param e_s: (Variable:batch) source entity indices.
    :param q: (Variable:batch) query relation indices.
    :param e_t: (Variable:batch) target entity indices.
    :param kg: Knowledge graph environment.
    :param num_steps: Number of search steps.
    :param beam_size: Beam size used in search.
    :param return_path_components: If set, return all path components at the end of search.
    """
    assert (num_steps >= 1)
    batch_size = len(e_s)

    def top_k_action(log_action_dist, action_space, return_merge_scores=None):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)

        if return_merge_scores is not None:
            if return_merge_scores == 'sum':
                reduce_method = torch.sum
            elif return_merge_scores == 'mean':
                reduce_method = torch.mean
            else:
                reduce_method = None

            all_action_ind = torch.LongTensor([
                range(beam_action_space_size)
                for _ in range(len(log_action_dist))
            ]).cuda(device=0)

            all_next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                          all_action_ind)
            all_next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                          all_action_ind)

            real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same(
                log_action_dist, all_next_e, all_next_r, method=reduce_method)

            next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], []
            for i in range(batch_size):
                k_prime = min(len(real_log_action_prob[i]), k)
                top_log_prob, top_ind = torch.topk(real_log_action_prob[i],
                                                   k_prime)
                top_next_e, top_next_r, top_ind = real_next_e[i][
                    top_ind], real_next_r[i][top_ind], real_action_ind[i][
                        top_ind]
                next_e_list.append(top_next_e.unsqueeze(0))
                next_r_list.append(top_next_r.unsqueeze(0))
                action_ind_list.append(top_ind.unsqueeze(0))
                log_action_prob_list.append(top_log_prob.unsqueeze(0))

            next_r = ops.pad_and_cat(next_r_list,
                                     padding_value=kg.dummy_r).view(-1)
            next_e = ops.pad_and_cat(next_e_list,
                                     padding_value=kg.dummy_e).view(-1)
            log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                              padding_value=0.0).view(-1)
            action_ind = ops.pad_and_cat(action_ind_list,
                                         padding_value=-1).view(-1)
        else:
            log_action_prob, action_ind = torch.topk(log_action_dist, k)
            next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                      action_ind).view(-1)
            next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                      action_ind).view(-1)

        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset

    def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size] 概率
        :param action_space (r_space, e_space): 实体
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx / action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)

    def adjust_search_trace(search_trace, action_offset):
        for i, (r, e) in enumerate(search_trace):
            new_r = r[action_offset]
            new_e = e[action_offset]
            search_trace[i] = (new_r, new_e)

    # Initialization
    r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)  # WHY 最初为啥要空关系
    seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)  #TODO
    init_action = (r_s, e_s)
    e_s_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_s]),
                       requires_grad=False)
    e_t_abs = var_cuda(torch.LongTensor([kg.get_typeid(_) for _ in e_t]),
                       requires_grad=False)
    # print("()()()es_size() e_s_abs.size():", e_s.size(), e_s_abs.size())
    seen_nodes_abs = int_fill_var_cuda(e_s_abs.size(), kg.dummy_e).unsqueeze(1)
    init_action_abs = (r_s, e_s_abs)
    init_action = (r_s, e_s)

    # path encoder
    # pn.initialize_path(init_action, kg)
    pn.initialize_abs_path(init_action, init_action_abs, kg, same_start)
    if kg.args.save_paths_to_csv:
        search_trace = [(r_s, e_s)]

    # Run beam search for num_steps
    # [batch_size*k], k=1
    log_action_prob = zeros_var_cuda(batch_size)
    if return_path_components:
        log_action_probs = []

    action = init_action
    action_abs = init_action_abs

    for t in range(num_steps):
        last_r, e = action
        last_r, e_abs = action_abs
        assert (q.size() == e_s.size())
        assert (q.size() == e_t.size())
        assert (e.size()[0] % batch_size == 0)
        assert (q.size()[0] % batch_size == 0)
        k = int(e.size()[0] / batch_size)

        # => [batch_size*k]
        q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
        e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
        e_s_abs = ops.tile_along_beam(e_abs.view(batch_size, -1)[:, 0], k)
        e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
        e_t_abs = ops.tile_along_beam(e_t_abs.view(batch_size, -1)[:, 0], k)
        obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
        obs_abs = [
            e_s_abs, q, e_t_abs, t == (num_steps - 1), last_r, seen_nodes_abs
        ]
        # one step forward in search
        # db_outcomes, _, _ = pn.transit_same(
        #     e, obs, kg, use_action_space_bucketing=False,
        #     merge_aspace_batching_outcome=True)  # TODO:细跟一下里面的get_action_space_in_buckets
        db_outcomes, _, _ = pn.transit_same(
            e,
            obs,
            e_abs,
            obs_abs,
            kg,
            use_action_space_bucketing=False,
            merge_aspace_batching_outcome=True
        )  # TODO:细跟一下里面的get_action_space_in_buckets
        action_space, action_dist = db_outcomes[0]
        # => [batch_size*k, action_space_size]
        log_action_dist = log_action_prob.view(-1,
                                               1) + ops.safe_log(action_dist)
        # [batch_size*k, action_space_size] => [batch_size*new_k]

        if t == num_steps - 1:
            if return_merge_scores is None:
                action, log_action_prob, action_offset = top_k_answer_unique(
                    log_action_dist, action_space)
            else:
                action, log_action_prob, action_offset = top_k_action(
                    log_action_dist, action_space, return_merge_scores)
        else:
            action, log_action_prob, action_offset = top_k_action(
                log_action_dist, action_space, None)

        #(next_r, next_e)
        action_abs = (action[0],
                      var_cuda(torch.LongTensor(
                          [kg.get_typeid(_) for _ in action[1]]),
                               requires_grad=False))

        if return_path_components:
            ops.rearrange_vector_list(log_action_probs, action_offset)
            log_action_probs.append(log_action_prob)
        pn.update_path_abs(action_abs, kg, offset=action_offset)
        seen_nodes = torch.cat(
            [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1)
        seen_nodes_abs = torch.cat(
            [seen_nodes_abs[action_offset], action_abs[1].unsqueeze(1)], dim=1)
        if kg.args.save_paths_to_csv:
            adjust_search_trace(search_trace, action_offset)
            search_trace.append(action)

    output_beam_size = int(action[0].size()[0] / batch_size)
    # [batch_size*beam_size] => [batch_size, beam_size]
    beam_search_output = dict()
    beam_search_output['pred_e2s'] = action[1].view(batch_size, -1)
    beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1)

    if return_path_components:
        path_width = 10
        path_components_list = []
        for i in range(batch_size):
            p_c = []
            for k, log_action_prob in enumerate(log_action_probs):
                top_k_edge_labels = []
                for j in range(min(output_beam_size, path_width)):
                    ind = i * output_beam_size + j
                    r = kg.id2relation[int(search_trace[k + 1][0][ind])]
                    e = kg.id2entity[int(search_trace[k + 1][1][ind])]
                    if r.endswith('_inv'):
                        edge_label = '<-{}-{} {}'.format(
                            r[:-4], e, float(log_action_probs[k][ind]))
                    else:
                        edge_label = '-{}->{} {}'.format(
                            r, e, float(log_action_probs[k][ind]))
                    top_k_edge_labels.append(edge_label)
                top_k_action_prob = log_action_prob[:path_width]
                e_name = kg.id2entity[int(
                    search_trace[1][0][i *
                                       output_beam_size])] if k == 0 else ''
                p_c.append((e_name, top_k_edge_labels,
                            var_to_numpy(top_k_action_prob)))
            path_components_list.append(p_c)
        beam_search_output['search_traces'] = search_trace
        SAME_ALL_PATHS.append(beam_search_output)

    return beam_search_output
Пример #9
0
def beam_search(pn,
                e_s,
                q,
                e_t,
                kg,
                num_steps,
                beam_size,
                return_path_components=False):
    """
    Beam search from source.

    :param pn: Policy network.
    :param e_s: (Variable:batch) source entity indices.
    :param q: (Variable:batch) query relation indices.
    :param e_t: (Variable:batch) target entity indices.
    :param kg: Knowledge graph environment.
    :param num_steps: Number of search steps.
    :param beam_size: Beam size used in search.
    :param return_path_components: If set, return all path components at the end of search.
    """
    assert (num_steps >= 1)
    batch_size = len(e_s)

    def top_k_action_r(log_action_dist):
        """
        Get top k relations.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :return:
            next_r, log_action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        action_space_size = log_action_dist.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)
        # [batch_size, k]
        log_action_prob, action_ind = torch.topk(log_action_dist, k)
        next_r = (action_ind % action_space_size).view(-1)
        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return next_r, log_action_prob, action_offset

    def top_k_action(log_action_dist, action_space):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), log_action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)
        # [batch_size, k]
        log_action_prob, action_ind = torch.topk(log_action_dist, k)
        next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                  action_ind).view(-1)
        next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                  action_ind).view(-1)
        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset

    def top_k_answer_unique(log_action_dist, action_space):
        """
        Get top k unique entities
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), log_action_prob, action_offset: [batch_size*k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)
        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]

        r_space = r_space.view(batch_size, -1)
        e_space = e_space.view(batch_size, -1)
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert (beam_action_space_size % action_space_size == 0)
        k = min(beam_size, beam_action_space_size)
        next_r_list, next_e_list = [], []
        log_action_prob_list = []
        action_offset_list = []
        for i in range(batch_size):
            log_action_dist_b = log_action_dist[i]
            r_space_b = r_space[i]
            e_space_b = e_space[i]
            unique_e_space_b = var_cuda(torch.unique(e_space_b.data.cpu()))
            unique_log_action_dist, unique_idx = unique_max(
                unique_e_space_b, e_space_b, log_action_dist_b)
            k_prime = min(len(unique_e_space_b), k)
            top_unique_log_action_dist, top_unique_idx2 = torch.topk(
                unique_log_action_dist, k_prime)
            top_unique_idx = unique_idx[top_unique_idx2]
            top_unique_beam_offset = top_unique_idx / action_space_size
            top_r = r_space_b[top_unique_idx]
            top_e = e_space_b[top_unique_idx]
            next_r_list.append(top_r.unsqueeze(0))
            next_e_list.append(top_e.unsqueeze(0))
            log_action_prob_list.append(
                top_unique_log_action_dist.unsqueeze(0))
            top_unique_batch_offset = i * last_k
            top_unique_action_offset = top_unique_batch_offset + top_unique_beam_offset
            action_offset_list.append(top_unique_action_offset.unsqueeze(0))
        next_r = ops.pad_and_cat(next_r_list,
                                 padding_value=kg.dummy_r).view(-1)
        next_e = ops.pad_and_cat(next_e_list,
                                 padding_value=kg.dummy_e).view(-1)
        log_action_prob = ops.pad_and_cat(log_action_prob_list,
                                          padding_value=-ops.HUGE_INT)
        action_offset = ops.pad_and_cat(action_offset_list, padding_value=-1)
        return (next_r,
                next_e), log_action_prob.view(-1), action_offset.view(-1)

    def adjust_search_trace(search_trace, action_offset):
        for i, (r, e) in enumerate(search_trace):
            new_r = r[action_offset]
            new_e = e[action_offset]
            search_trace[i] = (new_r, new_e)

    def to_one_hot(x):
        y = torch.eye(kg.num_relations).cuda()
        return y[x]

    def adjust_relation_trace(relation_trace, action_offset, relation):
        return torch.cat(
            [relation_trace[action_offset],
             relation.unsqueeze(1)], 1)

    # Initialization
    r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
    seen_nodes = int_fill_var_cuda(e_s.size(), kg.dummy_e).unsqueeze(1)
    init_action = (r_s, e_s)

    # record original q
    init_q = q

    # path encoder
    pn.initialize_path(init_action, kg)
    if kg.args.save_beam_search_paths:
        search_trace = [(r_s, e_s)]
    relation_trace = r_s.unsqueeze(1)

    # Run beam search for num_steps
    # [batch_size*k], k=1
    log_action_prob = zeros_var_cuda(batch_size)
    if return_path_components:
        log_action_probs = []

    action = init_action

    # pretrain evaluation without traversing in the graph
    if kg.args.pretrain and kg.args.pretrain_out_of_graph:
        for t in range(num_steps):
            last_r, e = action
            assert (q.size() == e_s.size())
            assert (q.size() == e_t.size())
            assert (e.size()[0] % batch_size == 0)
            assert (q.size()[0] % batch_size == 0)
            k = int(e.size()[0] / batch_size)
            # => [batch_size*k]
            q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
            e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
            e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None]

            # one step forward in relation search
            r_prob, _ = pn.transit_r(e, obs, kg)

            # => [batch_size*k, relation_space_size]
            log_action_dist = log_action_prob.view(-1,
                                                   1) + ops.safe_log(r_prob)
            #action_space = torch.stack([torch.arange(kg.num_relations)] * len(r_prob), 0).long().cuda()

            action_r, log_action_prob, action_offset = top_k_action_r(
                log_action_dist)

            if return_path_components:
                ops.rearrange_vector_list(log_action_probs, action_offset)
                log_action_probs.append(log_action_prob)
            pn.update_path_r(action_r, kg, offset=action_offset)

            if kg.args.save_beam_search_paths:
                adjust_search_trace(search_trace, action_offset)
            relation_trace = adjust_relation_trace(relation_trace,
                                                   action_offset, action_r)

            # one step forward in entity search
            k = int(action_r.size()[0] / batch_size)
            # => [batch_size*k]
            q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
            e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
            e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
            e = e[action_offset]
            action = (action_r, e)

            seen_nodes = torch.cat(
                [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1)
            if kg.args.save_beam_search_paths:
                search_trace.append(action)

        output_beam_size = int(action[0].size()[0] / batch_size)
        # [batch_size*beam_size] => [batch_size, beam_size]
        beam_search_output = dict()
        beam_search_output['pred_e2s'] = action[1].view(batch_size, -1)
        beam_search_output['pred_e2_scores'] = log_action_prob.view(
            batch_size, -1)
        if kg.args.save_beam_search_paths:
            beam_search_output['search_traces'] = search_trace
        rule_score = torch.zeros(batch_size, output_beam_size).cuda().float()
        top_rules = pickle.load(open(kg.args.rule, 'rb'))
        for i in range(batch_size):
            for j in range(output_beam_size):
                path_ij = relation_trace[i * output_beam_size + j,
                                         1:].cpu().numpy().tolist()
                if not int(init_q[i]) in top_rules.keys():
                    rule_score[i][j] = 0.0
                elif tuple(path_ij) in top_rules[int(init_q[i])].keys():
                    rule_score[i][j] = top_rules[int(
                        init_q[i])][tuple(path_ij)]
                #print(tuple(relation_trace[i * output_beam_size + j, 1 : ]))
                #print(top_rules[int(init_q[i])].keys())
        # rule_score = (rule_score > 0).float()
        beam_search_output["rule_score"] = torch.mean(rule_score, 1)
        assert len(beam_search_output["rule_score"]) == batch_size
        return beam_search_output

    for t in range(num_steps):
        last_r, e = action
        assert (q.size() == e_s.size())
        assert (q.size() == e_t.size())
        assert (e.size()[0] % batch_size == 0)
        assert (q.size()[0] % batch_size == 0)
        k = int(e.size()[0] / batch_size)
        # => [batch_size*k]
        q = ops.tile_along_beam(q.view(batch_size, -1)[:, 0], k)
        e_s = ops.tile_along_beam(e_s.view(batch_size, -1)[:, 0], k)
        e_t = ops.tile_along_beam(e_t.view(batch_size, -1)[:, 0], k)
        obs = [e_s, q, e_t, t == (num_steps - 1), last_r, seen_nodes]
        # one step forward in search
        db_outcomes, _, _ = pn.transit(e,
                                       obs,
                                       kg,
                                       use_action_space_bucketing=True,
                                       merge_aspace_batching_outcome=True)
        action_space, action_dist = db_outcomes[0]

        # incorporate r_prob
        (r_space, e_space), action_mask = action_space
        r_prob, _ = pn.transit_r(e, obs, kg)
        if kg.args.pretrain:
            r_space = torch.ones(len(r_space),
                                 kg.num_relations).long() * torch.arange(
                                     kg.num_relations)
            r_space = r_space.cuda()
            e_space = torch.ones(len(e_space), kg.num_relations).long().cuda()
            action_dist = r_prob
            action_dist = action_dist * kg.r_prob_mask + (
                1 - kg.r_prob_mask) * ops.EPSILON
            action_space = (r_space, e_space), action_mask
        else:
            # r_space_dist = torch.matmul(r_prob.unsqueeze(1), to_one_hot(r_space).transpose(1, 2))
            # action_dist = torch.mul(r_space_dist.squeeze(1), action_dist)
            r_space_dist = torch.gather(r_prob, 1, r_space)
            action_dist = torch.mul(r_space_dist, action_dist)
            action_dist = action_dist * action_mask + (
                1 - action_mask) * ops.EPSILON

        # => [batch_size*k, action_space_size]
        log_action_dist = log_action_prob.view(-1,
                                               1) + ops.safe_log(action_dist)
        # [batch_size*k, action_space_size] => [batch_size*new_k]
        if t == num_steps - 1 and not kg.args.pretrain:
            action, log_action_prob, action_offset = top_k_answer_unique(
                log_action_dist, action_space)
        else:
            action, log_action_prob, action_offset = top_k_action(
                log_action_dist, action_space)
        if return_path_components:
            ops.rearrange_vector_list(log_action_probs, action_offset)
            log_action_probs.append(log_action_prob)
        pn.update_path(action, kg, offset=action_offset)
        seen_nodes = torch.cat(
            [seen_nodes[action_offset], action[1].unsqueeze(1)], dim=1)
        if kg.args.save_beam_search_paths:
            adjust_search_trace(search_trace, action_offset)
            search_trace.append(action)
        relation_trace = adjust_relation_trace(relation_trace, action_offset,
                                               action[0])

    output_beam_size = int(action[0].size()[0] / batch_size)
    # [batch_size*beam_size] => [batch_size, beam_size]
    beam_search_output = dict()
    beam_search_output['pred_e2s'] = action[1].view(batch_size, -1)
    beam_search_output['pred_e2_scores'] = log_action_prob.view(batch_size, -1)
    if kg.args.save_beam_search_paths:
        beam_search_output['search_traces'] = search_trace

    rule_score = torch.zeros(batch_size, output_beam_size).cuda().float()
    top_rules = pickle.load(open(kg.args.rule, 'rb'))
    for i in range(batch_size):
        for j in range(output_beam_size):
            path_ij = relation_trace[i * output_beam_size + j,
                                     1:].cpu().numpy().tolist()
            if not int(init_q[i]) in top_rules.keys():
                rule_score[i][j] = 0.0
            elif tuple(path_ij) in top_rules[int(init_q[i])].keys():
                rule_score[i][j] = top_rules[int(init_q[i])][tuple(path_ij)]
            #print(tuple(relation_trace[i * output_beam_size + j, 1 : ]))
            #print(top_rules[int(init_q[i])].keys())
    # rule_score = (rule_score > 0).float()
    beam_search_output["rule_score"] = torch.mean(rule_score, 1)
    assert len(beam_search_output["rule_score"]) == batch_size

    if return_path_components:
        path_width = 10
        path_components_list = []
        for i in range(batch_size):
            p_c = []
            for k, log_action_prob in enumerate(log_action_probs):
                top_k_edge_labels = []
                for j in range(min(output_beam_size, path_width)):
                    ind = i * output_beam_size + j
                    r = kg.id2relation[int(search_trace[k + 1][0][ind])]
                    e = kg.id2entity[int(search_trace[k + 1][1][ind])]
                    if r.endswith('_inv'):
                        edge_label = ' <-{}- {} {}'.format(
                            r[:-4], e, float(log_action_probs[k][ind]))
                    else:
                        edge_label = ' -{}-> {} {}'.format(
                            r, e, float(log_action_probs[k][ind]))
                    top_k_edge_labels.append(edge_label)
                top_k_action_prob = log_action_prob[:path_width]
                e_name = kg.id2entity[int(
                    search_trace[1][0][i *
                                       output_beam_size])] if k == 0 else ''
                p_c.append((e_name, top_k_edge_labels,
                            var_to_numpy(top_k_action_prob)))
            path_components_list.append(p_c)
        beam_search_output['path_components_list'] = path_components_list
    return beam_search_output