Ejemplo n.º 1
0
    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        device = xyz.device
        B, C, N = xyz.shape
        xyz_t = xyz.permute(0, 2, 1).contiguous()  # [B, N, C]

        fps_idx = pointutils.furthest_point_sample(xyz_t,
                                                   self.npoint)  # [B, npoint]
        new_xyz = pointutils.gather_operation(xyz, fps_idx)  # [B, 3, npoint]
        new_xyz_t = new_xyz.permute(0, 2, 1).contiguous()

        _, idx = pointutils.knn(self.nsample, new_xyz_t,
                                xyz_t)  # [B, npoint, nsample]
        neighbors = pointutils.grouping_operation(
            xyz, idx)  # [B, 3, npoint, nsample]
        centers = new_xyz.view(B, -1, self.npoint, 1).repeat(
            1, 1, 1, self.nsample)  # [B, 3, npoint, nsample]
        pos_diff = centers - neighbors  # [B, 3, npoint, nsample]
        distances = torch.norm(pos_diff, p=2, dim=1,
                               keepdim=True)  # [B, 1, npoint, nsample]
        h_xi_xj = torch.cat([distances, pos_diff, centers, neighbors],
                            dim=1)  # [B, 1+3+3+3, npoint, nsample]

        x = pointutils.grouping_operation(points,
                                          idx)  # [B, D, npoint, nsample]
        x = torch.cat([neighbors, x], dim=1)  # [B, D+3, npoint, nsample]

        h_xi_xj = self.mapping_func2(
            F.relu(self.bn_mapping(
                self.mapping_func1(h_xi_xj))))  # [B, c_in, npoint, nsample]
        if self.first_layer:
            x = F.relu(self.bn_xyz_raising(
                self.xyz_raising(x)))  # [B, c_in, npoint, nsample]
        x = F.relu(self.bn_rsconv(torch.mul(h_xi_xj,
                                            x)))  # (B, c_in, npoint, nsample)

        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            x = F.relu(bn(conv(x)))  # [B, c_out, npoint, nsample]

        x = torch.max(x, -1)[0]  # [B, c_out, npoint]
        # x = F.relu(self.bn_channel_raising(self.cr_mapping(x)))   # [B, c_out, npoint]

        return new_xyz, x
Ejemplo n.º 2
0
    def forward(self, pos1, pos2, feature1, feature2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        pos1_t = pos1.permute(0, 2, 1).contiguous()
        pos2_t = pos2.permute(0, 2, 1).contiguous()
        B, C, N = pos1.shape

        # dists = square_distance(pos1, pos2)
        # dists, idx = dists.sort(dim=-1)
        # dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
        dists, idx = pointutils.three_nn(pos1_t, pos2_t)  # [B, N, K=3]
        dists[dists < 1e-10] = 1e-10
        weight = 1.0 / dists
        weight = weight / torch.sum(weight, -1, keepdim=True)
        interpolated_feat = torch.sum(
            pointutils.grouping_operation(feature2, idx) *
            weight.view(B, 1, N, 3),
            dim=-1)  # [B, C, N, S=3] -> [B, C, N]

        if feature1 is not None:
            feat_new = torch.cat([interpolated_feat, feature1], 1)
        else:
            feat_new = interpolated_feat

        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            feat_new = F.relu(bn(conv(feat_new)))
        return feat_new
Ejemplo n.º 3
0
    def forward(self, pos1, pos2, feature1, feature2):
        """
            Feature propagation from xyz2 (less points) to xyz1 (more points)
        Inputs:
            xyz1: (batch_size, 3, npoint1)
            xyz2: (batch_size, 3, npoint2)
            feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points)
            feat2: (batch_size, channel1, npoint2) features for xyz2 points
        Output:
            feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3)
            TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating.
        """
        pos1_t = pos1.permute(0, 2, 1).contiguous()
        pos2_t = pos2.permute(0, 2, 1).contiguous()
        B, C, N = pos1.shape
        if self.knn:
            _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)  # [B, N1, S]
        else:
            idx = pointutils.ball_query(self.radius, self.nsample, pos2_t,
                                        pos1_t)

        pos2_grouped = pointutils.grouping_operation(pos2, idx)
        pos_diff = pos2_grouped - pos1.view(B, -1, N, 1)  # [B, 3, N1, S]

        feat2_grouped = pointutils.grouping_operation(feature2, idx)
        feat_new = torch.cat([feat2_grouped, pos_diff],
                             dim=1)  # [B, C1+3, N1, S]

        for conv in self.mlp1_convs:
            feat_new = conv(feat_new)

        # max pooling
        feat_new = feat_new.max(-1)[0]  # [B, mlp1[-1], N1]

        # concatenate feature in early layer
        if feature1 is not None:
            feat_new = torch.cat([feat_new, feature1],
                                 dim=1)  # [B, mlp1[-1]+feat1_channel, N1]

        for conv in self.mlp2_convs:
            feat_new = conv(feat_new)

        return feat_new
