def policy_nn_fun(X2, action_space): (r_space, e_space), action_mask = action_space # (batch_size, action_sapce_size) A = self.get_action_embedding((r_space, e_space), kg) # (batch_size, action_space_size, action_dim) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist) # action_dist (batch_size, action_space_size)
def policy_nn_fun(X2, acs: ActionSpace): A = self.get_action_embedding(Action(acs.r_space, acs.e_space), kg) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - acs.action_mask) * ops.HUGE_INT, dim=-1, ) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist)
def policy_nn_fun(X2, action_space): (r_space, e_space), action_mask = action_space # ================= newly added =================== # A = self.get_action_embedding((r_space, e_space), kg) A = self.get_agg_action_embedding((r_space, e_space), kg) # ================= newly added =================== action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1, ) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) return action_dist, ops.entropy(action_dist)
def transit_r(self, e, obs, kg): """ Compute the next relation distribution based on (a) query relation (b) relation history representation :param e: agent location (node) at step t. :param obs: agent observation at step t. e_s: source node q: query relation e_t: target node last_step: If set, the agent is carrying out the last step. last_r: label of edge traversed in the previous step seen_nodes: notes seen on the paths :param kg: Knowledge graph environment. """ e_s, q, e_t, last_step, last_r, seen_nodes = obs # Representation of the current state (current node and other observations) Q = kg.get_relation_embeddings(q) H_r = self.path_r[-1][0][-1, :, :] X_r = torch.cat([H_r, Q], dim=-1) policy_net_hiddern = F.relu(self.LN1(X_r)) policy_net_hiddern = self.LN1Dropout(policy_net_hiddern) latent_state_emb = self.LN2(policy_net_hiddern) latent_state_emb = self.LN2Dropout(latent_state_emb) all_relation_tensor = kg.get_all_relation_embeddings() logit = torch.mm(latent_state_emb, all_relation_tensor.transpose(0, 1)) # clip upper = 1e31 * torch.ones(logit.size()).cuda() logit = torch.where(logit < upper, logit, upper) max_logit, _ = torch.max(logit, 1, keepdim=True) max_logit.detach_() # the_log_of_policy_prob = torch.nn.functional.log_softmax(logit - max_logit, dim=1) # r_prob = torch.nn.functional.softmax(logit - max_logit, dim=1) # r_prob = r_prob * kg.r_prob_mask r_prob = torch.nn.functional.softmax(logit - max_logit - (1 - kg.r_prob_mask) * ops.HUGE_INT, dim=1) r_entropy = ops.entropy(r_prob) # r_prob = r_prob.unsqueeze(1) return r_prob, r_entropy
def policy_nn_fun_fusion(e, X2, action_space, fn, fn_kg): (r_space, e_space), action_mask = action_space dim1, dim2 = e_space.size() e1 = e for i in range(dim2 - 1): e1 = torch.cat((e1, e), 0) e1 = e1.view(dim1, dim2) e2 = e_space r = r_space em_score_can = [] for i in range(dim1): e1_can = torch.squeeze(e1[i]) r_can = torch.squeeze(r[i]) e2_can = torch.squeeze(e2[i]) em_score_can.append( torch.squeeze(fn.forward_fact(e1_can, r_can, e2_can, fn_kg))) em_score = em_score_can[0] for i in range(1, dim1): em_score = torch.cat((em_score, em_score_can[i]), 0) A = self.get_action_embedding((r_space, e_space), kg) # X2_2=torch.unsqueeze(X2, 2) # X2_3=A @ X2_2 # X2_4=torch.squeeze(X2_3, 2) # X2_5=X2_4-(1 - action_mask) * ops.HUGE_INT # action_dist=F.softmax(X2_5,dim=-1) action_dist = F.softmax( torch.squeeze(A @ torch.unsqueeze(X2, 2), 2) - (1 - action_mask) * ops.HUGE_INT, dim=-1) # action_dist = ops.weighted_softmax(torch.squeeze(A @ torch.unsqueeze(X2, 2), 2), action_mask) #em_score=torch.squeeze(em_score) em_score = em_score.view(dim1, dim2) #long+short term dicision if self.long_short_term == 0: action_dist = action_dist * em_score #only long if self.long_short_term == 1: action_dist = action_dist #only short if self.long_short_term == 2: action_dist = em_score return action_dist, ops.entropy(action_dist)