def gather_class(self, gt_labels, assigned_idx):
     # [bs, gt_num] -> [bs, points_num, cls_num]
     # gt_labels_dtype = gt_labels.dtype
     gt_labels_f = gt_labels.unsqueeze(dim=1).float()
     assigned_gt_labels = grouping_operation(gt_labels_f,
                                             assigned_idx.int())
     assigned_gt_labels = assigned_gt_labels.squeeze(
         dim=-1).long().transpose(1, 2)
     return assigned_gt_labels
def test_grouping_points():
    if not torch.cuda.is_available():
        pytest.skip()
    idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]],
                        [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]]]).int().cuda()
    festures = torch.tensor([[[
        0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
        0.9268, 0.8414
    ],
                              [
                                  5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
                                  5.1030, 1.9360, 2.1939, 2.1581, 3.4666
                              ],
                              [
                                  -1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
                                  -0.5732, -1.0830, -1.7561, -1.6786, -1.6967
                              ]],
                             [[
                                 -0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
                                 0.7798, -0.3693, -0.9457, -0.2942, -1.8527
                             ],
                              [
                                  1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
                                  2.7346, 6.0865, 1.5555, 4.3303, 2.8229
                              ],
                              [
                                  -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                  -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
                              ]]]).cuda()

    output = grouping_operation(festures, idx)
    expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798],
                                      [-1.3311, -1.3311, -1.3311],
                                      [0.9268, 0.9268, 0.9268],
                                      [0.5798, 0.5798, 0.5798],
                                      [0.5798, 0.5798, 0.5798],
                                      [0.5798, 0.5798, 0.5798]],
                                     [[5.4247, 5.4247, 5.4247],
                                      [1.4740, 1.4740, 1.4740],
                                      [2.1581, 2.1581, 2.1581],
                                      [5.4247, 5.4247, 5.4247],
                                      [5.4247, 5.4247, 5.4247],
                                      [5.4247, 5.4247, 5.4247]],
                                     [[-1.6266, -1.6266, -1.6266],
                                      [-1.6931, -1.6931, -1.6931],
                                      [-1.6786, -1.6786, -1.6786],
                                      [-1.6266, -1.6266, -1.6266],
                                      [-1.6266, -1.6266, -1.6266],
                                      [-1.6266, -1.6266, -1.6266]]],
                                    [[[-0.0380, -0.0380, -0.0380],
                                      [-0.3693, -0.3693, -0.3693],
                                      [-1.8527, -1.8527, -1.8527],
                                      [-0.0380, -0.0380, -0.0380],
                                      [-0.0380, -0.0380, -0.0380],
                                      [-0.0380, -0.0380, -0.0380]],
                                     [[1.1773, 1.1773, 1.1773],
                                      [6.0865, 6.0865, 6.0865],
                                      [2.8229, 2.8229, 2.8229],
                                      [1.1773, 1.1773, 1.1773],
                                      [1.1773, 1.1773, 1.1773],
                                      [1.1773, 1.1773, 1.1773]],
                                     [[-0.6646, -0.6646, -0.6646],
                                      [0.4990, 0.4990, 0.4990],
                                      [0.0386, 0.0386, 0.0386],
                                      [-0.6646, -0.6646, -0.6646],
                                      [-0.6646, -0.6646, -0.6646],
                                      [-0.6646, -0.6646, -0.6646]]]]).cuda()

    assert torch.allclose(output, expected_output)
