Example #1
0
    def encode(self, whole, wholeAdj, lengs, refMat, maxNodes):
        ### 1
        hidden1 = self.sage1(whole, wholeAdj)
        hidden1 = F.tanh(hidden1)  ## BxNxL1
        hidden1 = self.bano1(hidden1)
        hidden1 = self.drop5(hidden1)

        ### 2
        hidden2 = self.sage22(hidden1)  #,wholeAdj)
        hidden2 = F.relu(hidden2)  ## BxNxL2
        hidden2 = self.bano2(hidden2)
        hidden2 = self.drop4(hidden2)

        ### Pool1
        pool1 = self.poolit1(hidden2, wholeAdj)
        pool1 = F.relu(pool1)  ## BxNxC1

        out1, adj1, _, _ = dense_mincut_pool(hidden2, wholeAdj, pool1)

        ### 3
        hidden3 = self.sage3(out1, adj1)
        hidden3 = F.relu(hidden3)
        hidden3 = self.bano3(hidden3)
        hidden3 = self.drop3(hidden3)

        ##hidden3=self.drop(hidden3)

        ### 4
        hidden4 = self.sage42(hidden3)  #,adj1)

        hidden4 = F.tanh(hidden4)
        hidden4 = self.bano4(hidden4)
        hidden4 = self.drop3(hidden4)

        ### Pool2
        pool2 = self.poolit2(hidden4, adj1)
        pool2 = F.leaky_relu(pool2)  ## BxN/4xC2

        out2, adj2, _, _ = dense_mincut_pool(hidden4, adj1, pool2)

        out2 = self.sage5(out2, adj2)
        out2 = F.tanh(out2)
        out2 = self.bano5(out2)
        out2 = self.drop3(out2)
        """
        ### Pool3
        pool3=self.poolit3(out2,adj2)
        pool3=F.leaky_relu(pool3) ## BxN/8xC3

        out3,adj3,_,_=dense_diff_pool(out2,adj2,pool3)
        """
        ### 5
        hidden5 = self.tr1(out2)
        hidden5 = F.relu(hidden5)
        hidden5 = self.drop3(hidden5)

        return self.tr2(hidden5), self.tr2(hidden5), adj2
Example #2
0
 def forward(self, x, edge_index):
     z = x
     for conv in self.convs[:-1]:
         z = self.relu(conv(z, edge_index))
     # if not self.variational:
     z = self.convs[-1](z, edge_index)
     if self.use_mincut:
         z_p, mask = to_dense_batch(z, None)
         adj = to_dense_adj(edge_index, None)
         s = self.pool1(z)
         # print(s.shape)
         # print(np.bincount(s.detach().argmax(1).numpy().flatten()))
         _, adj, mc1, o1 = dense_mincut_pool(z_p, adj, s, mask)
     output = dict()
     if self.variational:
         output['mu'], output['logvar'] = self.conv_mu(
             z, edge_index), self.conv_logvar(z, edge_index)
         output['z'] = self.reparametrize(output['mu'], output['logvar'])
         # output=[self.conv_mu(z,edge_index), self.conv_logvar(z,edge_index)]
     else:
         output['z'] = z
         # output=[z]
     if self.prediction_task:
         output['y'] = self.classification_layer(z)
     if self.use_mincut:
         output['s'] = s
         output['mc1'] = mc1
         output['o1'] = o1
         # output.extend([s, mc1, o1])
     elif self.activate_kmeans:
         s = self.kmeans(z)
         output['s'] = s
         # output.extend([s])
     return output
Example #3
0
    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))

        x, mask = to_dense_batch(x, batch)
        adj = to_dense_adj(edge_index, batch)

        s = self.pool1(x)
        x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)

        x = F.relu(self.conv2(x, adj))
        s = self.pool2(x)

        x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)

        x = self.conv3(x, adj)

        x = x.mean(dim=1)
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2
Example #4
0
def test_dense_mincut_pool():
    batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)
    x = torch.randn((batch_size, num_nodes, channels))
    adj = torch.ones((batch_size, num_nodes, num_nodes))
    s = torch.randn((batch_size, num_nodes, num_clusters))
    mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)

    x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s, mask)
    assert x.size() == (2, 10, 16)
    assert adj.size() == (2, 10, 10)
    assert -1 <= mincut_loss <= 0
    assert 0 <= ortho_loss <= 2
Example #5
0
    def encode(self,whole,adj,lengs,mask,maxNodes):  
        ### 1 
        hidden=self.sage1(whole,adj)
        hidden=F.relu(hidden)
        hidden=self.bano1(hidden)
        hidden=self.drop5(hidden)
      
        ### 2
        hidden=self.sage2(hidden,adj)
        hidden=F.relu(hidden) 
        hidden=self.bano2(hidden)
        hidden=self.drop4(hidden)

        ### Pool1
        pool1=self.poolit1(hidden)
 
        hidden,adj,mc1,o1=dense_mincut_pool(hidden,adj,pool1,mask)

           
        ### 3
        hidden=self.sage3(hidden,adj)
        hidden=F.relu(hidden)        
        hidden=self.bano3(hidden)
        hidden=self.drop3(hidden)

        ### Pool2
        pool2=self.poolit2(hidden)

        hidden,adj,mc2,o2=dense_mincut_pool(hidden,adj,pool2)

        hidden=self.sage5(hidden,adj)
        hidden=F.relu(hidden) 
        hidden=self.bano5(hidden)
        hidden=self.drop3(hidden)
 
        return self.tr2(hidden),self.tr2(hidden), adj,pool1,pool2,mc1+mc2,o1+o2