def sample(acsp: ActionSpace, action_dist):
     sample_action_dist = apply_action_dropout_mask(
         action_dist, acsp.action_mask)
     idx = torch.multinomial(sample_action_dist, 1, replacement=True)
     next_r = ops.batch_lookup(acsp.r_space, idx)
     next_e = ops.batch_lookup(acsp.e_space, idx)
     action_prob = ops.batch_lookup(action_dist, idx)
     return Action(next_r, next_e), action_prob
Example #2
0
File: pg.py Project: h-peng17/KGE
 def sample(action_space, action_dist):
     sample_outcome = {}
     ((r_space, e_space), action_mask) = action_space
     sample_action_dist = apply_action_dropout_mask(action_dist, action_mask)
     idx = torch.multinomial(sample_action_dist, 1, replacement=True)
     next_r = ops.batch_lookup(r_space, idx) # (batch_size, sample_size)
     next_e = ops.batch_lookup(e_space, idx)
     action_prob = ops.batch_lookup(action_dist, idx)
     sample_outcome['action_sample'] = (next_r, next_e)
     sample_outcome['action_prob'] = action_prob
     return sample_outcome
Example #3
0
def forward_fact_oracle(e1, r, e2, kg):
    oracle = zeros_var_cuda([len(e1), kg.num_entities]).cuda()
    for i in range(len(e1)):
        _e1, _r = int(e1[i]), int(r[i])
        if _e1 in kg.all_object_vectors and _r in kg.all_object_vectors[_e1]:
            answer_vector = kg.all_object_vectors[_e1][_r]
            oracle[i][answer_vector] = 1
        else:
            raise ValueError("Query answer not found")
    oracle_e2 = ops.batch_lookup(oracle, e2.unsqueeze(1))
    return oracle_e2
Example #4
0
    def top_k_action(log_action_dist, action_space):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)
        # [batch_size, k]
        log_action_prob, action_ind = torch.topk(log_action_dist, k)
        next_r = ops.batch_lookup(r_space.view(batch_size, -1),
                                  action_ind).view(-1)  # [batch_size*k]
        next_e = ops.batch_lookup(e_space.view(batch_size, -1),
                                  action_ind).view(-1)  # [batch_size*k]
        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)  # [batch_size*k]
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) *
                                           last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r,
                next_e), log_action_prob, action_offset  # [batch_size*k]
Example #5
0
        def apply_action_dropout_mask(r_space, action_dist, action_mask):

            bucket_size = len(r_space)
            batch_idx = torch.tensor(offset[self.ct:self.ct +
                                            bucket_size]).cuda()
            r_prob_b = r_prob[batch_idx]

            if not self.pretrain:
                uni_mask = torch.zeros(r_prob_b.shape).cuda()
                uni_mask = torch.scatter(uni_mask, 1, r_space,
                                         torch.ones(
                                             r_space.shape).cuda()).cuda()
                r_prob_b = r_prob_b * uni_mask
            r_prob_sum = torch.sum(r_prob_b, 1)
            is_zero = (r_prob_sum == 0).float().unsqueeze(1)
            r_prob_b = r_prob_b + is_zero * torch.ones(
                r_prob_b[0].size()).cuda()

            r_chosen_b = torch.multinomial(r_prob_b, 1, replacement=True)

            r_prob_chosen_b = ops.batch_lookup(r_prob_b,
                                               r_chosen_b).unsqueeze(1)

            action_mute_mask = (r_space == r_chosen_b).float()
            if self.pretrain:
                action_dist_muted = action_mute_mask * r_prob_chosen_b
            else:
                action_dist_muted = action_dist * action_mute_mask * r_prob_chosen_b

            dist_sum = torch.sum(action_dist_muted, 1)
            is_zero = (dist_sum == 0).float().unsqueeze(1)
            if self.pretrain:
                uniform_dist = torch.ones(action_dist[0].size()).float().cuda()
                action_dist_muted = action_dist_muted + is_zero * uniform_dist
            else:
                action_dist_muted = action_dist_muted + is_zero * action_dist
            self.ct += bucket_size
            new_action_dist = action_dist_muted

            if self.action_dropout_rate > 0:
                rand = torch.rand(action_dist.size())
                action_keep_mask = var_cuda(
                    rand > self.action_dropout_rate).float()
                sample_action_dist = new_action_dist * action_keep_mask + ops.EPSILON * (
                    1 - action_keep_mask) * action_mask
                # if dropout to 0, keep original value
                dist_sum = torch.sum(sample_action_dist, 1)
                is_zero = (dist_sum == 0).float().unsqueeze(1)
                sample_action_dist = sample_action_dist + is_zero * new_action_dist
                return sample_action_dist
            else:
                return new_action_dist
