def gather_class(self, gt_labels, assigned_idx): # [bs, gt_num] -> [bs, points_num, cls_num] # gt_labels_dtype = gt_labels.dtype gt_labels_f = gt_labels.unsqueeze(dim=1).float() assigned_gt_labels = grouping_operation(gt_labels_f, assigned_idx.int()) assigned_gt_labels = assigned_gt_labels.squeeze( dim=-1).long().transpose(1, 2) return assigned_gt_labels
def test_grouping_points(): if not torch.cuda.is_available(): pytest.skip() idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0], [0, 0, 0]], [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0], [0, 0, 0]]]).int().cuda() festures = torch.tensor([[[ 0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274, 0.9268, 0.8414 ], [ 5.4247, 1.5113, 2.3944, 1.4740, 5.0300, 5.1030, 1.9360, 2.1939, 2.1581, 3.4666 ], [ -1.6266, -1.0281, -1.0393, -1.6931, -1.3982, -0.5732, -1.0830, -1.7561, -1.6786, -1.6967 ]], [[ -0.0380, -0.1880, -1.5724, 0.6905, -0.3190, 0.7798, -0.3693, -0.9457, -0.2942, -1.8527 ], [ 1.1773, 1.5009, 2.6399, 5.9242, 1.0962, 2.7346, 6.0865, 1.5555, 4.3303, 2.8229 ], [ -0.6646, -0.6870, -0.1125, -0.2224, -0.3445, -1.4049, 0.4990, -0.7037, -0.9924, 0.0386 ]]]).cuda() output = grouping_operation(festures, idx) expected_output = torch.tensor([[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311], [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798]], [[5.4247, 5.4247, 5.4247], [1.4740, 1.4740, 1.4740], [2.1581, 2.1581, 2.1581], [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247]], [[-1.6266, -1.6266, -1.6266], [-1.6931, -1.6931, -1.6931], [-1.6786, -1.6786, -1.6786], [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266]]], [[[-0.0380, -0.0380, -0.0380], [-0.3693, -0.3693, -0.3693], [-1.8527, -1.8527, -1.8527], [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380]], [[1.1773, 1.1773, 1.1773], [6.0865, 6.0865, 6.0865], [2.8229, 2.8229, 2.8229], [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773]], [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990], [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]]).cuda() assert torch.allclose(output, expected_output)
def forward(self, xyz, points, former_fps_idx, vote_ctr): bs = xyz.shape[0] num_points = xyz.shape[1] cur_fps_idx_list = [] last_fps_end_index = 0 for fps_sample_range, fps_method, npoint in zip( self.fps_sample_range_list, self.fps_method_list, self.npoint_list): if fps_sample_range < 0: fps_sample_range_tmp = fps_sample_range + num_points + 1 else: fps_sample_range_tmp = fps_sample_range tmp_xyz = xyz[:, last_fps_end_index: fps_sample_range_tmp, :].contiguous() tmp_points = points[:, last_fps_end_index: fps_sample_range_tmp, :].contiguous() if npoint == 0: last_fps_end_index += fps_sample_range continue if vote_ctr is not None: npoint = vote_ctr.shape[1] fps_idx = torch.arange(npoint).int().view(1, npoint).repeat( (bs, 1)).to(tmp_xyz.device) elif fps_method == 'FS': features_for_fps = torch.cat([tmp_xyz, tmp_points], dim=-1) # dist1 = nn_distance(tmp_xyz, tmp_xyz) # dist2 = calc_square_dist(tmp_xyz, tmp_xyz, norm=False) features_for_fps_distance = calc_square_dist( features_for_fps, features_for_fps) features_for_fps_distance = features_for_fps_distance.contiguous( ) fps_idx_1 = pointnet2_utils.furthest_point_sample_with_dist( features_for_fps_distance, npoint) fps_idx_2 = pointnet2_utils.furthest_point_sample( tmp_xyz, npoint) fps_idx = torch.cat([fps_idx_1, fps_idx_2], dim=-1) # [bs, npoint * 2] elif npoint == tmp_xyz.shape[1]: fps_idx = torch.arange(npoint).int().view(1, npoint).repeat( (bs, 1)).to(tmp_xyz.device) elif fps_method == 'F-FPS': features_for_fps = torch.cat([tmp_xyz, tmp_points], dim=-1) features_for_fps_distance = calc_square_dist( features_for_fps, features_for_fps) features_for_fps_distance = features_for_fps_distance.contiguous( ) fps_idx = pointnet2_utils.furthest_point_sample_with_dist( features_for_fps_distance, npoint) else: # D-FPS fps_idx = pointnet2_utils.furthest_point_sample( tmp_xyz, npoint) fps_idx = fps_idx + last_fps_end_index cur_fps_idx_list.append(fps_idx) last_fps_end_index += fps_sample_range fps_idx = torch.cat(cur_fps_idx_list, dim=-1) if former_fps_idx is not None: fps_idx = torch.cat([fps_idx, former_fps_idx], dim=-1) if vote_ctr is not None: vote_ctr_transpose = vote_ctr.transpose(1, 2).contiguous() new_xyz = pointnet2_utils.gather_operation(vote_ctr_transpose, fps_idx).transpose( 1, 2).contiguous() else: new_xyz = pointnet2_utils.gather_operation( xyz.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() # # if deformed_xyz is not None, then no attention model # if use_attention: # # first gather the points out # new_points = gather_point(points, fps_idx) # [bs, npoint, c] # # # choose farthest feature to center points # # [bs, npoint, ndataset] # relation = model_util.calc_square_dist(new_points, points) # # choose these points with largest distance to center_points # _, relation_idx = tf.nn.top_k(relation, k=relation.shape.as_list()[-1]) new_points_list = [] points = points.transpose(1, 2).contiguous() xyz = xyz.contiguous() for i in range(len(self.radius_list)): nsample = self.nsample_list[i] if self.dilated_group: if i == 0: min_radius = 0.0 else: min_radius = self.radius_list[i - 1] max_radius = self.radius_list[i] idx = pointnet2_utils.ball_query_dilated( max_radius, min_radius, nsample, xyz, new_xyz) else: radius = self.radius_list[i] idx = pointnet2_utils.ball_query(radius, nsample, xyz, new_xyz) xyz_trans = xyz.transpose(1, 2).contiguous() grouped_xyz = pointnet2_utils.grouping_operation( xyz_trans, idx) # (B, 3, npoint, nsample) grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) if points is not None: grouped_points = pointnet2_utils.grouping_operation( points, idx) grouped_points = torch.cat( [grouped_xyz, grouped_points], dim=1) # (B, C + 3, npoint, nsample) else: grouped_points = grouped_xyz new_points = self.mlp_modules[i](grouped_points) new_points = F.max_pool2d(new_points, kernel_size=[1, new_points.size(3)]) new_points_list.append(new_points.squeeze(-1)) if len(new_points_list) > 0: new_points_concat = torch.cat(new_points_list, dim=1) if cfg.MODEL.NETWORK.AGGREGATION_SA_FEATURE: new_points_concat = self.aggregation_layer(new_points_concat) else: new_points_concat = pointnet2_utils.gather_operation( points, fps_idx) new_points_concat = new_points_concat.transpose(1, 2).contiguous() return new_xyz, new_points_concat, fps_idx