def update_path(self, action: Action, kg: KnowledgeGraph, offset=None): """ Once an action was selected, update the action history. :param action (r, e): (Variable:batch) indices of the most recent action - r is the most recently traversed edge; - e is the destination entity. :param offset: (Variable:batch) if None, adjust path history with the given offset, used for search :param KG: Knowledge graph environment. """ def offset_path_history(p, offset): for i, x in enumerate(p): if type(x) is tuple: new_tuple = tuple([_x[:, offset, :] for _x in x]) p[i] = new_tuple else: p[i] = x[offset, :] # update action history if self.relation_only_in_path: action_embedding = kg.get_relation_embeddings(action.rel) else: action_embedding = self.get_action_embedding(action, kg) if offset is not None: offset_path_history(self.path, offset) self.path.append( self.path_encoder(action_embedding.unsqueeze(1), self.path[-1])[1])
def get_action_embedding(self, action: Action, kg: KnowledgeGraph): """ Return (batch) action embedding which is the concatenation of the embeddings of the traversed edge and the target node. :param action (r, e): (Variable:batch) indices of the most recent action - r is the most recently traversed edge - e is the destination entity. :param kg: Knowledge graph enviroment. """ relation_embedding = kg.get_relation_embeddings(action.rel) if self.relation_only: action_embedding = relation_embedding else: entity_embedding = kg.get_entity_embeddings(action.ent) action_embedding = torch.cat( [relation_embedding, entity_embedding], dim=-1) return action_embedding
def initialize_path(self, action: Action, kg: KnowledgeGraph): # [batch_size, action_dim] if self.relation_only_in_path: init_action_embedding = kg.get_relation_embeddings(action.rel) else: init_action_embedding = self.get_action_embedding(action, kg) init_action_embedding.unsqueeze_(1) # [num_layers, batch_size, dim] init_h = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) init_c = zeros_var_cuda([ self.history_num_layers, len(init_action_embedding), self.history_dim ]) self.path = [ self.path_encoder(init_action_embedding, (init_h, init_c))[1] ]