예제 #1
0
    def update_path(self, action: Action, kg: KnowledgeGraph, offset=None):
        """
        Once an action was selected, update the action history.
        :param action (r, e): (Variable:batch) indices of the most recent action
            - r is the most recently traversed edge;
            - e is the destination entity.
        :param offset: (Variable:batch) if None, adjust path history with the given offset, used for search
        :param KG: Knowledge graph environment.
        """
        def offset_path_history(p, offset):
            for i, x in enumerate(p):
                if type(x) is tuple:
                    new_tuple = tuple([_x[:, offset, :] for _x in x])
                    p[i] = new_tuple
                else:
                    p[i] = x[offset, :]

        # update action history
        if self.relation_only_in_path:
            action_embedding = kg.get_relation_embeddings(action.rel)
        else:
            action_embedding = self.get_action_embedding(action, kg)
        if offset is not None:
            offset_path_history(self.path, offset)

        self.path.append(
            self.path_encoder(action_embedding.unsqueeze(1), self.path[-1])[1])
예제 #2
0
    def get_action_embedding(self, action: Action, kg: KnowledgeGraph):
        """
        Return (batch) action embedding which is the concatenation of the embeddings of
        the traversed edge and the target node.

        :param action (r, e):
            (Variable:batch) indices of the most recent action
                - r is the most recently traversed edge
                - e is the destination entity.
        :param kg: Knowledge graph enviroment.
        """
        relation_embedding = kg.get_relation_embeddings(action.rel)
        if self.relation_only:
            action_embedding = relation_embedding
        else:
            entity_embedding = kg.get_entity_embeddings(action.ent)
            action_embedding = torch.cat(
                [relation_embedding, entity_embedding], dim=-1)
        return action_embedding
예제 #3
0
 def initialize_path(self, action: Action, kg: KnowledgeGraph):
     # [batch_size, action_dim]
     if self.relation_only_in_path:
         init_action_embedding = kg.get_relation_embeddings(action.rel)
     else:
         init_action_embedding = self.get_action_embedding(action, kg)
     init_action_embedding.unsqueeze_(1)
     # [num_layers, batch_size, dim]
     init_h = zeros_var_cuda([
         self.history_num_layers,
         len(init_action_embedding), self.history_dim
     ])
     init_c = zeros_var_cuda([
         self.history_num_layers,
         len(init_action_embedding), self.history_dim
     ])
     self.path = [
         self.path_encoder(init_action_embedding, (init_h, init_c))[1]
     ]