def denoise_normals(points, normals, num_points, sharpness_sigma=30, knn_result=None, neighborhood_size=16): """ Weights exp(-(1-<n, n_i>)/(1-cos(sharpness_sigma))), for i in a local neighborhood """ points, num_points = convert_pointclouds_to_tensor(points) normals = F.normalize(normals, dim=-1) if knn_result is None: diag = (points.max(dim=-2)[0] - points.min(dim=-2)[0]).norm(dim=-1) avg_spacing = math.sqrt(diag / points.shape[1]) search_radius = min(4 * avg_spacing * neighborhood_size, 0.2) dists, idxs, _, grid = frnn.frnn_grid_points(points, points, num_points, num_points, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=True) knn_result = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) if knn_result.knn is None: knn = frnn.frnn_gather(points, knn_result.idx, num_points) knn_result = _KNN(idx=knn_result.idx, knn=knn, dists=knn_result.dists) # filter out knn_normals = frnn.frnn_gather(normals, knn_result.idx, num_points) # knn_normals = frnn.frnn_gather(normals, self._knn_idx, num_points) weights_n = ( (1 - torch.sum(knn_normals * normals[:, :, None, :], dim=-1)) / sharpness_sigma)**2 weights_n = torch.exp(-weights_n) inv_sigma_spatial = num_points / 2.0 spatial_dist = 16 / inv_sigma_spatial deltap = knn - points[:, :, None, :] deltap = torch.sum(deltap * deltap, dim=-1) weights_p = torch.exp(-deltap * inv_sigma_spatial) weights_p[deltap > spatial_dist] = 0 weights = weights_p * weights_n # weights[self._knn_idx < 0] = 0 normals_denoised = torch.sum(knn_normals * weights[:, :, :, None], dim=-2) / \ eps_denom(torch.sum(weights, dim=-1, keepdim=True)) normals_denoised = F.normalize(normals_denoised, dim=-1) return normals_denoised.view_as(normals)
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor: """ Naive PyTorch implementation of K-Nearest Neighbors. Returns always sorted results """ N, P1, D = p1.shape _N, P2, _D = p2.shape assert N == _N and D == _D if lengths1 is None: lengths1 = torch.full((N,), P1, dtype=torch.int64, device=p1.device) if lengths2 is None: lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device) dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device) idx = torch.zeros((N, P1, K), dtype=torch.int64, device=p1.device) for n in range(N): num1 = lengths1[n].item() num2 = lengths2[n].item() pp1 = p1[n, :num1].view(num1, 1, D) pp2 = p2[n, :num2].view(1, num2, D) diff = pp1 - pp2 diff = (diff * diff).sum(2) num2 = min(num2, K) for i in range(num1): dd = diff[i] srt_dd, srt_idx = dd.sort() dists[n, i, :num2] = srt_dd[:num2] idx[n, i, :num2] = srt_idx[:num2] return _KNN(dists=dists, idx=idx, knn=None)
def _knn_vs_python_square_helper(self, device, return_sorted): Ns = [1, 4] Ds = [3, 5, 8] P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] norms = [1, 2] versions = [0, 1, 2, 3] factors = [Ns, Ds, P1s, P2s, Ks, norms] for N, D, P1, P2, K, norm in product(*factors): for version in versions: if version == 3 and K > 4: continue x = torch.randn(N, P1, D, device=device, requires_grad=True) x_cuda = x.clone().detach() x_cuda.requires_grad_(True) y = torch.randn(N, P2, D, device=device, requires_grad=True) y_cuda = y.clone().detach() y_cuda.requires_grad_(True) # forward out1 = self._knn_points_naive( x, y, lengths1=None, lengths2=None, K=K, norm=norm ) out2 = knn_points( x_cuda, y_cuda, K=K, norm=norm, version=version, return_sorted=return_sorted, ) if K > 1 and not return_sorted: # check out2 is not sorted self.assertFalse(torch.allclose(out1[0], out2[0])) self.assertFalse(torch.allclose(out1[1], out2[1])) # now sort out2 dists, idx, _ = out2 if P2 < K: dists[..., P2:] = float("inf") dists, sort_idx = dists.sort(dim=2) dists[..., P2:] = 0 else: dists, sort_idx = dists.sort(dim=2) idx = idx.gather(2, sort_idx) out2 = _KNN(dists, idx, None) self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) # backward grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) loss1 = (out1.dists * grad_dist).sum() loss1.backward() loss2 = (out2.dists * grad_dist).sum() loss2.backward() self.assertClose(x_cuda.grad, x.grad, atol=5e-6) self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def resample_uniformly( pointclouds: Union[Pointclouds, torch.Tensor], neighborhood_size: int = 8, knn=None, normals=None, shrink_ratio: float = 0.5, repulsion_mu: float = 1.0 ) -> Union[Pointclouds, Tuple[torch.Tensor, torch.Tensor]]: """ resample first use wlop to consolidate point clouds to a smaller point clouds (halve the points) then upsample with ear Returns: Pointclouds or padded points and number of points per batch """ import math import frnn points_init, num_points = convert_pointclouds_to_tensor(pointclouds) batch_size = num_points.shape[0] diag = (points_init.view(-1, 3).max(dim=0).values - points_init.view(-1, 3).min(0).values).norm().item() avg_spacing = math.sqrt(diag / points_init.shape[1]) search_radius = min(4 * avg_spacing * neighborhood_size, 0.2) if knn is None: dists, idxs, _, grid = frnn.frnn_grid_points(points_init, points_init, num_points, num_points, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=False) knn = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) # estimate normals if isinstance(pointclouds, torch.Tensor): normals = normals else: normals = pointclouds.normals_padded() if normals is None: normals = estimate_pointcloud_normals( points_init, neighborhood_size=neighborhood_size, disambiguate_directions=False) else: normals = F.normalize(normals, dim=-1) points = points_init wlop_result = wlop(pointclouds, ratio=shrink_ratio, repulsion_mu=repulsion_mu) up_result = upsample(wlop_result, num_points) if is_pointclouds(pointclouds): return up_result return up_result.points_padded(), up_result.num_points_per_cloud()
def _ball_query_naive(p1, p2, lengths1, lengths2, K: int, radius: float) -> torch.Tensor: """ Naive PyTorch implementation of ball query. """ N, P1, D = p1.shape _N, P2, _D = p2.shape assert N == _N and D == _D if lengths1 is None: lengths1 = torch.full((N, ), P1, dtype=torch.int64, device=p1.device) if lengths2 is None: lengths2 = torch.full((N, ), P2, dtype=torch.int64, device=p1.device) radius2 = radius * radius dists = torch.zeros((N, P1, K), dtype=torch.float32, device=p1.device) idx = torch.full((N, P1, K), fill_value=-1, dtype=torch.int64, device=p1.device) # Iterate through the batches for n in range(N): num1 = lengths1[n].item() num2 = lengths2[n].item() # Iterate through the points in the p1 for i in range(num1): # Iterate through the points in the p2 count = 0 for j in range(num2): dist = p2[n, j] - p1[n, i] dist2 = (dist * dist).sum() if dist2 < radius2 and count < K: dists[n, i, count] = dist2 idx[n, i, count] = j count += 1 return _KNN(dists=dists, idx=idx, knn=None)
def upsample( pcl: Union[Pointclouds, torch.Tensor], n_points: Union[int, torch.Tensor], num_points=None, neighborhood_size=16, knn_result=None ) -> Union[Pointclouds, Tuple[torch.Tensor, torch.Tensor]]: """ Iteratively add points to the sparsest region Args: points (tensor of [N, P, 3] or Pointclouds) n_points (tensor of [N] or integer): target number of points per cloud Returns: Pointclouds or (padded_points, num_points) """ def _return_value(points, num_points, return_pcl): if return_pcl: points_list = padded_to_list(points, num_points.tolist()) return pcl.__class__(points_list) else: return points, num_points return_pcl = is_pointclouds(pcl) points, num_points = convert_pointclouds_to_tensor(pcl) knn_k = neighborhood_size if not ((num_points - num_points[0]) == 0).all(): logger_py.warn( "Upsampling operation may encounter unexpected behavior for heterogeneous batches" ) if num_points.sum() == 0: return _return_value(points, num_points, return_pcl) n_remaining = (n_points - num_points).to(dtype=torch.long) if (n_remaining <= 0).all(): return _return_value(points, num_points, return_pcl) if knn_result is None: knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True, return_sorted=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) while True: if (n_remaining == 0).all(): break # half of the points per batch sparse_pts = points sparse_dists = knn_result.dists sparse_knn = knn_result.knn batch_size, P, _ = sparse_pts.shape max_P = (P // 8) # sparse_knn_normals = frnn.frnn_gather( # normals_init, knn_result.idx, num_points)[:, 1:] # get all mid points mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3 # N,P,K,K,3 mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3) # minimize among all the neighbors min_dist2 = torch.norm(mid_nn_diff, dim=-1) # N,P,K,K min_dist2 = min_dist2.min(dim=-1)[0] # N,P,K father_sparsity, father_nb = min_dist2.max(dim=-1) # N,P # neighborhood to insert sparsity_sorted = father_sparsity.sort(dim=1).indices n_new_points = n_remaining.clone() n_new_points[n_new_points > max_P] = max_P sparsity_sorted = sparsity_sorted[:, -max_P:] new_pts = torch.gather( mid_points[torch.arange(mid_points.shape[0]).view(-1, 1, 1), torch.arange(mid_points.shape[1]).view(1, -1, 1), father_nb.unsqueeze(-1)].squeeze(-2), 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) sparse_selected = torch.gather( sparse_pts, 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) total_pts_list = [] for b, pts_batch in enumerate( padded_to_list(points, num_points.tolist())): total_pts_list.append( torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0)) points = list_to_padded(total_pts_list) n_remaining = n_remaining - n_new_points num_points = n_new_points + num_points knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) return _return_value(points, num_points, return_pcl)
def wlop(pointclouds: PointClouds3D, ratio: float = 0.5, neighborhood_size=16, iters=3, repulsion_mu=0.5) -> PointClouds3D: """ Consolidation of Unorganized Point Clouds for Surface Reconstruction Args: pointclouds containing max J points per cloud ratio: downsampling ratio (0, 1] """ P, num_points_P = convert_pointclouds_to_tensor(pointclouds) # (N, 3, 2) bbox = pointclouds.get_bounding_boxes() # (N,) diag = torch.norm(bbox[..., 0] - bbox[..., 1], dim=-1) h = 4 * torch.sqrt(diag / num_points_P.float()) search_radius = min(h * neighborhood_size, 0.2) theta_sigma_inv = 16 / h / h if ratio < 1.0: X0 = farthest_sampling(pointclouds, ratio=ratio) elif ratio == 1.0: X0 = pointclouds.clone() else: raise ValueError('ratio must be less or equal to 1.0') # slightly perturb so that we don't find the same point when searching NN XtoP offset = torch.randn_like(X0.points_packed()) * h * 0.1 X0.offset_(offset) X, num_points_X = convert_pointclouds_to_tensor(X0) def theta(r2): return torch.exp(-r2 * theta_sigma_inv) def eta(r): return -r def deta(r): return torch.ones_like(r) grid = None dists, idxs, _, grid = frnn.frnn_grid_points(P, P, num_points_P, num_points_P, K=neighborhood_size + 1, r=search_radius, grid=grid, return_nn=False) knn_PtoP = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) deltapp = torch.norm(P.unsqueeze(-2) - frnn.frnn_gather(P, knn_PtoP.idx, num_points_P), dim=-1) theta_pp_nn = theta(deltapp**2) # (B, P, K) theta_pp_nn[knn_PtoP.idx < 0] = 0 density_P = torch.sum(theta_pp_nn, dim=-1) + 1 for it in range(iters): # from each x find closest neighbors in pointclouds dists, idxs, _, grid = frnn.frnn_grid_points(X, P, num_points_X, num_points_P, K=neighborhood_size, r=search_radius, grid=grid, return_nn=False) knn_XtoP = _KNN(dists=dists, idx=idxs, knn=None) dists, idxs, _, _ = frnn.frnn_grid_points(X, X, num_points_X, num_points_X, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=False) knn_XtoX = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) # LOP local optimal projection nn_XtoP = frnn.frnn_gather(P, knn_XtoP.idx, num_points_P) epsilon = X.unsqueeze(-2) - frnn.frnn_gather(P, knn_XtoP.idx, num_points_P) delta = X.unsqueeze(-2) - frnn.frnn_gather(X, knn_XtoX.idx, num_points_X) # (B, I, I) deltaxx2 = (delta**2).sum(dim=-1) # (B, I, K) deltaxp2 = (epsilon**2).sum(dim=-1) # (B, I, K) alpha = theta(deltaxp2) / eps_denom(epsilon.norm(dim=-1)) # (B, I, K) beta = theta(deltaxx2) * deta(delta.norm(dim=-1)) / eps_denom( delta.norm(dim=-1)) density_X = torch.sum(theta(deltaxx2), dim=-1) + 1 new_alpha = alpha / frnn.frnn_gather( density_P.unsqueeze(-1), knn_XtoP.idx, num_points_P).squeeze(-1) new_alpha[knn_XtoP.idx < 0] = 0 new_beta = density_X.unsqueeze(-1) * beta new_beta[knn_XtoX.idx < 0] = 0 term_data = torch.sum(new_alpha[..., None] * nn_XtoP, dim=-2) / \ eps_denom(torch.sum(new_alpha, dim=-1, keepdim=True)) term_repul = repulsion_mu * torch.sum(new_beta[..., None] * delta, dim=-2) / \ eps_denom(torch.sum(new_beta, dim=-1, keepdim=True)) X = term_data + term_repul if is_pointclouds(X0): return X0.update_padded(X) return X
def project_to_latent_surface(points, normals, sharpness_angle=60, neighborhood_size=31, max_proj_iters=10, max_est_iter=5): """ RIMLS """ points, num_points = convert_pointclouds_to_tensor(points) normals = F.normalize(normals, dim=-1) sharpness_sigma = 1 - math.cos(sharpness_angle / 180 * math.pi) diag = (points.max(dim=-2)[0] - points.min(dim=-2)[0]).norm(dim=-1) avg_spacing = math.sqrt(diag / points.shape[1]) search_radius = min(16 * avg_spacing * neighborhood_size, 0.2) dists, idxs, _, grid = frnn.frnn_grid_points(points, points, num_points, num_points, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=False) knn_result = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) # knn_normals = knn_gather(normals, knn_result.idx, num_points) knn_normals = frnn.frnn_gather(normals, knn_result.idx, num_points) inv_sigma_spatial = 1 / knn_result.dists[..., 0] / 16 # spatial_dist = 16 / inv_sigma_spatial not_converged = torch.full(points.shape[:-1], True, device=points.device, dtype=torch.bool) itt = 0 it = 0 while True: knn_pts = frnn.frnn_gather(points, knn_result.idx, num_points) pts_diff = points[not_converged].unsqueeze(-2) - knn_pts[not_converged] fx = torch.sum(pts_diff * knn_normals[not_converged], dim=-1) not_converged_1 = torch.full(fx.shape[:-1], True, dtype=torch.bool, device=fx.device) knn_normals_1 = knn_normals[not_converged] inv_sigma_spatial_1 = inv_sigma_spatial[not_converged] f = points.new_zeros(points[not_converged].shape[:-1], device=points.device) grad_f = points.new_zeros(points[not_converged].shape, device=points.device) alpha = torch.ones_like(fx) for itt in range(max_est_iter): if itt > 0: alpha_old = alpha weights_n = ( (knn_normals_1[not_converged_1] - grad_f[not_converged_1].unsqueeze(-2)).norm(dim=-1) / 0.5)**2 weights_n = torch.exp(-weights_n) weights_p = torch.exp(-( (fx[not_converged_1] - f[not_converged_1].unsqueeze(-1))**2 * inv_sigma_spatial_1[not_converged_1].unsqueeze(-1) / 4)) alpha[not_converged_1] = weights_n * weights_p not_converged_1[not_converged_1] = ( alpha[not_converged_1] - alpha_old[not_converged_1]).abs().max(dim=-1)[0] < 1e-4 if not not_converged_1.any(): break deltap = torch.sum(pts_diff[not_converged_1] * pts_diff[not_converged_1], dim=-1) phi = torch.exp(-deltap * inv_sigma_spatial_1[not_converged_1].unsqueeze(-1)) # phi[deltap > spatial_dist] = 0 dphi = inv_sigma_spatial_1[not_converged_1].unsqueeze(-1) * phi weights = phi * alpha[not_converged_1] grad_weights = 2 * pts_diff * (dphi * weights).unsqueeze(-1) sum_grad_weights = torch.sum(grad_weights, dim=-2) sum_weight = torch.sum(weights, dim=-1) sum_f = torch.sum(fx[not_converged_1] * weights, dim=-1) sum_Gf = torch.sum(grad_weights * fx[not_converged_1].unsqueeze(-1), dim=-2) sum_N = torch.sum(weights.unsqueeze(-1) * knn_normals_1[not_converged_1], dim=-2) tmp_f = sum_f / eps_denom(sum_weight) tmp_grad_f = (sum_Gf - tmp_f.unsqueeze(-1) * sum_grad_weights + sum_N) / eps_denom(sum_weight).unsqueeze(-1) grad_f[not_converged_1] = tmp_grad_f f[not_converged_1] = tmp_f move = f.unsqueeze(-1) * grad_f points[not_converged] = points[not_converged] - move mask = move.norm(dim=-1) > 5e-4 not_converged[not_converged] = mask it = it + 1 if not not_converged.any() or it >= max_proj_iters: break return points
def upsample(points, n_points: Union[int, torch.Tensor], num_points=None, neighborhood_size=16, knn_result=None): """ Args: points (N, P, 3) n_points (tensor of [N] or integer): target number of points per cloud """ batch_size = points.shape[0] knn_k = neighborhood_size if num_points is None: num_points = torch.tensor([points.shape[1]] * points.shape[0], device=points.device, dtype=torch.long) if not ((num_points - num_points[0]) == 0).all(): logger_py.warn( "May encounter unexpected behavior for heterogeneous batches") if num_points.sum() == 0: return points, num_points n_remaining = n_points - num_points if (n_remaining == 0).all(): return points, num_points point_cloud_diag = (points.max(dim=-2)[0] - points.min(dim=-2)[0]).norm(dim=-1) inv_sigma_spatial = num_points / point_cloud_diag spatial_dist = 16 / inv_sigma_spatial if knn_result is None: knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True, return_sorted=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) while True: if (n_remaining == 0).all(): break # half of the points per batch sparse_pts = points sparse_dists = knn_result.dists sparse_knn = knn_result.knn batch_size, P, _ = sparse_pts.shape max_P = (P // 10) # sparse_knn_normals = frnn.frnn_gather( # normals_init, knn_result.idx, num_points)[:, 1:] # get all mid points mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3 # N,P,K,K,3 mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3) # minimize among all the neighbors min_dist2 = torch.norm(mid_nn_diff, dim=-1) # N,P,K,K min_dist2 = min_dist2.min(dim=-1)[0] # N,P,K father_sparsity, father_nb = min_dist2.max(dim=-1) # N,P # neighborhood to insert sparsity_sorted = father_sparsity.sort(dim=1).indices n_new_points = n_remaining.clone() n_new_points[n_new_points > max_P] = max_P sparsity_sorted = sparsity_sorted[:, -max_P:] new_pts = torch.gather( mid_points[torch.arange(mid_points.shape[0]), torch.arange(mid_points.shape[1]), father_nb], 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) from DSS.utils.io import save_ply sparse_selected = torch.gather( sparse_pts, 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) # save_ply('tests/outputs/test_uniform_projection/init.ply', sparse_pts.view(-1,3).cpu()) # save_ply('tests/outputs/test_uniform_projection/sparse.ply', sparse_selected[0].cpu()) # save_ply('tests/outputs/test_uniform_projection/new_pts.ply', new_pts.view(-1,3).cpu().detach()) # import pdb; pdb.set_trace() total_pts_list = [] for b, pts_batch in enumerate( padded_to_list(points, num_points.tolist())): total_pts_list.append( torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0)) points = list_to_padded(total_pts_list) n_remaining = n_remaining - n_new_points num_points = n_new_points + num_points knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) return points, num_points
def resample_uniformly(pointclouds, neighborhood_size=8, iters=1, knn=None, normals=None, reproject=False, repulsion_mu=1.0): """ resample sample_iters times """ import math import frnn points_init, num_points = convert_pointclouds_to_tensor(pointclouds) batch_size = num_points.shape[0] # knn_result = knn_points( # points_init, points_init, num_points, num_points, K=neighborhood_size + 1, return_nn=True) diag = (points_init.view(-1, 3).max(dim=0).values - points_init.view(-1, 3).min(0).values).norm().item() avg_spacing = math.sqrt(diag / points_init.shape[1]) search_radius = min(4 * avg_spacing * neighborhood_size, 0.2) if knn is None: dists, idxs, _, grid = frnn.frnn_grid_points(points_init, points_init, num_points, num_points, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=False) knn = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) # estimate normals if isinstance(pointclouds, torch.Tensor): normals = normals else: normals = pointclouds.normals_padded() if normals is None: normals = estimate_pointcloud_normals( points_init, neighborhood_size=neighborhood_size, disambiguate_directions=False) else: normals = F.normalize(normals, dim=-1) points = points_init for i in range(iters): if reproject: normals = denoise_normals(points, normals, num_points, knn_result=knn) points = project_to_latent_surface(points, normals, max_proj_iters=2, max_est_iter=3) if i > 0 and i % 3 == 0: dists, idxs, _, grid = frnn.frnn_grid_points(points_init, points_init, num_points, num_points, K=neighborhood_size + 1, r=search_radius, grid=None, return_nn=False) knn = _KNN(dists=dists[..., 1:], idx=idxs[..., 1:], knn=None) nn = frnn.frnn_gather(points, knn.idx, num_points) pts_diff = points.unsqueeze(-2) - nn dists = torch.sum(pts_diff**2, dim=-1) knn_result = _KNN(dists=dists, idx=knn.idx, knn=nn) deltap = knn_result.dists inv_sigma_spatial = num_points / 2.0 / 16 spatial_w = torch.exp(-deltap * inv_sigma_spatial) spatial_w[knn_result.idx < 0] = 0 # density_w = torch.sum(spatial_w, dim=-1) + 1.0 # 0.5 * derivative of (-r)exp(-r^2*inv) density = frnn.frnn_gather( spatial_w.sum(-1, keepdim=True) + 1.0, knn.idx, num_points) nn_normals = frnn.frnn_gather(normals, knn_result.idx, num_points) pts_diff_proj = pts_diff - (pts_diff * nn_normals).sum( dim=-1, keepdim=True) * nn_normals # move = 0.5 * torch.sum(density*spatial_w[..., None] * pts_diff_proj, dim=-2) / torch.sum(density.view_as(spatial_w)*spatial_w, dim=-1).unsqueeze(-1) # move = F.normalize(move, dim=-1) * move.norm(dim=-1, keepdim=True).clamp_max(2*avg_spacing) move = repulsion_mu * avg_spacing * torch.mean( density * spatial_w[..., None] * F.normalize(pts_diff_proj, dim=-1), dim=-2) points = points + move # then project to latent surface again if is_pointclouds(pointclouds): return pointclouds.update_padded(points) return points