def forward(
            self, xyz: torch.Tensor, features: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""
        Parameters
        ----------
        xyz : torch.Tensor
            (B, N, 3) tensor of the xyz coordinates of the features
        features : torch.Tensor
            (B, C, N) tensor of the descriptors of the the features

        Returns
        -------
        new_xyz : torch.Tensor
            (B, npoint, 3) tensor of the new features' xyz
        new_features : torch.Tensor
            (B,  \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
        """

        new_features_list = []

        xyz_flipped = xyz.transpose(1, 2).contiguous()
        new_xyz = (pointnet2_utils.gather_operation(
            xyz_flipped, pointnet2_utils.furthest_point_sample(
                xyz, self.npoint)).transpose(1, 2).contiguous()
                   if self.npoint is not None else None)

        if self.fuse == 'add':
            for i in range(len(self.groupers)):
                new_features = self.groupers[i](
                    xyz, new_xyz, features)  # (B, C, npoint, nsample)

                new_features = self.mlps[i](
                    new_features)  # (B, mlp[-1], npoint, nsample)
                new_features = F.max_pool2d(new_features,
                                            kernel_size=[
                                                1, new_features.size(3)
                                            ])  # (B, mlp[-1], npoint, 1)
                new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
                if i == 0:
                    new_features_sum = new_features
                else:
                    new_features_sum += new_features
            return new_xyz, new_features_sum

        for i in range(len(self.groupers)):
            new_features = self.groupers[i](
                xyz, new_xyz, features)  # (B, C, npoint, nsample)

            new_features = self.mlps[i](
                new_features)  # (B, mlp[-1], npoint, nsample)
            new_features = F.max_pool2d(new_features,
                                        kernel_size=[
                                            1, new_features.size(3)
                                        ])  # (B, mlp[-1], npoint, 1)
            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1)
Exemple #2
0
def pool(xyz, points, k, npoint):
    xyz_flipped = xyz.transpose(1, 2).contiguous()
    new_xyz = pointnet2_utils.gather_operation(
        xyz_flipped, pointnet2_utils.furthest_point_sample(
            xyz_flipped, npoint)).transpose(1, 2).contiguous()
    _, idx = knn_point(k, xyz, new_xyz)
    new_points = torch.max(pointnet2_utils.grouping_operation(
        points.permute(0, 2, 1).contiguous(),
        idx.int().permute(0, 2, 1).contiguous()).permute(0, 3, 2, 1),
                           dim=2).values

    return new_xyz, new_points
Exemple #3
0
def fps_subsample(pcd, n_points=2048):
    """
    Args
        pcd: (b, 16384, 3)

    returns
        new_pcd: (b, n_points, 3)
    """
    new_pcd = gather_operation(
        pcd.permute(0, 2, 1).contiguous(),
        furthest_point_sample(pcd, n_points))
    new_pcd = new_pcd.permute(0, 2, 1).contiguous()
    return new_pcd
Exemple #4
0
 def forward(self, x, feature, num_pool):
     x = x.contiguous()
     xyz_flipped = x.transpose(1, 2).contiguous()
     x_sub = (pointnet2_utils.gather_operation(
         x, pointnet2_utils.furthest_point_sample(xyz_flipped,
                                                  num_pool)).transpose(
                                                      1, 2).contiguous())
     # sub_index = self.knn(x_sub, xyz_flipped).int()
     sub_index = pointnet2_utils.ball_query(0.2 * self.nlayer, self.k,
                                            xyz_flipped, x_sub)
     x = pointnet2_utils.grouping_operation(x, sub_index)
     x = torch.max(x, dim=-1)[0]
     # x = soft_pool2d(x, [1, x.shape[-1]]).squeeze(-1)
     feature = pointnet2_utils.grouping_operation(feature, sub_index)
     feature = self.mlp(feature)
     feature = torch.max(feature, dim=-1)[0]
     # feature = soft_pool2d(feature, [1, feature.shape[-1]]).squeeze(-1)
     return x, feature
Exemple #5
0
def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
    """
    Args:
        xyz: Tensor, (B, 3, N)
        points: Tensor, (B, f, N)
        npoint: int
        nsample: int
        radius: float
        use_xyz: boolean

    Returns:
        new_xyz: Tensor, (B, 3, npoint)
        new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample)
        idx_local: Tensor, (B, npoint, nsample)
        grouped_xyz: Tensor, (B, 3, npoint, nsample)

    """
    xyz_flipped = xyz.permute(0, 2, 1).contiguous()  # (B, N, 3)
    new_xyz = gather_operation(xyz,
                               furthest_point_sample(xyz_flipped,
                                                     npoint))  # (B, 3, npoint)

    idx = ball_query(radius, nsample, xyz_flipped,
                     new_xyz.permute(0, 2,
                                     1).contiguous())  # (B, npoint, nsample)
    grouped_xyz = grouping_operation(xyz, idx)  # (B, 3, npoint, nsample)
    grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample)

    if points is not None:
        grouped_points = grouping_operation(points,
                                            idx)  # (B, f, npoint, nsample)
        if use_xyz:
            new_points = torch.cat([grouped_xyz, grouped_points], 1)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz
Exemple #6
0
            vis.add_geometry(pcd)

            ctr = vis.get_view_control()
            ctr.rotate(10.0, 0.0)

            vis.update_geometry()
            vis.poll_events()
            vis.update_renderer()
            time.sleep(1)
            vis.capture_screen_image('{}b.png'.format(i))
            vis.remove_geometry(pcd)

            point = torch.from_numpy(pcd.points)
            x_sub = (
            pointnet2_utils.gather_operation(
                xyz_flipped, pointnet2_utils.furthest_point_sample(x, num_pool)
                ).transpose(1, 2).contiguous()
            )
            vis.update_geometry(pcd)
            vis.poll_events()
            vis.update_renderer()
            time.sleep(1)
            vis.capture_screen_image('{}b.png'.format(i))
            vis.remove_geometry(pcd)

            # pcd = o3d.geometry.PointCloud()
            # pcd.points = o3d.utility.Vector3dVector(data)
            # o3d.visualization.draw_geometries([pcd],
            #                         zoom=1,
            #                         front=[1, 1, -1],
            #                         lookat=[0, 0, 0],
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points // self.npart
        last_point = -1
        last_feature_dim = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, feature_dim = h.shape

            ######## Build Graph #########
            last_point = n_points

            ######### Dynamic Graph Conv #########
            xyz = xyz.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            h = h.transpose(1, 2).contiguous()
            for j in range(self.num_conv):
                index = self.num_conv * i + j
                ####### BN + ReLU #####
                if self.pre_act == True:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

                ####### Graph Feature ###########
                if self.k == 1 and j == 0:
                    h = h.unsqueeze(-1)
                else:
                    if i == self.num_layers - 1:
                        if self.cluster == 'xyz':
                            h = get_graph_feature(xyz, h, k=self.k)
                        elif self.cluster == 'xyzrgb' or self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                    else:
                        # Common Layers
                        if self.cluster == 'allxyzrgb':
                            h = get_graph_feature(torch.cat((xyz, h), 1),
                                                  h,
                                                  k=self.k)
                        else:
                            h = get_graph_feature(xyz, h, k=self.k)

                ####### Conv ##########
                if self.light == True and i > 0:
                    #shuffle after the first layer
                    h = channel_shuffle(h, 2)
                    h = self.conv[index](h)
                else:
                    h = self.conv[index](h)
                h = h.max(dim=-1, keepdim=False)[0]
                ####### BN + ReLU #####
                if self.pre_act == False:
                    if self.norm == 'ln':
                        h = h.transpose(1, 2).contiguous()
                        h = self.bn[index](h)
                        h = h.transpose(1, 2).contiguous()
                    else:
                        h = self.bn[index](h)
                    h = F.leaky_relu(h, 0.2)

            h = h.transpose(1, 2).contiguous()
            #print(h.shape) # batchsize x point_number x feature_dim
            batch_size, n_points, feature_dim = h.shape

            ######### Residual Before Downsampling#############
            if self.id_skip == 1:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            ######### PointNet++ MSG ########
            if feature_dim != last_feature_dim:
                h = h.transpose(1, 2).contiguous()
                xyz, h = self.sa[s2_count](xyz_input, h)
                h = h.transpose(1, 2).contiguous()
                if self.id_skip == 2:
                    h_input = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(),
                        pointnet2_utils.furthest_point_sample(
                            xyz_input, h.shape[1])).transpose(1,
                                                              2).contiguous()
            else:
                xyz = xyz.transpose(1, 2).contiguous()

            ######### Residual After Downsampling (Paper) #############
            if self.id_skip == 2:
                if istrain and self.drop_connect_rate > 0:
                    h = drop_connect(h,
                                     p=self.drop_connect_rate,
                                     training=istrain)
                if feature_dim != last_feature_dim:
                    h_input = self.conv_s2[s2_count](h_input)
                h = h_input + self.res_scale * h

            if feature_dim != last_feature_dim:
                s2_count += 1
                last_feature_dim = feature_dim

            #print(xyz.shape, h.shape)
        if self.npart == 1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort
            #batch_size, n_points, _ = h.shape
            #y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            #h = h.view(batch_size * n_points, -1)
            #h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:, :, i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h
Exemple #8
0
    def forward(self, xyz, rgb, istrain=False):
        hs = []
        #xyz_copy = xyz.clone()
        #rgb_copy = rgb.clone()
        batch_size, n_points, _ = xyz.shape
        part_length = n_points//self.npart
        last_point = -1
        #h = self.proj_in(rgb)
        h = rgb
        s2_count = 0
        for i in range(self.num_layers):
            h_input = h.clone()
            xyz_input = xyz.clone()
            batch_size, n_points, _ = h.shape
            if self.k>1:
                if i == self.num_layers-1:
                    if self.cluster == 'xyz':
                        g = self.nng(xyz, istrain = istrain and self.graph_jitter)
                    elif self.cluster == 'rgb':
                        g = self.nng(h, istrain=istrain and self.graph_jitter)
                    elif self.cluster == 'xyzrgb':
                        g = self.nng( torch.cat((xyz,h), 2), istrain=istrain and self.graph_jitter)
                elif i==0 or  n_points !=  last_point:
                    g = self.nng(xyz, istrain=istrain and self.graph_jitter)
            last_point = n_points
            h = h.view(batch_size * n_points, -1)

            if self.k==1:
                h = self.conv[i](h)
            elif self.conv_type == 'GatedGCN':
                h = self.conv[i](g, h, g.edata['feat'], snorm_n = 1/g.number_of_nodes() , snorm_e = 1/g.number_of_edges())
            else:
                h = self.conv[i](g, h)
            h = F.leaky_relu(h, 0.2)
            h = h.view(batch_size, n_points, -1)
            h = h.transpose(1, 2).contiguous()
            xyz, h  = self.sa[i](xyz_input, h)
            h = h.transpose(1, 2).contiguous()
            #h = self.conv_s1[i](h)
            if self.id_skip and  h.shape[1] <= self.init_points//4:
            # We could use identity mapping Here or add connect drop
                if istrain and self.drop_connect_rate>0:
                    h = drop_connect(h, p=self.drop_connect_rate, training=istrain)

                if h.shape[1] == n_points:
                    h = h_input + self.res_scale * h  # Here I borrow the idea from Inception-ResNet-v2
                elif h.shape[1] == n_points//2:
                    h_input_s2 = pointnet2_utils.gather_operation(
                        h_input.transpose(1, 2).contiguous(), 
                        pointnet2_utils.furthest_point_sample(xyz_input, h.shape[1] )
                    ).transpose(1, 2).contiguous()
                    h = self.conv_s2[s2_count](h_input_s2) + self.res_scale * h
                    s2_count +=1
        if self.npart==1:
            # Pooling
            h_max, _ = torch.max(h, 1)
            h_avg = torch.mean(h, 1)
            hs.append(h_max)
            hs.append(h_avg)

            h = torch.cat(hs, 1)
            h = self.embs[0](h)
            h = self.bn_embs[0](h)
            h = self.dropouts[0](h)
            h = self.proj_output(h)
        else:
            # Sort 
            batch_size, n_points, _ = h.shape
            y_index = torch.argsort(xyz[:, :, 1],dim = 1).view(batch_size * n_points, -1)
            h = h.view(batch_size * n_points, -1)
            h = h[y_index, :].view(batch_size, n_points, -1)
            h = h.transpose(1, 2)
            # Part Pooling            
            h = self.partpool(h)
            for i in range(self.npart):
                part_h = h[:,:,i]
                part_h = self.embs[i](part_h)
                part_h = self.bn_embs[i](part_h)
                part_h = self.dropouts[i](part_h)
                part_h = self.proj_outputs[i](part_h)
                hs.append(part_h)
            h = hs
        return h