Exemplo n.º 3
0
    def forward(self, xyz, points, former_fps_idx, vote_ctr):
        bs = xyz.shape[0]
        num_points = xyz.shape[1]

        cur_fps_idx_list = []
        last_fps_end_index = 0
        for fps_sample_range, fps_method, npoint in zip(
                self.fps_sample_range_list, self.fps_method_list,
                self.npoint_list):
            if fps_sample_range < 0:
                fps_sample_range_tmp = fps_sample_range + num_points + 1
            else:
                fps_sample_range_tmp = fps_sample_range
            tmp_xyz = xyz[:, last_fps_end_index:
                          fps_sample_range_tmp, :].contiguous()
            tmp_points = points[:, last_fps_end_index:
                                fps_sample_range_tmp, :].contiguous()
            if npoint == 0:
                last_fps_end_index += fps_sample_range
                continue
            if vote_ctr is not None:
                npoint = vote_ctr.shape[1]
                fps_idx = torch.arange(npoint).int().view(1, npoint).repeat(
                    (bs, 1)).to(tmp_xyz.device)
            elif fps_method == 'FS':
                features_for_fps = torch.cat([tmp_xyz, tmp_points], dim=-1)
                # dist1 = nn_distance(tmp_xyz, tmp_xyz)
                # dist2 = calc_square_dist(tmp_xyz, tmp_xyz, norm=False)
                features_for_fps_distance = calc_square_dist(
                    features_for_fps, features_for_fps)
                features_for_fps_distance = features_for_fps_distance.contiguous(
                )
                fps_idx_1 = pointnet2_utils.furthest_point_sample_with_dist(
                    features_for_fps_distance, npoint)
                fps_idx_2 = pointnet2_utils.furthest_point_sample(
                    tmp_xyz, npoint)
                fps_idx = torch.cat([fps_idx_1, fps_idx_2],
                                    dim=-1)  # [bs, npoint * 2]
            elif npoint == tmp_xyz.shape[1]:
                fps_idx = torch.arange(npoint).int().view(1, npoint).repeat(
                    (bs, 1)).to(tmp_xyz.device)
            elif fps_method == 'F-FPS':
                features_for_fps = torch.cat([tmp_xyz, tmp_points], dim=-1)
                features_for_fps_distance = calc_square_dist(
                    features_for_fps, features_for_fps)
                features_for_fps_distance = features_for_fps_distance.contiguous(
                )
                fps_idx = pointnet2_utils.furthest_point_sample_with_dist(
                    features_for_fps_distance, npoint)
            else:  # D-FPS
                fps_idx = pointnet2_utils.furthest_point_sample(
                    tmp_xyz, npoint)

            fps_idx = fps_idx + last_fps_end_index
            cur_fps_idx_list.append(fps_idx)
            last_fps_end_index += fps_sample_range
        fps_idx = torch.cat(cur_fps_idx_list, dim=-1)

        if former_fps_idx is not None:
            fps_idx = torch.cat([fps_idx, former_fps_idx], dim=-1)

        if vote_ctr is not None:
            vote_ctr_transpose = vote_ctr.transpose(1, 2).contiguous()
            new_xyz = pointnet2_utils.gather_operation(vote_ctr_transpose,
                                                       fps_idx).transpose(
                                                           1, 2).contiguous()
        else:
            new_xyz = pointnet2_utils.gather_operation(
                xyz.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()

        # # if deformed_xyz is not None, then no attention model
        # if use_attention:
        #     # first gather the points out
        #     new_points = gather_point(points, fps_idx) # [bs, npoint, c]
        #
        #     # choose farthest feature to center points
        #     # [bs, npoint, ndataset]
        #     relation = model_util.calc_square_dist(new_points, points)
        #     # choose these points with largest distance to center_points
        #     _, relation_idx = tf.nn.top_k(relation, k=relation.shape.as_list()[-1])

        new_points_list = []
        points = points.transpose(1, 2).contiguous()
        xyz = xyz.contiguous()
        for i in range(len(self.radius_list)):
            nsample = self.nsample_list[i]
            if self.dilated_group:
                if i == 0:
                    min_radius = 0.0
                else:
                    min_radius = self.radius_list[i - 1]
                max_radius = self.radius_list[i]
                idx = pointnet2_utils.ball_query_dilated(
                    max_radius, min_radius, nsample, xyz, new_xyz)
            else:
                radius = self.radius_list[i]
                idx = pointnet2_utils.ball_query(radius, nsample, xyz, new_xyz)

            xyz_trans = xyz.transpose(1, 2).contiguous()
            grouped_xyz = pointnet2_utils.grouping_operation(
                xyz_trans, idx)  # (B, 3, npoint, nsample)
            grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
            if points is not None:
                grouped_points = pointnet2_utils.grouping_operation(
                    points, idx)
                grouped_points = torch.cat(
                    [grouped_xyz, grouped_points],
                    dim=1)  # (B, C + 3, npoint, nsample)
            else:
                grouped_points = grouped_xyz

            new_points = self.mlp_modules[i](grouped_points)
            new_points = F.max_pool2d(new_points,
                                      kernel_size=[1, new_points.size(3)])
            new_points_list.append(new_points.squeeze(-1))

        if len(new_points_list) > 0:
            new_points_concat = torch.cat(new_points_list, dim=1)
            if cfg.MODEL.NETWORK.AGGREGATION_SA_FEATURE:
                new_points_concat = self.aggregation_layer(new_points_concat)
        else:
            new_points_concat = pointnet2_utils.gather_operation(
                points, fps_idx)
        new_points_concat = new_points_concat.transpose(1, 2).contiguous()

        return new_xyz, new_points_concat, fps_idx