Ejemplo n.º 4
0
    def forward(self, pos1, pos2, feature1, feature2):
        """
        Input:
            xyz1: (batch_size, 3, npoint)
            xyz2: (batch_size, 3, npoint)
            feat1: (batch_size, channel, npoint)
            feat2: (batch_size, channel, npoint)
        Output:
            xyz1: (batch_size, 3, npoint)
            feat1_new: (batch_size, mlp[-1], npoint)
        """
        pos1_t = pos1.permute(0, 2, 1).contiguous()
        pos2_t = pos2.permute(0, 2, 1).contiguous()
        B, N, C = pos1_t.shape
        if self.knn:
            _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t)  # [B, N, S]
        else:
            idx = pointutils.ball_query(self.radius, self.nsample, pos2_t,
                                        pos1_t)

        pos2_grouped = pointutils.grouping_operation(pos2, idx)  # [B, 3, N, S]
        pos_diff = pos2_grouped - pos1.view(B, -1, N, 1)  # [B, 3, N, S]

        feat2_grouped = pointutils.grouping_operation(feature2,
                                                      idx)  # [B, C, N, S]
        if self.corr_func == 'concat':
            feat_diff = torch.cat([
                feat2_grouped,
                feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)
            ],
                                  dim=1)  # [B, 2*C, N, S]

        feat1_new = torch.cat([pos_diff, feat_diff], dim=1)  # [B, 2*C+3, N, S]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            feat1_new = F.relu(bn(conv(feat1_new)))

        feat1_new = torch.max(feat1_new, -1)[0]  # [B, mlp[-1], npoint]
        return pos1, feat1_new
Ejemplo n.º 5
0
def group(p: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
    """Group point cloud indices.

    Args:
        p: Reference point cloud of shape [batch_size, dim, num_point].
        idx: Indices tensor of shape [batch_size, num_query, k].

    Returns:
        A tensor of shape [batch_size, dim, num_query, k].
    """

    p = p.contiguous()
    return _PU.grouping_operation(p, idx)
Ejemplo n.º 6
0
    def forward(self,point_cloud):
        dist,idx=self.KNN(point_cloud,point_cloud)
        '''
        idx is batch_size,k,n_points
        point_cloud is batch_size,n_dims,n_points
        point_cloud_neightbors is batch_size,n_dims,k,n_points
        '''
        idx=idx[:,1:,:]
        point_cloud_neighbors=grouping_operation(point_cloud,idx.contiguous().int())
        point_cloud_central=point_cloud.unsqueeze(2).repeat(1,1,self.k,1)
        #print(point_cloud_central.shape,point_cloud_neighbors.shape)
        edge_feature=torch.cat([point_cloud_central,point_cloud_neighbors-point_cloud_central],dim=1)

        return edge_feature,idx



        return dist,idx
Ejemplo n.º 7
0
    def get_uniform_loss(self,
                         pcd,
                         percentage=[0.004, 0.006, 0.008, 0.010, 0.012],
                         radius=1.0):
        B, N, C = pcd.shape[0], pcd.shape[1], pcd.shape[2]
        npoint = int(N * 0.05)
        loss = 0
        further_point_idx = pn2_utils.furthest_point_sample(
            pcd.permute(0, 2, 1).contiguous(), npoint)
        new_xyz = pn2_utils.gather_operation(
            pcd.permute(0, 2, 1).contiguous(), further_point_idx)  # B,C,N
        for p in percentage:
            nsample = int(N * p)
            r = math.sqrt(p * radius)
            disk_area = math.pi * (radius**2) / N

            idx = pn2_utils.ball_query(r, nsample, pcd.contiguous(),
                                       new_xyz.permute(
                                           0, 2, 1).contiguous())  #b N nsample

            expect_len = math.sqrt(disk_area)

            grouped_pcd = pn2_utils.grouping_operation(
                pcd.permute(0, 2, 1).contiguous(), idx)  #B C N nsample
            grouped_pcd = grouped_pcd.permute(0, 2, 3, 1)  #B N nsample C

            grouped_pcd = torch.cat(torch.unbind(grouped_pcd, dim=1),
                                    dim=0)  #B*N nsample C

            dist, _ = self.knn_uniform(grouped_pcd, grouped_pcd)
            #print(dist.shape)
            uniform_dist = dist[:, :, 1:]  #B*N nsample 1
            uniform_dist = torch.abs(uniform_dist + 1e-8)
            uniform_dist = torch.mean(uniform_dist, dim=1)
            uniform_dist = (uniform_dist - expect_len)**2 / (expect_len + 1e-8)
            mean_loss = torch.mean(uniform_dist)
            mean_loss = mean_loss * math.pow(p * 100, 2)
            loss += mean_loss
        return loss / len(percentage)