コード例 #1
0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = x.relu()
        cluster1 = graclus(edge_index, num_nodes=x.shape[0])
        pooled_1 = data
        pooled_1.x = x
        pooled_1 = max_pool(cluster1, pooled_1)
        edge_index_2 = pooled_1.edge_index
        x2 = pooled_1.x
        x2 = self.conv2(x2, edge_index_2)
        x2 = x2.relu()
        cluster2 = graclus(edge_index_2, num_nodes=x2.shape[0])
        pooled_2 = pooled_1
        pooled_2.x = x2
        pooled_2 = max_pool(cluster2, pooled_2)
        edge_index_3 = pooled_2.edge_index
        x3 = pooled_2.x
        x3 = self.conv3(x3, edge_index_3)
        x3 = x3.relu()
        x3 = self.conv4(x3, edge_index_3)
        x3 = x3.relu()
        x3 = knn_interpolate(x3, pooled_2.pos, pooled_1.pos)
        x3 = torch.cat((x2, x3), dim=1)
        x3 = self.conv5(x3, edge_index_2)
        x3 = x3.relu()
        x3 = knn_interpolate(x3, pooled_1.pos, data.pos)
        x = torch.cat((x, x3), dim=1)
        x = self.lin1(x)
        x = x.relu()
        x = self.lin2(x)
        x = x.relu()
        x = torch.sigmoid(self.out(x))

        return x
コード例 #2
0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre_lin(x) if self.masif_descr else x
        x = self.conv1(x, edge_index)
        x = self.s1(x)
        x = self.conv2(x, edge_index)
        x = self.s2(x)
        x = self.conv3(x, edge_index)
        x = self.s3(x)

        cluster = graclus(edge_index, num_nodes=x.shape[0])
        inter = data
        inter.x = x
        inter = max_pool(cluster, inter)
        interx = self.inters1(self.interconv1(inter.x, inter.edge_index))
        inter = knn_interpolate(interx, inter.pos, data.pos)
        x1 = self.affine1(x)
        x1 += inter

        x = self.conv4(x, edge_index)
        x = self.s4(x)
        x = self.conv5(x, edge_index)
        x = self.s5(x)
        x = self.conv6(x, edge_index)
        x = self.s6(x)

        inter = data
        inter.x = x
        inter = max_pool(cluster, inter)
        interx = self.inters2(self.interconv1(inter.x, inter.edge_index))
        inter = knn_interpolate(interx, inter.pos, data.pos)
        x2 = self.affine1(x)
        x2 += inter

        x = self.conv7(x, edge_index)
        x = self.s7(x)
        x = self.conv8(x, edge_index)
        x = self.s8(x)
        x = self.conv9(x, edge_index)
        x = self.s9(x)
        x = x + x1 + x2
        x = self.conv10(x, edge_index)
        x = self.s10(x)
        x = self.lin1(x)
        x = self.s11(x)
        x = self.lin2(x)
        x = self.s12(x)
        x = self.out(x)
        x = torch.sigmoid(x)

        return x
コード例 #3
0
ファイル: point_net.py プロジェクト: timkartar/geobind
    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.nn(x)

        return x, pos_skip, batch_skip
コード例 #4
0
    def predict_original_samples(self, batch, conv_type, output):
        """ Takes the output generated by the NN and upsamples it to the original data
        Arguments:
            batch -- processed batch
            conv_type -- Type of convolutio (DENSE, PARTIAL_DENSE, etc...)
            output -- output predicted by the model
        """
        full_res_results = {}
        num_sample = BaseDataset.get_num_samples(batch, conv_type)
        if conv_type == "DENSE":
            output = output.reshape(num_sample, -1,
                                    output.shape[-1])  # [B,N,L]

        setattr(batch, "_pred", output)
        for b in range(num_sample):
            sampleid = batch.sampleid[b]
            sample_raw_pos = self.test_dataset[0].get_raw(sampleid).pos.to(
                output.device)
            predicted = BaseDataset.get_sample(batch, "_pred", b, conv_type)
            origindid = BaseDataset.get_sample(batch, SaveOriginalPosId.KEY, b,
                                               conv_type)
            full_prediction = knn_interpolate(predicted,
                                              sample_raw_pos[origindid],
                                              sample_raw_pos,
                                              k=3)
            labels = full_prediction.max(1)[1].unsqueeze(-1)
            full_res_results[self.test_dataset[0].get_filename(
                sampleid)] = np.hstack((
                    sample_raw_pos.cpu().numpy(),
                    labels.cpu().numpy(),
                ))
        return full_res_results
