Пример #1
0
    def forward(self, pos1, pos2, feature1, feature2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        pos1_t = pos1.permute(0, 2, 1).contiguous()
        pos2_t = pos2.permute(0, 2, 1).contiguous()
        B, C, N = pos1.shape

        # dists = square_distance(pos1, pos2)
        # dists, idx = dists.sort(dim=-1)
        # dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
        dists, idx = pointutils.three_nn(pos1_t, pos2_t)  # [B, N, K=3]
        dists[dists < 1e-10] = 1e-10
        weight = 1.0 / dists
        weight = weight / torch.sum(weight, -1, keepdim=True)
        interpolated_feat = torch.sum(
            pointutils.grouping_operation(feature2, idx) *
            weight.view(B, 1, N, 3),
            dim=-1)  # [B, C, N, S=3] -> [B, C, N]

        if feature1 is not None:
            feat_new = torch.cat([interpolated_feat, feature1], 1)
        else:
            feat_new = interpolated_feat

        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            feat_new = F.relu(bn(conv(feat_new)))
        return feat_new
Пример #2
0
    def forward(self, unknown, known, unknow_feats, known_feats):
        # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        #print(unknown.shape, known.shape, unknow_feats.shape, known_feats.shape)

        #import pdb; pdb.set_trace()

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            #print(known_feats.shape, idx.shape, weight.shape)

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight)
        else:
            interpolated_feats = known_feats.expand(*(known_feats.size()[0:2] +
                                                      [unknown.size(1)]))

        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats],
                                     dim=1)  # (B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)
Пример #3
0
    def forward(self, unknown: torch.Tensor, known: torch.Tensor,
                unknow_feats: torch.Tensor,
                known_feats: torch.Tensor) -> torch.Tensor:
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknow_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propigated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propigated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight)
        else:
            interpolated_feats = known_feats.expand(*known_feats.size()[0:2],
                                                    unknown.size(1))

        if unknow_feats is not None:
            new_features = torch.cat([interpolated_feats, unknow_feats],
                                     dim=1)  #(B, C2 + C1, n)
        else:
            new_features = interpolated_feats

        new_features = new_features.unsqueeze(-1)
        new_features = self.mlp(new_features)

        return new_features.squeeze(-1)
