示例#1
0
    def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

        # # copy adj_A batch size
        # adj_A = self.adj_A.unsqueeze(0).repeat(self.batch_size, 1, 1)

        adj_A_new = torch.eye(origin_A.size()[0]).double()#preprocess_adj(origin_A)#
        adj_A_new1 = preprocess_adj_new1(origin_A)
        mat_z = torch.matmul(adj_A_new1, input_z+Wa)-Wa #.unsqueeze(2) #.squeeze(1).unsqueeze(1).repeat(1, self.data_variable_size, 1) # torch.repeat(torch.transpose(input_z), torch.ones(n_in_node), axis=0)

        adj_As = adj_A_new

        #mat_z_max = torch.matmul(adj_A_new, my_normalize(mat_z))

#        mat_z_max = (torch.max(mat_z, torch.matmul(adj_As, mat_z)))
        H3 = F.relu(self.out_fc1((mat_z)))

        #H3_max = torch.matmul(adj_A_new, my_normalize(H3))
#        H3_max = torch.max(H3, torch.matmul(adj_As, H3))

#        H4 = F.relu(self.out_fc2(H3))

        #H4_max = torch.matmul(adj_A_new, my_normalize(H4))
#        H4_max = torch.max(H4, torch.matmul(adj_As, H4))

#        H5 = F.relu(self.out_fc4(H4_max)) + H3

        #H5_max = torch.max(H5, torch.matmul(adj_As, H5))

        # mu and sigma
        out = self.softmax(self.out_fc3(H3)) # discretized log

        return mat_z, out, adj_A_tilt#, self.adj_A
示例#2
0
    def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

        # # copy adj_A batch size
        # adj_A = self.adj_A.unsqueeze(0).repeat(self.batch_size, 1, 1)

        adj_A_new = torch.eye(origin_A.size()[0]).double()#preprocess_adj(origin_A)#
        adj_A_new1 = preprocess_adj_new1(origin_A)
        mat_z = torch.matmul(adj_A_new1, input_z+Wa)-Wa

        adj_As = adj_A_new

        #mat_z_max = torch.matmul(adj_A_new, my_normalize(mat_z))

#        mat_z_max = (torch.max(mat_z, torch.matmul(adj_As, mat_z)))
        H3 = F.relu(self.out_fc1((mat_z)))

        #H3_max = torch.matmul(adj_A_new, my_normalize(H3))
#        H3_max = torch.max(H3, torch.matmul(adj_As, H3))

#        H4 = F.relu(self.out_fc2(H3))

        #H4_max = torch.matmul(adj_A_new, my_normalize(H4))
#        H4_max = torch.max(H4, torch.matmul(adj_As, H4))

#        H5 = F.relu(self.out_fc4(H4_max)) + H3

        #H5_max = torch.max(H5, torch.matmul(adj_As, H5))

        # mu and sigma
        out = self.out_fc3(H3)

        return mat_z, out, adj_A_tilt#, self.adj_A
示例#3
0
    def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

        # adj_A_new1 = (I-A^T)^(-1)
        adj_A_new1 = preprocess_adj_new1(origin_A)
        mat_z = torch.matmul(adj_A_new1, input_z + Wa)
        out = mat_z

        return mat_z, out-Wa, adj_A_tilt
示例#4
0
    def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

        #adj_A_new1 = (I-A^T)^(-1)
        adj_A_new1 = preprocess_adj_new1(origin_A)
        mat_z = torch.matmul(adj_A_new1, input_z+Wa)-Wa

        H3 = F.relu(self.out_fc1((mat_z)))
        out = self.out_fc2(H3)

        return mat_z, out, adj_A_tilt
示例#5
0
    def forward(self, inputs, rel_rec, rel_send):

        if torch.sum(self.adj_A != self.adj_A):
            print('nan error \n')

        adj_A1 = torch.sinh(3.*self.adj_A)

        # adj_A = I-A^T, adj_A_inv = (I-A^T)^(-1)
        adj_A = preprocess_adj_new((adj_A1))
        adj_A_inv = preprocess_adj_new1((adj_A1))

        meanF = torch.matmul(adj_A_inv, torch.mean(torch.matmul(adj_A, inputs), 0))
        logits = torch.matmul(adj_A, inputs-meanF)

        return inputs-meanF, logits, adj_A1, adj_A, self.z, self.z_positive, self.adj_A
示例#6
0
    def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

        # adj_A_new1 = (I-A^T)^(-1)
        adj_A_new1 = preprocess_adj_new1(origin_A)