Example #1
0
    def forward(self, data):
        data.x = self.inputnet(data.x)
        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv1.flow))
        data.x = self.edgeconv1(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        data.edge_attr = None
        data = avg_pool(cluster, data)

        data.edge_index = to_undirected(
            knn_graph(data.x,
                      self.k,
                      data.batch,
                      loop=False,
                      flow=self.edgeconv2.flow))
        data.x = self.edgeconv2(data.x, data.edge_index)

        weight = normalized_cut_2d(data.edge_index, data.x)
        cluster = graclus(data.edge_index, weight, data.x.size(0))
        x, batch = avg_pool_x(cluster, data.x, data.batch)

        x = global_mean_pool(x, batch)
        logits = self.output(x).squeeze(-1)
        return logits
def test_single_voxel_grid():
    pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]])
    edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]])
    batch = torch.tensor([0, 0, 0, 1, 1])
    x = torch.randn(5, 16)

    cluster = voxel_grid(pos, size=5, batch=batch)
    assert cluster.tolist() == [0, 0, 0, 1, 1]

    data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch)
    data = avg_pool(cluster, data)

    cluster_no_batch = voxel_grid(pos, size=5)
    assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0]

    data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos)
    data_no_batch = avg_pool(cluster_no_batch, data_no_batch)
Example #3
0
    def forward(self, data):
        x, edge_index_1 = data.x, data.edge_index
        # define downscaled samples.
        cluster1 = graclus(edge_index_1, num_nodes=x.shape[0])
        downsample_1 = avg_pool(cluster1, data)
        edge_index_2 = downsample_1.edge_index
        cluster2 = graclus(edge_index_2, num_nodes=downsample_1.x.shape[0])
        downsample_2 = avg_pool(cluster2, downsample_1)
        edge_index_3 = downsample_2.edge_index

        x = self.conv1(x, edge_index_1)
        x = self.s1(x)
        inter1 = data
        inter1.x = x
        inter1 = max_pool(cluster1, inter1)
        x2 = inter1.x
        x2 = torch.cat((self.affine1(downsample_1.x), x2), dim=1)
        x2 = self.conv2(x2, edge_index_2)
        x2 = self.s2(x2)

        inter2 = inter1
        inter2.x = x2
        inter2 = max_pool(cluster2, inter2)
        x3 = inter2.x
        x3 = torch.cat((self.affine2(downsample_2.x), x3), dim=1)
        x3 = self.conv3(x3, edge_index_3)
        x3 = self.s3(x3)

        x3 = knn_interpolate(x3, downsample_2.pos, downsample_1.pos)
        x2 = torch.cat((x2, x3), dim=1)
        x2 = knn_interpolate(x2, downsample_1.pos, data.pos)
        x = torch.cat((x, x2), dim=1)

        x = self.conv4(x, edge_index_1)
        x = self.s4(x)
        x = self.conv5(x, edge_index_1)
        x = self.s5(x)
        x = self.s6(self.lin1(x))
        x = self.s7(self.lin2(x))

        return torch.sigmoid(self.out(x))
def test_avg_pool():
    cluster = torch.tensor([0, 1, 0, 1, 2, 2])
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    pos = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
    edge_attr = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
    batch = torch.tensor([0, 0, 0, 0, 1, 1])

    data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,
                 batch=batch)

    data = avg_pool(cluster, data, transform=lambda x: x)

    assert data.x.tolist() == [[3, 4], [5, 6], [10, 11]]
    assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]]
    assert data.edge_index.tolist() == [[0, 1], [1, 0]]
    assert data.edge_attr.tolist() == [4, 4]
    assert data.batch.tolist() == [0, 0, 1]
Example #5
0
    def forward(self, data):
        # (1/32,V_0/V_1)
        # pre-pool1
        pos1 = data.pos
        edge_index1 = data.edge_index
        x_pre = data.x.clone().detach()
        batch1 = data.batch if hasattr(data, 'batch') else None
        # convolution
        data.x = F.elu(self.conv1a(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv1b(data.x, data.edge_index, data.edge_attr))
        # clustering
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster1 = graclus(data.edge_index, weight, data.x.size(0))
        weights1 = pweights(x_pre, cluster1)
        # pooling
        data = avg_pool(cluster1, data, transform=T.Cartesian(cat=False))

        # (32/64,V_1/V_2)

        # pre-pool2
        pos2 = data.pos
        edge_index2 = data.edge_index
        x_pre = data.x.clone().detach()
        batch2 = data.batch if hasattr(data, 'batch') else None
        # convolution
        data.x = F.elu(self.conv2a(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv2b(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn2(data.x)
        # clustering
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster2 = graclus(data.edge_index, weight, data.x.size(0))
        weights2 = pweights(x_pre, cluster2)
        # pooling
        data = avg_pool(cluster2, data, transform=T.Cartesian(cat=False))
        pool2 = data.clone()

        # 64/128,V_2/V_3
        # pre-pool1
        pos3 = data.pos
        edge_index3 = data.edge_index
        x_pre = data.x.clone().detach()
        batch3 = data.batch if hasattr(data, 'batch') else None
        # convolution
        data.x = F.elu(self.conv3a(data.x, data.edge_index, data.edge_attr))
        data.x = F.elu(self.conv3b(data.x, data.edge_index, data.edge_attr))
        data.x = self.bn3(data.x)
        # clustering
        weight = normalized_cut_2d(data.edge_index, data.pos)
        cluster3 = graclus(data.edge_index, weight, data.x.size(0))
        weights3 = pweights(x_pre, cluster3)
        # pooling
        data = avg_pool(cluster3, data, transform=T.Cartesian(cat=False))

        # upsample
        # data = recover_grid_barycentric(data, weights=weights2, pos=pos2, edge_index=edge_index2, cluster=cluster2,
        #                                  batch=batch2, transform=None)
        data.x = F.elu(self.score_fr1(data.x, data.edge_index, data.edge_attr))
        data = recover_grid_barycentric(data,
                                        weights=weights3,
                                        pos=pos3,
                                        edge_index=edge_index3,
                                        cluster=cluster3,
                                        batch=batch3,
                                        transform=T.Cartesian(cat=False))
        data.x = F.elu(self.score_fr2(data.x, data.edge_index, data.edge_attr))

        pool2.x = F.elu(
            self.score_pool2(pool2.x, pool2.edge_index, pool2.edge_attr))

        # data = recover_grid_barycentric(data, weights=weights1, pos=pos1, edge_index=edge_index1, cluster=cluster1,
        #                                  batch=batch1, transform=None)
        data.x = data.x + pool2.x
        data = recover_grid_barycentric(data,
                                        weights=weights2,
                                        pos=pos2,
                                        edge_index=edge_index2,
                                        cluster=cluster2,
                                        batch=batch2,
                                        transform=T.Cartesian(cat=False))

        data = recover_grid_barycentric(data,
                                        weights=weights1,
                                        pos=pos1,
                                        edge_index=edge_index1,
                                        cluster=cluster1,
                                        batch=batch1,
                                        transform=T.Cartesian(cat=False))

        #
        data.x = F.elu(self.convout(data.x, data.edge_index, data.edge_attr))

        x = data.x

        return x