def encode(self, whole, wholeAdj, lengs, refMat, maxNodes): ### 1 hidden1 = self.sage1(whole, wholeAdj) hidden1 = F.tanh(hidden1) ## BxNxL1 hidden1 = self.bano1(hidden1) hidden1 = self.drop5(hidden1) ### 2 hidden2 = self.sage22(hidden1) #,wholeAdj) hidden2 = F.relu(hidden2) ## BxNxL2 hidden2 = self.bano2(hidden2) hidden2 = self.drop4(hidden2) ### Pool1 pool1 = self.poolit1(hidden2, wholeAdj) pool1 = F.relu(pool1) ## BxNxC1 out1, adj1, _, _ = dense_mincut_pool(hidden2, wholeAdj, pool1) ### 3 hidden3 = self.sage3(out1, adj1) hidden3 = F.relu(hidden3) hidden3 = self.bano3(hidden3) hidden3 = self.drop3(hidden3) ##hidden3=self.drop(hidden3) ### 4 hidden4 = self.sage42(hidden3) #,adj1) hidden4 = F.tanh(hidden4) hidden4 = self.bano4(hidden4) hidden4 = self.drop3(hidden4) ### Pool2 pool2 = self.poolit2(hidden4, adj1) pool2 = F.leaky_relu(pool2) ## BxN/4xC2 out2, adj2, _, _ = dense_mincut_pool(hidden4, adj1, pool2) out2 = self.sage5(out2, adj2) out2 = F.tanh(out2) out2 = self.bano5(out2) out2 = self.drop3(out2) """ ### Pool3 pool3=self.poolit3(out2,adj2) pool3=F.leaky_relu(pool3) ## BxN/8xC3 out3,adj3,_,_=dense_diff_pool(out2,adj2,pool3) """ ### 5 hidden5 = self.tr1(out2) hidden5 = F.relu(hidden5) hidden5 = self.drop3(hidden5) return self.tr2(hidden5), self.tr2(hidden5), adj2
def forward(self, x, edge_index): z = x for conv in self.convs[:-1]: z = self.relu(conv(z, edge_index)) # if not self.variational: z = self.convs[-1](z, edge_index) if self.use_mincut: z_p, mask = to_dense_batch(z, None) adj = to_dense_adj(edge_index, None) s = self.pool1(z) # print(s.shape) # print(np.bincount(s.detach().argmax(1).numpy().flatten())) _, adj, mc1, o1 = dense_mincut_pool(z_p, adj, s, mask) output = dict() if self.variational: output['mu'], output['logvar'] = self.conv_mu( z, edge_index), self.conv_logvar(z, edge_index) output['z'] = self.reparametrize(output['mu'], output['logvar']) # output=[self.conv_mu(z,edge_index), self.conv_logvar(z,edge_index)] else: output['z'] = z # output=[z] if self.prediction_task: output['y'] = self.classification_layer(z) if self.use_mincut: output['s'] = s output['mc1'] = mc1 output['o1'] = o1 # output.extend([s, mc1, o1]) elif self.activate_kmeans: s = self.kmeans(z) output['s'] = s # output.extend([s]) return output
def forward(self, x, edge_index, batch): x = F.relu(self.conv1(x, edge_index)) x, mask = to_dense_batch(x, batch) adj = to_dense_adj(edge_index, batch) s = self.pool1(x) x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask) x = F.relu(self.conv2(x, adj)) s = self.pool2(x) x, adj, mc2, o2 = dense_mincut_pool(x, adj, s) x = self.conv3(x, adj) x = x.mean(dim=1) x = F.relu(self.lin1(x)) x = self.lin2(x) return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2
def test_dense_mincut_pool(): batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) x = torch.randn((batch_size, num_nodes, channels)) adj = torch.ones((batch_size, num_nodes, num_nodes)) s = torch.randn((batch_size, num_nodes, num_clusters)) mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool) x, adj, mincut_loss, ortho_loss = dense_mincut_pool(x, adj, s, mask) assert x.size() == (2, 10, 16) assert adj.size() == (2, 10, 10) assert -1 <= mincut_loss <= 0 assert 0 <= ortho_loss <= 2
def encode(self,whole,adj,lengs,mask,maxNodes): ### 1 hidden=self.sage1(whole,adj) hidden=F.relu(hidden) hidden=self.bano1(hidden) hidden=self.drop5(hidden) ### 2 hidden=self.sage2(hidden,adj) hidden=F.relu(hidden) hidden=self.bano2(hidden) hidden=self.drop4(hidden) ### Pool1 pool1=self.poolit1(hidden) hidden,adj,mc1,o1=dense_mincut_pool(hidden,adj,pool1,mask) ### 3 hidden=self.sage3(hidden,adj) hidden=F.relu(hidden) hidden=self.bano3(hidden) hidden=self.drop3(hidden) ### Pool2 pool2=self.poolit2(hidden) hidden,adj,mc2,o2=dense_mincut_pool(hidden,adj,pool2) hidden=self.sage5(hidden,adj) hidden=F.relu(hidden) hidden=self.bano5(hidden) hidden=self.drop3(hidden) return self.tr2(hidden),self.tr2(hidden), adj,pool1,pool2,mc1+mc2,o1+o2