def __init__( self, initial_entity_emb, initial_relation_emb, entity_out_dim, relation_out_dim, drop_GAT, drop_conv, alpha, alpha_conv, nheads_GAT, conv_out_channels, ): """Sparse version of KBGAT entity_in_dim -> Entity Input Embedding dimensions entity_out_dim -> Entity Output Embedding dimensions, passed as a list num_relation -> number of unique relations relation_dim -> Relation Embedding dimensions num_nodes -> number of nodes in the Graph nheads_GAT -> Used for Multihead attention, passed as a list """ super().__init__() self.num_nodes = initial_entity_emb.shape[0] self.entity_in_dim = initial_entity_emb.shape[1] self.entity_out_dim_1 = entity_out_dim[0] self.nheads_GAT_1 = nheads_GAT[0] self.entity_out_dim_2 = entity_out_dim[1] self.nheads_GAT_2 = nheads_GAT[1] # Properties of Relations self.num_relation = initial_relation_emb.shape[0] self.relation_dim = initial_relation_emb.shape[1] self.relation_out_dim_1 = relation_out_dim[0] self.drop_GAT = drop_GAT self.drop_conv = drop_conv self.alpha = alpha # For leaky relu self.alpha_conv = alpha_conv self.conv_out_channels = conv_out_channels self.final_entity_embeddings = nn.Parameter( torch.randn(self.num_nodes, self.entity_out_dim_1 * self.nheads_GAT_1)) self.final_relation_embeddings = nn.Parameter( torch.randn(self.num_relation, self.entity_out_dim_1 * self.nheads_GAT_1)) self.convKB = ConvKB( self.entity_out_dim_1 * self.nheads_GAT_1, 3, 1, self.conv_out_channels, self.drop_conv, self.alpha_conv, )
def __init__(self, final_entity_emb, final_relation_emb, entity_out_dim, relation_out_dim, drop_conv, alpha_conv, nheads_GAT, conv_out_channels, variational, temperature, sigma_p): # NOTE removed alpha as it doesn't seem to get used ''' Sparse version of KBGAT entity_in_dim -> Entity Input Embedding dimensions entity_out_dim -> Entity Output Embedding dimensions, passed as a list num_relation -> number of unique relations relation_dim -> Relation Embedding dimensions num_nodes -> number of nodes in the Graph nheads_GAT -> Used for Multihead attention, passed as a list ''' super().__init__() self.num_nodes = final_entity_emb.shape[0] emb_dim = entity_out_dim[0] * nheads_GAT[0] # Properties of Relations self.num_relation = final_relation_emb.shape[0] self.relation_dim = final_relation_emb.shape[1] self.relation_out_dim_1 = relation_out_dim[0] self.drop_conv = drop_conv # self.alpha = alpha # For leaky relu self.alpha_conv = alpha_conv self.conv_out_channels = conv_out_channels self.variational = variational self.temperature = temperature self.sigma_p = sigma_p assert final_entity_emb.shape == ( self.num_nodes, emb_dim, ) assert final_relation_emb.shape == ( self.num_relation, emb_dim, ) self.entity_embeddings_from_gat = final_entity_emb.clone( ) # requires we always load GAT before initialising this self.relation_embeddings_from_gat = final_relation_emb.clone() self.final_entity_embeddings_mean = nn.Parameter( final_entity_emb.clone()) self.final_relation_embeddings_mean = nn.Parameter( final_relation_emb.clone() ) # this is learnable more. is this desired? if self.variational: self.entity_logstddev = nn.Parameter(final_entity_emb * 0 - 2) self.relation_logstddev = nn.Parameter(final_relation_emb * 0 - 2) self.convKB = ConvKB(emb_dim, 3, 1, self.conv_out_channels, self.drop_conv, self.alpha_conv)