Example #6
0
    def sample_relation(self, r_dist, inv_offset=None):
        """
        Sample a relation based on current policy.
        :param inv_offset: Indexes for restoring original order in a batch.
        :return next_r: Sampled next relation.
        :return r_prob: Probability of the sampled relation.
        """

        if inv_offset is not None:
            raise NotImplementedError('Relation bucket not implemented!')
        else:
            sample_dist = r_dist
            next_r = torch.multinomial(sample_dist, 1, replacement=True)
            r_prob = ops.batch_lookup(sample_dist, next_r)
            sample_outcome = {}
            sample_outcome['action_sample'] = next_r.view(-1)
            sample_outcome['action_prob'] = r_prob.view(-1)

        return sample_outcome
    def top_k_action(log_action_dist, action_space, return_merge_scores=None):
        """
        Get top k actions.
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*k, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*k, action_space_size]
            e_space: [batch_size*k, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*new_k]
        """
        full_size = len(log_action_dist)
        assert (full_size % batch_size == 0)
        last_k = int(full_size / batch_size)

        (r_space, e_space), _ = action_space
        action_space_size = r_space.size()[1]
        # => [batch_size, k'*action_space_size]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        k = min(beam_size, beam_action_space_size)

        if return_merge_scores is not None:
            if return_merge_scores == 'sum':
                reduce_method = torch.sum
            elif return_merge_scores == 'mean':
                reduce_method = torch.mean
            else:
                reduce_method = None

            all_action_ind = torch.LongTensor([range(beam_action_space_size) for _ in range(len(log_action_dist))]).cuda()
            # _, all_action_ind = torch.topk(all_action_ind, beam_action_space_size, largest=False)
            # print("all_action_ind:", all_action_ind.shape, all_action_ind)
            # print ("DEBUG all_action_ind:", all_action_ind.shape, all_action_ind)
            all_next_r = ops.batch_lookup(r_space.view(batch_size, -1), all_action_ind)
            all_next_e = ops.batch_lookup(e_space.view(batch_size, -1), all_action_ind)

            # print ("DEBUG all_next_e:", all_next_e.shape, all_next_e)
            # print ("DEBUG all_next_r:", all_next_r.shape, all_next_r)

            real_log_action_prob, real_next_e, real_action_ind, real_next_r = ops.merge_same(log_action_dist, all_next_e, all_next_r, method=reduce_method)


            # print("DEBUG real_log_action_prob:", real_log_action_prob.shape, real_log_action_prob)
            # print("DEBUG real_next_e:", real_next_e.shape, real_next_e)

            next_e_list, next_r_list, action_ind_list, log_action_prob_list = [], [], [], []
            for i in range(batch_size):
                k_prime = min(len(real_log_action_prob[i]), k)
                top_log_prob, top_ind = torch.topk(real_log_action_prob[i], k_prime)
                top_next_e, top_next_r, top_ind = real_next_e[i][top_ind], real_next_r[i][top_ind], real_action_ind[i][top_ind]
                next_e_list.append(top_next_e.unsqueeze(0))
                next_r_list.append(top_next_r.unsqueeze(0))
                action_ind_list.append(top_ind.unsqueeze(0))
                log_action_prob_list.append(top_log_prob.unsqueeze(0))

            # print("DEBUG -->next_e_list:", next_e_list, next_e_list)
            # print("DEBUG -->next_r_list:", next_r_list, next_r_list)

            next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r).view(-1)
            next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e).view(-1)
            log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0).view(-1)
            action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1).view(-1)

            # print("DEBUG next_r, next_e:", next_e.shape, next_r.shape)

            # next_r = ops.pad_and_cat(next_r_list, padding_value=kg.dummy_r, padding_dim=0).view(-1)
            # next_e = ops.pad_and_cat(next_e_list, padding_value=kg.dummy_e, padding_dim=0).view(-1)
            # log_action_prob = ops.pad_and_cat(log_action_prob_list, padding_value=0.0, padding_dim=0).view(-1)
            # action_ind = ops.pad_and_cat(action_ind_list, padding_value=-1, padding_dim=0).view(-1)
        else:
            log_action_prob, action_ind = torch.topk(log_action_dist, k)
            next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1)
            next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1)

        # print ("log_action_dist:", log_action_dist)
        #old start
        # log_action_prob, action_ind = torch.topk(log_action_dist, k)
        # next_r = ops.batch_lookup(r_space.view(batch_size, -1), action_ind).view(-1)
        # next_e = ops.batch_lookup(e_space.view(batch_size, -1), action_ind).view(-1)
        #old end

        # [batch_size, k] => [batch_size*k]
        log_action_prob = log_action_prob.view(-1)
        # *** compute parent offset
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, 1]
        action_batch_offset = int_var_cuda(torch.arange(batch_size) * last_k).unsqueeze(1)
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        return (next_r, next_e), log_action_prob, action_offset
Example #8
0
    def teacher_forcing_pretrain(self, e_s, q, e_t, num_steps):
        """
        Perform multi-step rollout from the source entity conditioned on the query relation.
        :param pn: Policy network.
        :param e_s: (Variable:batch) source entity indices.
        :param q: (Variable:batch) query relation indices.
        :param e_t: (Variable:batch) target entity indices.
        :param kg: Knowledge graph environment.
        :param num_steps: Number of rollout steps.
        :return log_action_probs: Log probability of the sampled path.
        :return action_entropy: Entropy regularization term.
        """
        assert (num_steps > 0)
        kg, pn = self.kg, self.mdl

        # Initialization
        log_action_probs = []
        action_entropy = []

        r_s = int_fill_var_cuda(e_s.size(), kg.dummy_start_r)
        path_label = []
        cnt = 0
        for q_b in q:
            if int(q_b) not in self.rel2rules.keys():
                path_label.append(
                    torch.randint(kg.num_relations,
                                  (num_steps, )).numpy().tolist())
                continue
            rules = self.rel2rules[int(q_b)]
            cnt += 1
            # uniform
            # sample_rule_id = torch.randint(len(rules.keys()), (1,)).item()

            # weighted by confidence score
            rule_dist_orig = torch.tensor(list(rules.values())).cuda()
            rand = torch.rand(rule_dist_orig.size())
            keep_mask = var_cuda(rand > self.action_dropout_rate).float()
            rule_dist = rule_dist_orig * keep_mask
            rule_sum = torch.sum(rule_dist, 0)
            is_zero = (rule_sum == 0).float()  #.unsqueeze(1)
            rule_dist = rule_dist + is_zero * rule_dist_orig
            sample_rule_id = torch.multinomial(rule_dist, 1).item()
            path_label.append(list(rules.keys())[sample_rule_id])
        path_label = torch.tensor(path_label).cuda()
        #print('rule_path_percentage:', cnt/len(path_label))

        path_trace = [r_s]
        pn.initialize_path((r_s, e_s), kg)
        for t in range(num_steps):
            last_r = path_trace[-1]
            obs = [e_s, q, e_t, t == (num_steps - 1), last_r, None]

            # relation selection
            r_prob, policy_entropy = pn.transit_r(None, obs, kg)
            action_r = path_label[:, t]
            action_prob = ops.batch_lookup(r_prob, action_r.view(-1, 1))
            pn.update_path_r(action_r, kg)

            action_entropy.append(policy_entropy)

            log_action_probs.append(ops.safe_log(action_prob))
            path_trace.append(action_r)

        return {
            'log_action_probs': log_action_probs,
            'action_entropy': action_entropy,
            'path_trace': path_trace
        }