Ejemplo n.º 1
0
    def __init__(
        self,
        initial_entity_emb,
        initial_relation_emb,
        entity_out_dim,
        relation_out_dim,
        drop_GAT,
        drop_conv,
        alpha,
        alpha_conv,
        nheads_GAT,
        conv_out_channels,
    ):
        """Sparse version of KBGAT
        entity_in_dim -> Entity Input Embedding dimensions
        entity_out_dim  -> Entity Output Embedding dimensions, passed as a list
        num_relation -> number of unique relations
        relation_dim -> Relation Embedding dimensions
        num_nodes -> number of nodes in the Graph
        nheads_GAT -> Used for Multihead attention, passed as a list """

        super().__init__()

        self.num_nodes = initial_entity_emb.shape[0]
        self.entity_in_dim = initial_entity_emb.shape[1]
        self.entity_out_dim_1 = entity_out_dim[0]
        self.nheads_GAT_1 = nheads_GAT[0]
        self.entity_out_dim_2 = entity_out_dim[1]
        self.nheads_GAT_2 = nheads_GAT[1]

        # Properties of Relations
        self.num_relation = initial_relation_emb.shape[0]
        self.relation_dim = initial_relation_emb.shape[1]
        self.relation_out_dim_1 = relation_out_dim[0]

        self.drop_GAT = drop_GAT
        self.drop_conv = drop_conv
        self.alpha = alpha  # For leaky relu
        self.alpha_conv = alpha_conv
        self.conv_out_channels = conv_out_channels

        self.final_entity_embeddings = nn.Parameter(
            torch.randn(self.num_nodes,
                        self.entity_out_dim_1 * self.nheads_GAT_1))

        self.final_relation_embeddings = nn.Parameter(
            torch.randn(self.num_relation,
                        self.entity_out_dim_1 * self.nheads_GAT_1))

        self.convKB = ConvKB(
            self.entity_out_dim_1 * self.nheads_GAT_1,
            3,
            1,
            self.conv_out_channels,
            self.drop_conv,
            self.alpha_conv,
        )
Ejemplo n.º 2
0
    def __init__(self, final_entity_emb, final_relation_emb, entity_out_dim,
                 relation_out_dim, drop_conv, alpha_conv, nheads_GAT,
                 conv_out_channels, variational, temperature,
                 sigma_p):  # NOTE removed alpha as it doesn't seem to get used
        '''
        Sparse version of KBGAT
        entity_in_dim -> Entity Input Embedding dimensions
        entity_out_dim  -> Entity Output Embedding dimensions, passed as a list
        num_relation -> number of unique relations
        relation_dim -> Relation Embedding dimensions
        num_nodes -> number of nodes in the Graph
        nheads_GAT -> Used for Multihead attention, passed as a list
        '''

        super().__init__()

        self.num_nodes = final_entity_emb.shape[0]
        emb_dim = entity_out_dim[0] * nheads_GAT[0]

        # Properties of Relations
        self.num_relation = final_relation_emb.shape[0]
        self.relation_dim = final_relation_emb.shape[1]
        self.relation_out_dim_1 = relation_out_dim[0]

        self.drop_conv = drop_conv
        # self.alpha = alpha      # For leaky relu
        self.alpha_conv = alpha_conv
        self.conv_out_channels = conv_out_channels

        self.variational = variational
        self.temperature = temperature
        self.sigma_p = sigma_p

        assert final_entity_emb.shape == (
            self.num_nodes,
            emb_dim,
        )
        assert final_relation_emb.shape == (
            self.num_relation,
            emb_dim,
        )

        self.entity_embeddings_from_gat = final_entity_emb.clone(
        )  # requires we always load GAT before initialising this
        self.relation_embeddings_from_gat = final_relation_emb.clone()

        self.final_entity_embeddings_mean = nn.Parameter(
            final_entity_emb.clone())
        self.final_relation_embeddings_mean = nn.Parameter(
            final_relation_emb.clone()
        )  # this is learnable more. is this desired?
        if self.variational:
            self.entity_logstddev = nn.Parameter(final_entity_emb * 0 - 2)
            self.relation_logstddev = nn.Parameter(final_relation_emb * 0 - 2)

        self.convKB = ConvKB(emb_dim, 3, 1, self.conv_out_channels,
                             self.drop_conv, self.alpha_conv)