コード例 #5
0
def test_knn_interpolate():
    x = torch.Tensor([[1], [10], [100], [-1], [-10], [-100]])
    pos_x = torch.Tensor([[-1, 0], [0, 0], [1, 0], [-2, 0], [0, 0], [2, 0]])
    pos_y = torch.Tensor([[-1, -1], [1, 1], [-2, -2], [2, 2]])
    batch_x = torch.tensor([0, 0, 0, 1, 1, 1])
    batch_y = torch.tensor([0, 0, 1, 1])

    y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2)
    assert y.tolist() == [[4], [70], [-4], [-70]]
コード例 #6
0
 def forward(self, data):
     #print()
     #print([x.shape if x is not None else None for x in data])
     x, pos, batch, x_skip, pos_skip, batch_skip = data
     x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
     if x_skip is not None:
         x = torch.cat([x, x_skip], dim=1)
     x = self.nn(x)
     data = (x, pos_skip, batch_skip)
     return data
コード例 #7
0
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.pre_lin(x) if self.masif_descr else x
        x = self.conv1(x, edge_index)
        x = self.s1(x)
        x = self.conv2(x, edge_index)
        x = self.s2(x)
        x = self.conv3(x, edge_index)
        x = self.s3(x)

        cluster1 = graclus(edge_index, num_nodes=x.shape[0])
        inter1 = data
        inter1.x = x
        inter1 = max_pool(cluster1, inter1)
        x = self.s4(self.conv4(inter1.x, inter1.edge_index))
        edge_index = inter1.edge_index
        x = self.conv5(x, edge_index)
        x = self.s5(x)
        x = self.conv6(x, edge_index)
        x = self.s6(x)
        cluster2 = graclus(edge_index, num_nodes=x.shape[0])
        inter2 = inter1
        inter2.x = x
        inter2 = max_pool(cluster2, inter2)
        x = self.s7(self.conv7(inter2.x, inter2.edge_index))
        x = knn_interpolate(x, inter2.pos, inter1.pos)
        x = self.conv8(x, edge_index)
        x = self.s8(x)
        x = knn_interpolate(x, inter1.pos, data.pos)
        edge_index = data.edge_index
        x = self.conv9(x, edge_index)
        x = self.s9(x)
        x = self.conv10(x, edge_index)
        x = self.s10(x)
        x = self.lin1(x)
        x = self.s11(x)
        x = self.lin2(x)
        x = self.s12(x)
        x = self.out(x)
        x = torch.sigmoid(x)

        return x
コード例 #8
0
    def forward(self, data):
        x, edge_index_1 = data.x, data.edge_index
        # define downscaled samples.
        cluster1 = graclus(edge_index_1, num_nodes=x.shape[0])
        downsample_1 = avg_pool(cluster1, data)
        edge_index_2 = downsample_1.edge_index
        cluster2 = graclus(edge_index_2, num_nodes=downsample_1.x.shape[0])
        downsample_2 = avg_pool(cluster2, downsample_1)
        edge_index_3 = downsample_2.edge_index

        x = self.conv1(x, edge_index_1)
        x = self.s1(x)
        inter1 = data
        inter1.x = x
        inter1 = max_pool(cluster1, inter1)
        x2 = inter1.x
        x2 = torch.cat((self.affine1(downsample_1.x), x2), dim=1)
        x2 = self.conv2(x2, edge_index_2)
        x2 = self.s2(x2)

        inter2 = inter1
        inter2.x = x2
        inter2 = max_pool(cluster2, inter2)
        x3 = inter2.x
        x3 = torch.cat((self.affine2(downsample_2.x), x3), dim=1)
        x3 = self.conv3(x3, edge_index_3)
        x3 = self.s3(x3)

        x3 = knn_interpolate(x3, downsample_2.pos, downsample_1.pos)
        x2 = torch.cat((x2, x3), dim=1)
        x2 = knn_interpolate(x2, downsample_1.pos, data.pos)
        x = torch.cat((x, x2), dim=1)

        x = self.conv4(x, edge_index_1)
        x = self.s4(x)
        x = self.conv5(x, edge_index_1)
        x = self.s5(x)
        x = self.s6(self.lin1(x))
        x = self.s7(self.lin2(x))

        return torch.sigmoid(self.out(x))
