def select_patches(self, pts, refer_pts, vicinity, num_points_per_patch=1024): gc.collect() pts = torch.FloatTensor(pts).cuda().unsqueeze(0) refer_pts = torch.FloatTensor(refer_pts).cuda().unsqueeze(0) group_idx = pnt2.ball_query(vicinity, num_points_per_patch, pts, refer_pts) pts_trans = pts.transpose(1, 2).contiguous() new_points = pnt2.grouping_operation(pts_trans, group_idx) new_points = new_points.permute([0, 2, 3, 1]) mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, num_points_per_patch) mask = (group_idx == mask).float() mask[:, :, 0] = 0 mask[:, :, num_points_per_patch - 1] = 1 mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) new_pts = refer_pts.unsqueeze(2).repeat( [1, 1, num_points_per_patch, 1]) local_patches = new_points * (1 - mask).float() + new_pts * mask.float() # local_patches = list(local_patches.squeeze(0).detach().cpu().numpy()) local_patches = local_patches.squeeze(0) del mask del new_points del group_idx del new_pts del pts del pts_trans return local_patches
def forward(self, x, feature, num_pool): x = x.contiguous() xyz_flipped = x.transpose(1, 2).contiguous() x_sub = (pointnet2_utils.gather_operation( x, pointnet2_utils.furthest_point_sample(xyz_flipped, num_pool)).transpose( 1, 2).contiguous()) # sub_index = self.knn(x_sub, xyz_flipped).int() sub_index = pointnet2_utils.ball_query(0.2 * self.nlayer, self.k, xyz_flipped, x_sub) x = pointnet2_utils.grouping_operation(x, sub_index) x = torch.max(x, dim=-1)[0] # x = soft_pool2d(x, [1, x.shape[-1]]).squeeze(-1) feature = pointnet2_utils.grouping_operation(feature, sub_index) feature = self.mlp(feature) feature = torch.max(feature, dim=-1)[0] # feature = soft_pool2d(feature, [1, feature.shape[-1]]).squeeze(-1) return x, feature
def pool(xyz, points, k, npoint): xyz_flipped = xyz.transpose(1, 2).contiguous() new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample( xyz_flipped, npoint)).transpose(1, 2).contiguous() _, idx = knn_point(k, xyz, new_xyz) new_points = torch.max(pointnet2_utils.grouping_operation( points.permute(0, 2, 1).contiguous(), idx.int().permute(0, 2, 1).contiguous()).permute(0, 3, 2, 1), dim=2).values return new_xyz, new_points
def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True): """ Args: xyz: Tensor, (B, 3, N) points: Tensor, (B, f, N) npoint: int nsample: int radius: float use_xyz: boolean Returns: new_xyz: Tensor, (B, 3, npoint) new_points: Tensor, (B, 3 | f+3 | f, npoint, nsample) idx_local: Tensor, (B, npoint, nsample) grouped_xyz: Tensor, (B, 3, npoint, nsample) """ xyz_flipped = xyz.permute(0, 2, 1).contiguous() # (B, N, 3) new_xyz = gather_operation(xyz, furthest_point_sample(xyz_flipped, npoint)) # (B, 3, npoint) idx = ball_query(radius, nsample, xyz_flipped, new_xyz.permute(0, 2, 1).contiguous()) # (B, npoint, nsample) grouped_xyz = grouping_operation(xyz, idx) # (B, 3, npoint, nsample) grouped_xyz -= new_xyz.unsqueeze(3).repeat(1, 1, 1, nsample) if points is not None: grouped_points = grouping_operation(points, idx) # (B, f, npoint, nsample) if use_xyz: new_points = torch.cat([grouped_xyz, grouped_points], 1) else: new_points = grouped_points else: new_points = grouped_xyz return new_xyz, new_points, idx, grouped_xyz
def group(xyz, points, k, dilation=1, use_xyz=False): _, idx = knn_point(k * dilation + 1, xyz, xyz) idx = idx[:, :, 1::dilation].int().contiguous() # print("xyz",xyz.shape,idx.shape) xyz_trans = xyz.transpose(1, 2).contiguous() grouped_xyz = pointnet2_utils.grouping_operation( xyz_trans, idx.int().permute(0, 2, 1).contiguous()) # (batch_size, npoint, k, 3) # print("GRO",grouped_xyz.shape,xyz.shape) grouped_xyz -= torch.unsqueeze(xyz_trans, 2) # translation normalization if points is not None: # print(points.shape) grouped_points = pointnet2_utils.grouping_operation( points.permute(0, 2, 1).contiguous(), idx.int().permute( 0, 2, 1).contiguous()) # (batch_size, npoint, k, channel) if use_xyz: grouped_points = torch.cat( [grouped_xyz, grouped_points], dim=-1) # (batch_size, npoint, k, 3+channel) else: grouped_points = grouped_xyz return grouped_xyz, grouped_points, idx
def sphere_query_new(pts, new_pts, radius, nsample): """ :param pts: all points, [B. N. 3] :param new_pts: query points, [B, S. 3] :param radius: local sperical radius :param nsample: max sample number in local sphere :return: """ device = pts.device B, N, C = pts.shape _, S, _ = new_pts.shape pts = pts.contiguous() new_pts = new_pts.contiguous() group_idx = pnt2.ball_query(radius, nsample, pts, new_pts) mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample) mask = (group_idx == mask).float() mask[:, :, 0] = 0 mask1 = (group_idx[:, :, 0] == 0).unsqueeze(2).float() mask1 = torch.cat([mask1, torch.zeros_like(mask)[:, :, :-1]], dim=2) mask = mask + mask1 # C implementation pts_trans = pts.transpose(1, 2).contiguous() new_points = pnt2.grouping_operation(pts_trans, group_idx) # (B, 3, npoint, nsample) new_points = new_points.permute([0, 2, 3, 1]) # replace the wrong points using new_pts mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) n_points = new_points * (1 - mask).float() del mask del new_points del group_idx del new_pts del pts del pts_trans return n_points