Пример #4
0
    def forward(self, unknown, known, known_feats):
        # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
        r"""
        Parameters
        ----------
        unknown : torch.Tensor
            (B, n, 3) tensor of the xyz positions of the unknown features
        known : torch.Tensor
            (B, m, 3) tensor of the xyz positions of the known features
        unknown_feats : torch.Tensor
            (B, C1, n) tensor of the features to be propogated to
        known_feats : torch.Tensor
            (B, C2, m) tensor of features to be propogated

        Returns
        -------
        new_features : torch.Tensor
            (B, mlp[-1], n) tensor of the features of the unknown features
        """

        if known is not None:
            dist, idx = pointnet2_utils.three_nn(unknown, known)  # 三点最近邻
            dist_recip = 1.0 / (dist + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm  # 计算权重

            interpolated_feats = pointnet2_utils.three_interpolate(
                known_feats, idx, weight)  # 插值

        else:
            interpolated_feats = known_feats.expand(*(known_feats.size()[0:2] +
                                                      [unknown.size(1)]))

        new_features = interpolated_feats

        # new_features = new_features.unsqueeze(-1)
        # new_features = self.mlp(new_features)

        # return new_features.squeeze(-1)
        return new_features  # 不用mlp处理
Пример #5
0
    def forward(self, pointcloud: torch.cuda.FloatTensor, cls):
        # x: B,3,N

        xyz, features = self._break_up_pc(pointcloud)
        num_pts = xyz.size(1)
        batch_size = xyz.size(0)
        # FPS to find different point subsets and their relations
        subset1_idx = pointnet2_utils.furthest_point_sample(xyz, num_pts //
                                                            4).long()  # B,N/2
        subset1_xyz = torch.unsqueeze(subset1_idx, -1).repeat(1, 1,
                                                              3)  # B,N/2,3
        subset1_xyz = torch.take(xyz, subset1_xyz)  # B,N/2,3

        dist, idx1 = pointnet2_utils.three_nn(xyz, subset1_xyz)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight1 = dist_recip / norm

        subset12_idx = pointnet2_utils.furthest_point_sample(
            subset1_xyz, num_pts // 16).long()  # B,N/4
        subset12_xyz = torch.unsqueeze(subset12_idx, -1).repeat(1, 1,
                                                                3)  # B,N/4,3
        subset12_xyz = torch.take(subset1_xyz, subset12_xyz)  # B,N/4,3

        dist, idx12 = pointnet2_utils.three_nn(subset1_xyz, subset12_xyz)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight12 = dist_recip / norm

        device = torch.device('cuda')
        centroid = torch.zeros([batch_size, 1, 3], device=device)
        dist, idx0 = pointnet2_utils.three_nn(subset12_xyz, centroid)
        dist_recip = 1.0 / (dist + 1e-8)
        norm = torch.sum(dist_recip, dim=2, keepdim=True)
        weight0 = dist_recip / norm
        #######################################
        # Error-minimizing module 1:
        # Encoding
        x = xyz.transpose(2, 1)  # x: B,3,N
        x1_1 = x
        x = get_adptive_dilated_graph_feature(x,
                                              self.conv_op1,
                                              self.conv_op11,
                                              self.conv_op12,
                                              d=5,
                                              k=20)
        x = self.conv1(x)  # B,64,N,k
        x = self.conv14(x)  # B,64,N,k
        x1_2 = x
        # Back-projection
        x = self.conv11(x)  # B,3,N,1
        x = torch.squeeze(x, -1)  # B,3,N
        x1_3 = x
        # Calculating Error
        delta_1 = x1_3 - x1_1  # B,3,N
        # Output
        x = x1_2  # B,64,N,k
        x1 = x.max(dim=-1, keepdim=False)[0]  # B,64,N
        #######################################

        #######################################
        # Multi-resolution (MR) Branch
        # Down-scaling 1
        subset1_feat = torch.unsqueeze(subset1_idx, -1).repeat(1, 1,
                                                               64)  # B,N/2,64
        x1_subset1 = torch.take(x1.transpose(1, 2).contiguous(),
                                subset1_feat).transpose(
                                    1, 2).contiguous()  # B,64,N/2

        x2_1 = x1_subset1  # B,64,N/2
        x = get_graph_feature(x1_subset1, k=self.k // 2)
        x = self.conv2(x)  # B,64,N/2,k
        x = self.conv24(x)  # B,128,N/2,k
        x2 = x.max(dim=-1, keepdim=False)[0]  # B,128,N/2

        # Dense-connection
        x12 = pointnet2_utils.three_interpolate(x2, idx1, weight1)  # B,128,N
        x12 = torch.cat((x12, x1), dim=1)  # B,192,N
        x12 = self.conv23(x12)  # B,128,N

        # Down-scaling 2
        subset12_feat = torch.unsqueeze(subset12_idx,
                                        -1).repeat(1, 1, 128)  # B,N/4,128
        x2_subset12 = torch.take(
            x2.transpose(1, 2).contiguous(),
            subset12_feat).transpose(1, 2).contiguous()  # B,128,N/4

        x3_1 = x2_subset12  # B,128,N/4
        x = get_graph_feature(x2_subset12, k=self.k // 4)
        x = self.conv3(x)  # B,256,N/4,k
        x3 = x.max(dim=-1, keepdim=False)[0]  # B,256,N/4

        # Dense-connection
        x23 = pointnet2_utils.three_interpolate(x3, idx12,
                                                weight12)  # B,256,N/2
        x23 = torch.cat((x23, x2), dim=1)  # B,384,N/2
        x23 = self.conv34(x23)  # B,128,N/2
        x123 = pointnet2_utils.three_interpolate(x23, idx1, weight1)  # B,128,N
        x123 = torch.cat((x123, x12, x1), dim=1)  # B,320,N
        x123 = self.conv35(x123)  # B,128,N

        # Down-scaling 3
        x_bot = self.conv53(x3)
        x_bot = self.conv54(x_bot)  # B,1024,N/128
        x_bot = F.adaptive_max_pool1d(x_bot, 1)  # B,1024,1

        # Upsampling 3:
        interpolated_feats1 = pointnet2_utils.three_interpolate(
            x_bot, idx0, weight0)  # B,1024,N/4
        interpolated_feats2 = x3  # B,256,N/4
        x3_up = torch.cat((interpolated_feats1, interpolated_feats2),
                          dim=1)  # B,1280,N/4
        x3_up = self.conv32(x3_up)  # B,256,N/4
        x3_up = self.conv33(x3_up)  # B,256,N/4

        # Upsampling 2:
        interpolated_feats1 = pointnet2_utils.three_interpolate(
            x3_up, idx12, weight12)  # B,256,N/2
        interpolated_feats2 = x2  # B,128,N/2
        interpolated_feats3 = x23  # B,128,N/2
        x2_up = torch.cat(
            (interpolated_feats1, interpolated_feats3, interpolated_feats2),
            dim=1)  # B,512,N/2
        x2_up = self.conv21(x2_up)  # B,256,N/2
        x2_up = self.conv22(x2_up)  # B,128,N/2

        # Upsampling 1:
        interpolated_feats1 = pointnet2_utils.three_interpolate(
            x2_up, idx1, weight1)  # B,128,N
        interpolated_feats2 = x1  # B,64,N
        interpolated_feats3 = x12  # B,128,N
        interpolated_feats4 = x123  # B,128,N
        x1_up = torch.cat((interpolated_feats1, interpolated_feats4,
                           interpolated_feats3, interpolated_feats2),
                          dim=1)  # B,448,N
        x1_up = self.conv12(x1_up)  # B,512,N
        x1_up = self.conv13(x1_up)  # B,1024,N

        x_mr = x1_up
        #############################################################################

        #############################################################################
        # Full-resolution Branch
        # Error-minimizing module 2:
        # Encoding
        x2_1 = x1  # B,64,N
        x = get_adptive_dilated_graph_feature(x1,
                                              self.conv_op2,
                                              self.conv_op21,
                                              self.conv_op22,
                                              d=5,
                                              k=20)
        x = self.convfc2(x)  # B,64,N,k
        x = self.convfc24(x)  # B,64,N,k
        x2_2 = x
        # Back-projection
        x = self.convfc21(x)  # B,64,N,1
        x = torch.squeeze(x, -1)  # B,64,N
        x2_3 = x
        # Calculating Error
        delta_2 = x2_3 - x2_1  # B,64,N
        # Output
        x = x2_2  # B,64,N,k
        x2 = x.max(dim=-1, keepdim=False)[0]  # B,64,N
        #######################################
        # Error-minimizing module 3:
        # Encoding
        x3_1 = x2  # B,64,N
        x = get_adptive_dilated_graph_feature(x2,
                                              self.conv_op3,
                                              self.conv_op31,
                                              self.conv_op32,
                                              d=5,
                                              k=20)
        x = self.convfc3(x)  # B,128,N,k
        x3_2 = x
        # Back-projection
        x = self.convfc31(x)  # B,64,N,1
        x = torch.squeeze(x, -1)  # B,64,N
        x3_3 = x
        # Calculating Error
        delta_3 = x3_3 - x3_1  # B,64,N
        # Output
        x = x3_2  # B,128,N,k
        x3 = x.max(dim=-1, keepdim=False)[0]  # B,128,N
        #######################################
        # Error-minimizing module 4:
        # Encoding
        x4_1 = x3  # B,128,N
        x = get_adptive_dilated_graph_feature(x3,
                                              self.conv_op4,
                                              self.conv_op41,
                                              self.conv_op42,
                                              d=5,
                                              k=20)
        x = self.convfc4(x)  # B,256,N,k
        x4_2 = x
        # Back-projection
        x = self.convfc41(x)  # B,128,N,1
        x = torch.squeeze(x, -1)  # B,128,N
        x4_3 = x
        # Calculating Error
        delta_4 = x4_3 - x4_1  # B,128,N
        # Output
        x = x4_2  # B,256,N,k
        x4 = x.max(dim=-1, keepdim=False)[0]  # B,256,N

        x = torch.cat((x1, x2, x3, x4), dim=1)  # B,512,N
        x_fr = self.conv7(x)  # B,1024,N

        # Fusing FR and MR outputs
        fusion_score = self.fuse(x_mr)
        x = x_fr + x_fr * fusion_score
        x_all = self.conv9(x)  # B,1024,N

        # Collecting global feature
        one_hot_label = cls.view(-1, 16, 1)  # B,16,1
        one_hot_label = self.conv5(one_hot_label)  # B,64,1
        x_max = F.adaptive_max_pool1d(x_all, 1)  # B,1024,1
        x_global = torch.cat((x_max, one_hot_label), dim=1)  # B,1088,1

        x_global = x_global.repeat(1, 1, num_pts)  # B,1088,N
        x = torch.cat((x_all, x_global), dim=1)  # B,2112,N

        x = self.conv8(x)  # B,1024,N

        x = self.conv63(x)  # B,128,N
        x = self.dp(x)
        x = self.conv64(x)  # B,50,N

        return (x.transpose(2,
                            1).contiguous(), delta_1.transpose(2,
                                                               1).contiguous(),
                delta_2.transpose(2, 1).contiguous(),
                delta_3.transpose(2, 1).contiguous(),
                delta_4.transpose(2, 1).contiguous())