コード例 #9
0
    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        x_upsampled = gnn.knn_interpolate(
            x, pos, pos_skip, batch_x=batch, batch_y=batch_skip,
            k=self.k)  #Interpolate coarse features to dense positions

        if x_skip is not None:
            x_upsampled = torch.cat(
                (x_upsampled, x_skip),
                dim=1)  #Concatenate new and old dense features

        # print('dims:', x.shape, x_skip.shape, x_upsampled.shape)
        x_upsampled = self.mlp(x_upsampled)  #Run MLP

        return x_upsampled, pos_skip, batch_skip
コード例 #10
0
    def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):
        # transform low-res features and reduce the number of features
        x_sub = self.mlp_sub(x_sub)

        # interpolate low-res feats to high-res points
        x_interpolated = knn_interpolate(x_sub,
                                         pos_sub,
                                         pos,
                                         k=3,
                                         batch_x=batch_sub,
                                         batch_y=batch)

        x = self.mlp(x) + x_interpolated

        return x
コード例 #11
0
    def forward(self, data, backsampling):
        if self.conv_layer:
            data.x = F.elu(self.conva(data.x, data.edge_index, data.edge_attr))
            # data.x = F.elu(self.convb(data.x, data.edge_index, data.edge_attr))

        data.x = knn_interpolate(data.x,
                                 data.pos,
                                 backsampling.pos,
                                 data.batch,
                                 backsampling.batch,
                                 k=self.k)
        data.pos = backsampling.pos
        data.edge_index = backsampling.edge_index
        data.edge_attr = backsampling.edge_attr
        data.batch = backsampling.batch
        return data
コード例 #12
0
    def forward(self, x_hr, pos_hr, batch_hr, x_lr, pos_lr, batch_lr):
        out_x = x_hr
        out_pos = pos_hr
        out_batch = batch_hr
        if (out_x is not None):
            out_x = torch.cat([out_x, out_pos], dim=1)
        else:
            out_x = out_pos

        lr_x = x_lr
        lr_pos = pos_lr
        lr_batch = batch_lr
        lr_x = torch.cat([lr_x, lr_pos], dim=1)

        #return a tensor where the feature of each point of lr_x are appended in a new array in
        # the point of the upsampled version is in his knn
        lr_x = knn_interpolate(lr_x, lr_pos, out_pos, lr_batch, out_batch, k=1)
        out_x = torch.cat([out_x, lr_x], dim=1)

        return out_x, out_pos.new_zeros((out_x.size(0), 3)), out_batch
コード例 #13
0
    def __call__(self, query, support, precomputed: Data = None):
        """ Computes a new set of features going from the query resolution position to the support
        resolution position
        Args:
            - query: data structure that holds the low res data (position + features)
            - support: data structure that holds the position to which we will interpolate
        Returns:
            - torch.tensor: interpolated features
        """
        if precomputed:
            num_points = support.pos.size(0)
            if num_points != precomputed.num_nodes:
                raise ValueError("Precomputed indices do not match with the data given to the transform")

            x = query.x
            x_idx, y_idx, weights, normalisation = (
                precomputed.x_idx,
                precomputed.y_idx,
                precomputed.weights,
                precomputed.normalisation,
            )
            y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=num_points)
            y = y / normalisation
            return y

        x, pos = query.x, query.pos
        pos_support = support.pos
        if hasattr(support, "batch"):
            batch_support = support.batch
        else:
            batch_support = torch.zeros((support.num_nodes,), dtype=torch.long)
        if hasattr(query, "batch"):
            batch = query.batch
        else:
            batch = torch.zeros((query.num_nodes,), dtype=torch.long)

        return knn_interpolate(x, pos, pos_support, batch, batch_support, k=self.k)
コード例 #14
0
 def conv(self, x, pos, pos_skip, batch, batch_skip, *args):
     return knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
