def forward(self, data): data.x = self.inputnet(data.x) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv1.flow)) data.x = self.edgeconv1(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) data.edge_attr = None data = avg_pool(cluster, data) data.edge_index = to_undirected( knn_graph(data.x, self.k, data.batch, loop=False, flow=self.edgeconv2.flow)) data.x = self.edgeconv2(data.x, data.edge_index) weight = normalized_cut_2d(data.edge_index, data.x) cluster = graclus(data.edge_index, weight, data.x.size(0)) x, batch = avg_pool_x(cluster, data.x, data.batch) x = global_mean_pool(x, batch) logits = self.output(x).squeeze(-1) return logits
def test_single_voxel_grid(): pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]]) batch = torch.tensor([0, 0, 0, 1, 1]) x = torch.randn(5, 16) cluster = voxel_grid(pos, size=5, batch=batch) assert cluster.tolist() == [0, 0, 0, 1, 1] data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch) data = avg_pool(cluster, data) cluster_no_batch = voxel_grid(pos, size=5) assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0] data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos) data_no_batch = avg_pool(cluster_no_batch, data_no_batch)
def forward(self, data): x, edge_index_1 = data.x, data.edge_index # define downscaled samples. cluster1 = graclus(edge_index_1, num_nodes=x.shape[0]) downsample_1 = avg_pool(cluster1, data) edge_index_2 = downsample_1.edge_index cluster2 = graclus(edge_index_2, num_nodes=downsample_1.x.shape[0]) downsample_2 = avg_pool(cluster2, downsample_1) edge_index_3 = downsample_2.edge_index x = self.conv1(x, edge_index_1) x = self.s1(x) inter1 = data inter1.x = x inter1 = max_pool(cluster1, inter1) x2 = inter1.x x2 = torch.cat((self.affine1(downsample_1.x), x2), dim=1) x2 = self.conv2(x2, edge_index_2) x2 = self.s2(x2) inter2 = inter1 inter2.x = x2 inter2 = max_pool(cluster2, inter2) x3 = inter2.x x3 = torch.cat((self.affine2(downsample_2.x), x3), dim=1) x3 = self.conv3(x3, edge_index_3) x3 = self.s3(x3) x3 = knn_interpolate(x3, downsample_2.pos, downsample_1.pos) x2 = torch.cat((x2, x3), dim=1) x2 = knn_interpolate(x2, downsample_1.pos, data.pos) x = torch.cat((x, x2), dim=1) x = self.conv4(x, edge_index_1) x = self.s4(x) x = self.conv5(x, edge_index_1) x = self.s5(x) x = self.s6(self.lin1(x)) x = self.s7(self.lin2(x)) return torch.sigmoid(self.out(x))
def test_avg_pool(): cluster = torch.tensor([0, 1, 0, 1, 2, 2]) x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) edge_attr = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) batch = torch.tensor([0, 0, 0, 0, 1, 1]) data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, batch=batch) data = avg_pool(cluster, data, transform=lambda x: x) assert data.x.tolist() == [[3, 4], [5, 6], [10, 11]] assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]] assert data.edge_index.tolist() == [[0, 1], [1, 0]] assert data.edge_attr.tolist() == [4, 4] assert data.batch.tolist() == [0, 0, 1]
def forward(self, data): # (1/32,V_0/V_1) # pre-pool1 pos1 = data.pos edge_index1 = data.edge_index x_pre = data.x.clone().detach() batch1 = data.batch if hasattr(data, 'batch') else None # convolution data.x = F.elu(self.conv1a(data.x, data.edge_index, data.edge_attr)) data.x = F.elu(self.conv1b(data.x, data.edge_index, data.edge_attr)) # clustering weight = normalized_cut_2d(data.edge_index, data.pos) cluster1 = graclus(data.edge_index, weight, data.x.size(0)) weights1 = pweights(x_pre, cluster1) # pooling data = avg_pool(cluster1, data, transform=T.Cartesian(cat=False)) # (32/64,V_1/V_2) # pre-pool2 pos2 = data.pos edge_index2 = data.edge_index x_pre = data.x.clone().detach() batch2 = data.batch if hasattr(data, 'batch') else None # convolution data.x = F.elu(self.conv2a(data.x, data.edge_index, data.edge_attr)) data.x = F.elu(self.conv2b(data.x, data.edge_index, data.edge_attr)) data.x = self.bn2(data.x) # clustering weight = normalized_cut_2d(data.edge_index, data.pos) cluster2 = graclus(data.edge_index, weight, data.x.size(0)) weights2 = pweights(x_pre, cluster2) # pooling data = avg_pool(cluster2, data, transform=T.Cartesian(cat=False)) pool2 = data.clone() # 64/128,V_2/V_3 # pre-pool1 pos3 = data.pos edge_index3 = data.edge_index x_pre = data.x.clone().detach() batch3 = data.batch if hasattr(data, 'batch') else None # convolution data.x = F.elu(self.conv3a(data.x, data.edge_index, data.edge_attr)) data.x = F.elu(self.conv3b(data.x, data.edge_index, data.edge_attr)) data.x = self.bn3(data.x) # clustering weight = normalized_cut_2d(data.edge_index, data.pos) cluster3 = graclus(data.edge_index, weight, data.x.size(0)) weights3 = pweights(x_pre, cluster3) # pooling data = avg_pool(cluster3, data, transform=T.Cartesian(cat=False)) # upsample # data = recover_grid_barycentric(data, weights=weights2, pos=pos2, edge_index=edge_index2, cluster=cluster2, # batch=batch2, transform=None) data.x = F.elu(self.score_fr1(data.x, data.edge_index, data.edge_attr)) data = recover_grid_barycentric(data, weights=weights3, pos=pos3, edge_index=edge_index3, cluster=cluster3, batch=batch3, transform=T.Cartesian(cat=False)) data.x = F.elu(self.score_fr2(data.x, data.edge_index, data.edge_attr)) pool2.x = F.elu( self.score_pool2(pool2.x, pool2.edge_index, pool2.edge_attr)) # data = recover_grid_barycentric(data, weights=weights1, pos=pos1, edge_index=edge_index1, cluster=cluster1, # batch=batch1, transform=None) data.x = data.x + pool2.x data = recover_grid_barycentric(data, weights=weights2, pos=pos2, edge_index=edge_index2, cluster=cluster2, batch=batch2, transform=T.Cartesian(cat=False)) data = recover_grid_barycentric(data, weights=weights1, pos=pos1, edge_index=edge_index1, cluster=cluster1, batch=batch1, transform=T.Cartesian(cat=False)) # data.x = F.elu(self.convout(data.x, data.edge_index, data.edge_attr)) x = data.x return x