コード例 #1
0
ファイル: pn.py プロジェクト: h-peng17/KGE
 def policy_nn_fun(X2, action_space):
     (r_space, e_space), action_mask = action_space # (batch_size, action_sapce_size)
     A = self.get_action_embedding((r_space, e_space), kg) # (batch_size, action_space_size, action_dim)
     action_dist = F.softmax(
         torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1)
     # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
     return action_dist, ops.entropy(action_dist) # action_dist (batch_size, action_space_size)
コード例 #2
0
ファイル: graph_walk_agent.py プロジェクト: patzaa/MultiHopKG
 def policy_nn_fun(X2, acs: ActionSpace):
     A = self.get_action_embedding(Action(acs.r_space, acs.e_space), kg)
     action_dist = F.softmax(
         torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
         (1 - acs.action_mask) * ops.HUGE_INT,
         dim=-1,
     )
     # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
     return action_dist, ops.entropy(action_dist)
コード例 #3
0
 def policy_nn_fun(X2, action_space):
     (r_space, e_space), action_mask = action_space
     # ================= newly added ===================
     # A = self.get_action_embedding((r_space, e_space), kg)
     A = self.get_agg_action_embedding((r_space, e_space), kg)
     # ================= newly added ===================
     action_dist = F.softmax(
         torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
         (1 - action_mask) * ops.HUGE_INT,
         dim=-1,
     )
     # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
     return action_dist, ops.entropy(action_dist)
コード例 #4
0
ファイル: pn.py プロジェクト: nwpusunyue/KG-RuleGuider
    def transit_r(self, e, obs, kg):
        """
        Compute the next relation distribution based on
            (a) query relation
            (b) relation history representation
        :param e: agent location (node) at step t.
        :param obs: agent observation at step t.
            e_s: source node
            q: query relation
            e_t: target node
            last_step: If set, the agent is carrying out the last step.
            last_r: label of edge traversed in the previous step
            seen_nodes: notes seen on the paths
        :param kg: Knowledge graph environment.

        """
        e_s, q, e_t, last_step, last_r, seen_nodes = obs


        # Representation of the current state (current node and other observations)
        Q = kg.get_relation_embeddings(q)
        H_r = self.path_r[-1][0][-1, :, :]
        X_r = torch.cat([H_r, Q], dim=-1)


        policy_net_hiddern = F.relu(self.LN1(X_r))
        policy_net_hiddern = self.LN1Dropout(policy_net_hiddern)
        latent_state_emb = self.LN2(policy_net_hiddern)
        latent_state_emb = self.LN2Dropout(latent_state_emb)
        
        all_relation_tensor = kg.get_all_relation_embeddings()
        logit = torch.mm(latent_state_emb, all_relation_tensor.transpose(0, 1))
        # clip
        upper = 1e31 * torch.ones(logit.size()).cuda()
        logit = torch.where(logit < upper, logit, upper)

        max_logit, _ = torch.max(logit, 1, keepdim=True)
        max_logit.detach_()
        # the_log_of_policy_prob = torch.nn.functional.log_softmax(logit - max_logit, dim=1)
        # r_prob = torch.nn.functional.softmax(logit - max_logit, dim=1)
        # r_prob = r_prob * kg.r_prob_mask
        r_prob = torch.nn.functional.softmax(logit - max_logit - (1 - kg.r_prob_mask) * ops.HUGE_INT, dim=1)
        r_entropy = ops.entropy(r_prob)
        # r_prob = r_prob.unsqueeze(1)

        return r_prob, r_entropy
コード例 #5
0
        def policy_nn_fun_fusion(e, X2, action_space, fn, fn_kg):
            (r_space, e_space), action_mask = action_space
            dim1, dim2 = e_space.size()
            e1 = e
            for i in range(dim2 - 1):
                e1 = torch.cat((e1, e), 0)
            e1 = e1.view(dim1, dim2)
            e2 = e_space
            r = r_space
            em_score_can = []
            for i in range(dim1):
                e1_can = torch.squeeze(e1[i])
                r_can = torch.squeeze(r[i])
                e2_can = torch.squeeze(e2[i])
                em_score_can.append(
                    torch.squeeze(fn.forward_fact(e1_can, r_can, e2_can,
                                                  fn_kg)))

            em_score = em_score_can[0]
            for i in range(1, dim1):
                em_score = torch.cat((em_score, em_score_can[i]), 0)
            A = self.get_action_embedding((r_space, e_space), kg)
            # X2_2=torch.unsqueeze(X2, 2)
            # X2_3=A @ X2_2
            # X2_4=torch.squeeze(X2_3, 2)
            # X2_5=X2_4-(1 - action_mask) * ops.HUGE_INT
            # action_dist=F.softmax(X2_5,dim=-1)
            action_dist = F.softmax(
                torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) -
                (1 - action_mask) * ops.HUGE_INT,
                dim=-1)
            # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask)
            #em_score=torch.squeeze(em_score)
            em_score = em_score.view(dim1, dim2)
            #long+short term dicision
            if self.long_short_term == 0:
                action_dist = action_dist * em_score
            #only long
            if self.long_short_term == 1:
                action_dist = action_dist
            #only short
            if self.long_short_term == 2:
                action_dist = em_score
            return action_dist, ops.entropy(action_dist)