コード例 #15
0
    def forward(self, x):
        batch_size = x.size(0)
        num_points = x.size(2)

        x0 = get_graph_feature(
            x, x, k=self.k, dilation=False, dilation_type=self.dilation
        )  # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)  # (batch_size, 3, 3)
        x = x.transpose(
            2, 1)  # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(
            x, t
        )  # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(
            2, 1)  # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        pos = x
        pos_1 = pos
        x1 = self.edge1(x, pos_1)
        x1_upsample = x1
        x1_cpy = x1

        bs = torch.arange(x1.size(0)).to(self.device)
        bs = bs.repeat_interleave(x1.size(2))
        pos_upsample = pos.reshape(3, -1)
        pos = pos.permute(2, 1, 0).contiguous()
        pos = pos.view(pos.size(0) * pos.size(2), -1)  #(n_pts*bs),channel
        x1 = x1.permute(2, 1, 0).contiguous()  #n_pts,channel,bs
        x1 = x1.view(x1.size(0) * x1.size(2), -1)  #(n_pts*bs),channel

        idx = fps(pos, batch=bs, ratio=0.375)  #2048->768
        pos2 = pos[idx]
        pos2 = pos2.view(-1, pos2.size(1), batch_size)
        pos2 = pos2.permute(2, 1, 0)  #batchsize,channel,pts
        pos2_1 = pos2
        x2 = x1[idx]  #pts*bs,channel
        x2 = x2.view(-1, x1.size(1), batch_size)  #pts,channel,bs
        x2 = x2.permute(2, 1, 0)  #.view(x2.size(2),x2.size(1),x2.size(0))
        #########################
        x2 = self.edge2(x2, pos2_1)
        x2_upsample = x2

        bs2 = torch.arange(x2.size(0)).to(self.device)
        bs2 = bs2.repeat_interleave(x2.size(2))
        pos2_upsample = pos2.reshape(3, -1)
        pos2 = pos2.permute(2, 1, 0).contiguous()
        pos2 = pos2.view(pos2.size(0) * pos2.size(2), -1)  #(n_pts*bs),channel
        x2 = x2.permute(
            2, 1, 0).contiguous()  #view(x1.size(2),x1.size(1),x1.size(0))
        x2 = x2.view(x2.size(0) * x2.size(2), -1)  #(n_pts*bs),channel

        idx = fps(pos2, batch=bs2, ratio=1 / 3)  #768 -> 384
        pos3 = pos2[idx]
        pos3 = pos3.view(-1, pos3.size(1), batch_size)
        pos3 = pos3.permute(2, 1, 0)  #batchsize,channel,pts
        pos3_1 = pos3
        pos3_upsample = pos3.reshape(3, -1)
        x3 = x2[idx]
        x3 = x3.view(-1, x3.size(1), batch_size)  #pts,channel,bs
        x3 = x3.permute(2, 1, 0)  #.view(x3.size(2),x3.size(1),x3.size(0))
        ##########################
        #pos3 = pos2[:,idx]
        x3 = self.edge3(x3, pos3_1)  #bottleneck
        x3_upsample = x3  #bs,channel,n_pts
        bs3 = torch.arange(x3.size(0)).to(self.device)
        bs3 = bs3.repeat_interleave(x3.size(2))
        pos3 = pos3_upsample
        pos2 = pos2_upsample
        pos = pos_upsample

        x3_2 = x3.reshape(x3.size(0) * x3.size(2), -1)
        x4 = knn_interpolate(x3_2,
                             pos3.reshape(pos3.size(1), pos3.size(0)),
                             pos2.reshape(pos2.size(1), pos2.size(0)),
                             batch_x=bs3,
                             batch_y=bs2)
        x4 = x4.reshape(batch_size, x3_2.size(1), -1)
        x4 = torch.cat((x4, x2_upsample), 1)
        x4 = self.edge4(x4, pos2_1)

        x4_2 = x4.reshape(x4.size(0) * x4.size(2), -1)
        x5 = knn_interpolate(x4_2,
                             pos2.reshape(pos2.size(1), pos2.size(0)),
                             pos.reshape(pos.size(1), pos.size(0)),
                             batch_x=bs2,
                             batch_y=bs)
        x5 = x5.reshape(batch_size, x4_2.size(1), -1)
        x5 = torch.cat((x5, x1_upsample), 1)
        x5 = self.edge5(x5, pos_1)

        x5 = self.edge6(x5, pos_1)

        residual = x5

        x = self.conv4(
            x5
        )  # (batch_size, 256, num_points) -> (batch_size, 256, num_points)

        x = self.conv5(
            x
        )  # (batch_size, 256, num_points) -> (batch_size, 128, num_points)

        x = self.dp3(x)
        x = self.conv6(
            x
        )  # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)

        return x
コード例 #16
0
 def knn_interpolate(self, pos, k):
     x = knn_interpolate(self.x, self.pos, pos, k=k)
     return Prediction(pos, x)