def getChamferDist(x, y): ''' Computer chamfer distance :param x: :param y: :return: ''' xlengths = torch.full((x.shape[0], ), x.shape[1], dtype=torch.int64, device=x.device) ylengths = torch.full((y.shape[0], ), y.shape[1], dtype=torch.int64, device=y.device) x_nn = knn_points(x, y, lengths1=xlengths, lengths2=ylengths, K=1) y_nn = knn_points(y, x, lengths1=ylengths, lengths2=xlengths, K=1) cham_x = x_nn.dists[..., 0] # (N, P1) cham_y = y_nn.dists[..., 0] # (N, P2) return cham_x, cham_y
def compute_link_loss(self, scene_output, supervision): corresp = supervision["mano_corresp"] links = supervision["links"] body_verts = scene_output["body_info"]["verts"] left_hand_verts = body_verts[:, corresp["left_hand"]] right_hand_verts = body_verts[:, corresp["right_hand"]] obj_verts = torch.cat( [info["verts"] for info in scene_output["obj_infos"]], 2) # Compute min obj2 hand distance left2obj_mins = knn_points(obj_verts, left_hand_verts, K=1)[0].min(1)[0][:, 0] right2obj_mins = knn_points(obj_verts, right_hand_verts, K=1)[0].min(1)[0][:, 0] right_flags = obj_verts.new(links)[:, 1] left_flags = obj_verts.new(links)[:, 0] batch_min_dists = (left_flags * left2obj_mins + right_flags * right2obj_mins) loss = batch_min_dists.mean() min_dists = { "left": left2obj_mins.detach().cpu(), "right": right2obj_mins.detach().cpu(), "left_flags": left_flags.detach().cpu(), "right_flags": right_flags.detach(), } loss_info = {"link_min_dists": min_dists} return loss, loss_info
def hausdorff_distance_vismesh(ma: trimesh.Trimesh, mb: trimesh.Trimesh, min=0, max=10 * 1e-3): ''' ma: trimesh mb: trimesh min: for visualization, clip cham_ab, color blue max: for visualization, clip cham_ab, color blue return: bidirectional hausdorff distance, colored mesh ma, chamfer distance from ma to mb, chamfer distance from mb to ma ''' va = torch.from_numpy(ma.vertices).float().unsqueeze(0) vb = torch.from_numpy(mb.vertices).float().unsqueeze(0) x_nn = knn_points(va, vb, K=1) cham_x = x_nn.dists[..., 0].squeeze() # (N, P1) cham_x = cham_x**0.5 y_nn = knn_points(vb, va, K=1) cham_y = y_nn.dists[..., 0].squeeze() # (N, P1) cham_y = cham_y**0.5 hd_xy = cham_x.max() hd_yx = cham_y.max() hd = torch.max(hd_xy, hd_yx) # color ma by cham_x cham_x_clip = cham_x.clone() cham_x_clip[cham_x_clip > max] = max cham_x_clip[cham_x_clip < min] = min ratio = (cham_x_clip - min) / (max - min) max_h = 0 #red min_h = 0.667 #blue color_h = min_h - ratio * min_h rgb = np.stack([colorsys.hsv_to_rgb(i, 1, 0.8) for i in color_h]) a = np.ones([rgb.shape[0], 1]) rgba = np.hstack([rgb, a]) rgba *= 255. ma.visual.vertex_colors = rgba cham_y_clip = cham_y.clone() cham_y_clip[cham_y_clip > max] = max cham_y_clip[cham_y_clip < min] = min ratio = (cham_y_clip - min) / (max - min) color_h = min_h - ratio * min_h rgb = np.stack([colorsys.hsv_to_rgb(i, 1, 0.8) for i in color_h]) a = np.ones([rgb.shape[0], 1]) rgba = np.hstack([rgb, a]) rgba *= 255. mb.visual.vertex_colors = rgba return hd, ma, mb, cham_x.numpy(), cham_y.numpy()
def test_invalid_norm(self): device = get_random_cuda_device() N, P1, P2, K, D = 4, 16, 12, 8, 3 x = torch.rand((N, P1, D), device=device) y = torch.rand((N, P2, D), device=device) with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): knn_points(x, y, K=K, norm=3) with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): knn_points(x, y, K=K, norm=0)
def discrete_project(self, pc: torch.Tensor, thres=0.9, cpu=False): with torch.no_grad(): device = torch.device('cpu') if cpu else self.device pc = pc.double() if isinstance(self, Mesh): mid_points = self.vs[self.faces].mean(dim=1) normals = self.normals else: mid_points = self[:, :3] normals = self[:, 3:] pk12 = knn_points(mid_points[:, :3].unsqueeze(0), pc[:, :, :3], K=3).idx[0] pk21 = knn_points(pc[:, :, :3], mid_points[:, :3].unsqueeze(0), K=3).idx[0] loop = pk21[pk12].view(pk12.shape[0], -1) knn_mask = (loop == torch.arange( 0, pk12.shape[0], device=self.device)[:, None]).sum(dim=1) > 0 mid_points = mid_points.to(device) pc = pc[0].to(device) normals = normals.to(device)[~knn_mask, :] masked_mid_points = mid_points[~knn_mask, :] displacement = masked_mid_points[:, None, :] - pc[:, :3] torch.cuda.empty_cache() distance = displacement.norm(dim=-1) mask = (torch.abs( torch.sum((displacement / distance[:, :, None]) * normals[:, None, :], dim=-1)) > thres) if pc.shape[-1] == 6: pc_normals = pc[:, 3:] normals_correlation = torch.sum(normals[:, None, :] * pc_normals, dim=-1) mask = mask * (normals_correlation > 0) torch.cuda.empty_cache() distance[~mask] += float('inf') min, argmin = distance.min(dim=-1) pc_per_face_masked = pc[argmin, :].clone() pc_per_face_masked[min == float('inf'), :] = float('nan') pc_per_face = torch.zeros(mid_points.shape[0], 6).\ type(pc_per_face_masked.dtype).to(pc_per_face_masked.device) pc_per_face[~knn_mask, :pc.shape[-1]] = pc_per_face_masked pc_per_face[knn_mask, :] = float('nan') # clean up del knn_mask return pc_per_face.to( self.device), (pc_per_face[:, 0] == pc_per_face[:, 0]).to(device)
def _knn_vs_python_ragged_helper(self, device): Ns = [1, 4] Ds = [3, 5, 8] P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] factors = [Ns, Ds, P1s, P2s, Ks] for N, D, P1, P2, K in product(*factors): x = torch.rand((N, P1, D), device=device, requires_grad=True) y = torch.rand((N, P2, D), device=device, requires_grad=True) lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) x_csrc = x.clone().detach() x_csrc.requires_grad_(True) y_csrc = y.clone().detach() y_csrc.requires_grad_(True) # forward out1 = self._knn_points_naive( x, y, lengths1=lengths1, lengths2=lengths2, K=K ) out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K) self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) # backward grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) loss1 = (out1.dists * grad_dist).sum() loss1.backward() loss2 = (out2.dists * grad_dist).sum() loss2.backward() self.assertClose(x_csrc.grad, x.grad, atol=5e-6) self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
def _knn_vs_python_square_helper(self, device): Ns = [1, 4] Ds = [3, 5, 8] P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] versions = [0, 1, 2, 3] factors = [Ns, Ds, P1s, P2s, Ks] for N, D, P1, P2, K in product(*factors): for version in versions: if version == 3 and K > 4: continue x = torch.randn(N, P1, D, device=device, requires_grad=True) x_cuda = x.clone().detach() x_cuda.requires_grad_(True) y = torch.randn(N, P2, D, device=device, requires_grad=True) y_cuda = y.clone().detach() y_cuda.requires_grad_(True) # forward out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K) out2 = knn_points(x_cuda, y_cuda, K=K, version=version) self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) # backward grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) loss1 = (out1.dists * grad_dist).sum() loss1.backward() loss2 = (out2.dists * grad_dist).sum() loss2.backward() self.assertClose(x_cuda.grad, x.grad, atol=5e-6) self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def _knn_vs_python_square_helper(self, device, return_sorted): Ns = [1, 4] Ds = [3, 5, 8] P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] norms = [1, 2] versions = [0, 1, 2, 3] factors = [Ns, Ds, P1s, P2s, Ks, norms] for N, D, P1, P2, K, norm in product(*factors): for version in versions: if version == 3 and K > 4: continue x = torch.randn(N, P1, D, device=device, requires_grad=True) x_cuda = x.clone().detach() x_cuda.requires_grad_(True) y = torch.randn(N, P2, D, device=device, requires_grad=True) y_cuda = y.clone().detach() y_cuda.requires_grad_(True) # forward out1 = self._knn_points_naive( x, y, lengths1=None, lengths2=None, K=K, norm=norm ) out2 = knn_points( x_cuda, y_cuda, K=K, norm=norm, version=version, return_sorted=return_sorted, ) if K > 1 and not return_sorted: # check out2 is not sorted self.assertFalse(torch.allclose(out1[0], out2[0])) self.assertFalse(torch.allclose(out1[1], out2[1])) # now sort out2 dists, idx, _ = out2 if P2 < K: dists[..., P2:] = float("inf") dists, sort_idx = dists.sort(dim=2) dists[..., P2:] = 0 else: dists, sort_idx = dists.sort(dim=2) idx = idx.gather(2, sort_idx) out2 = _KNN(dists, idx, None) self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) # backward grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device) loss1 = (out1.dists * grad_dist).sum() loss1.backward() loss2 = (out2.dists * grad_dist).sum() loss2.backward() self.assertClose(x_cuda.grad, x.grad, atol=5e-6) self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
def output(): out = knn_points(pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K) loss = (out.dists * grad_dists).sum() loss.backward() torch.cuda.synchronize()
def feeature_pooling(x, x_lengths, y, y_lengths, neighbors): x = x x_lengths = x_lengths y = y y_lengths = y_lengths N, P1, D = x.shape P2 = y.shape[1] # T = 0.001 T2 = 1e-20 if P1 < 15 or P2 < 15: raise ValueError( "x or y does not have the enough points (at lest 15 points).") x_nn = knn_points(x, x, lengths1=x_lengths, lengths2=x_lengths, K=neighbors) y_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=neighbors) x_coor_near = knn_gather(x, x_nn.idx, x_lengths) #batch,points,nei,3 y_coor_near = knn_gather(y, y_nn.idx, x_lengths) if y.shape[0] != N or y.shape[2] != D: raise ValueError("y does not have the correct shape.") # three scales SPED1 = SPED(x, x_nn, y_nn, T2, N, P1, 10, x_coor_near, y_coor_near, D) SPED2 = SPED(x, x_nn, y_nn, T2, N, P1, 5, x_coor_near, y_coor_near, D) SPED3 = SPED(x, x_nn, y_nn, T2, N, P1, 1, x_coor_near, y_coor_near, D) SPED1 = SPED1.sum() SPED2 = SPED2.sum() SPED3 = SPED3.sum() MPED_SCORE = (SPED1 + SPED2 + SPED3) MPED_SCORE = MPED_SCORE / P1 MPED_SCORE = MPED_SCORE / N return MPED_SCORE
def knn(self): dists, idxs, nn = knn_points(self.pc1_knn, self.pc2_knn, self.lengths1_cuda, self.lengths2_cuda, K=self.K, version=-1, return_nn=True, return_sorted=True) # for backward, assume all we have k neighbors within the radius # mask = dists > self.r * self.r # idxs[mask] = -1 # dists[mask] = -1 # nn[mask] = 0. return dists, idxs, nn
def test_knn_gather(self): device = get_random_cuda_device() N, P1, P2, K, D = 4, 16, 12, 8, 3 x = torch.rand((N, P1, D), device=device) y = torch.rand((N, P2, D), device=device) lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device) out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K) y_nn = knn_gather(y, out.idx, lengths2) for n in range(N): for p1 in range(P1): for k in range(K): if k < lengths2[n]: self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]]) else: self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
def get_NN(src_xyz, trg_xyz, k=1): ''' :param src_xyz: [B, N1, 3] :param trg_xyz: [B, N2, 3] :return: nn_dists, nn_dix: all [B, 3000] tensor for NN distance and index in N2 ''' B = src_xyz.size(0) src_lengths = torch.full((src_xyz.shape[0], ), src_xyz.shape[1], dtype=torch.int64, device=src_xyz.device) # [B], N for each num trg_lengths = torch.full((trg_xyz.shape[0], ), trg_xyz.shape[1], dtype=torch.int64, device=trg_xyz.device) src_nn = knn_points(src_xyz, trg_xyz, lengths1=src_lengths, lengths2=trg_lengths, K=k) # [dists, idx] nn_dists = src_nn.dists[..., 0] nn_idx = src_nn.idx[..., 0] return nn_dists, nn_idx
def chamfer_distance( x, y, x_lengths=None, y_lengths=None, x_normals=None, y_normals=None, weights=None, batch_reduction: Union[str, None] = "mean", point_reduction: str = "mean", ): """ Chamfer distance between two pointclouds x and y. Args: x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing a batch of point clouds with at most P1 points in each batch element, batch size N and feature dimension D. y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing a batch of point clouds with at most P2 points in each batch element, batch size N and feature dimension D. x_lengths: Optional LongTensor of shape (N,) giving the number of points in each cloud in x. y_lengths: Optional LongTensor of shape (N,) giving the number of points in each cloud in y. x_normals: Optional FloatTensor of shape (N, P1, D). y_normals: Optional FloatTensor of shape (N, P2, D). weights: Optional FloatTensor of shape (N,) giving weights for batch elements for reduction operation. batch_reduction: Reduction operation to apply for the loss across the batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the points, can be one of ["mean", "sum"]. Returns: 2-element tuple containing - **loss**: Tensor giving the reduced distance between the pointclouds in x and the pointclouds in y. - **loss_normals**: Tensor giving the reduced cosine distance of normals between pointclouds in x and pointclouds in y. Returns None if x_normals and y_normals are None. """ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) return_normals = x_normals is not None and y_normals is not None N, P1, D = x.shape P2 = y.shape[1] # Check if inputs are heterogeneous and create a lengths mask. is_x_heterogeneous = (x_lengths != P1).any() is_y_heterogeneous = (y_lengths != P2).any() x_mask = (torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] ) # shape [N, P1] y_mask = (torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] ) # shape [N, P2] if y.shape[0] != N or y.shape[2] != D: raise ValueError("y does not have the correct shape.") if weights is not None: if weights.size(0) != N: raise ValueError("weights must be of shape (N,).") if not (weights >= 0).all(): raise ValueError("weights cannot be negative.") if weights.sum() == 0.0: weights = weights.view(N, 1) if batch_reduction in ["mean", "sum"]: return ( (x.sum((1, 2)) * weights).sum() * 0.0, (x.sum((1, 2)) * weights).sum() * 0.0, ) return ((x.sum((1, 2)) * weights) * 0.0, (x.sum( (1, 2)) * weights) * 0.0) cham_norm_x = x.new_zeros(()) cham_norm_y = x.new_zeros(()) x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) cham_x = x_nn.dists[..., 0] # (N, P1) cham_y = y_nn.dists[..., 0] # (N, P2) if is_x_heterogeneous: cham_x[x_mask] = 0.0 if is_y_heterogeneous: cham_y[y_mask] = 0.0 if weights is not None: cham_x *= weights.view(N, 1) cham_y *= weights.view(N, 1) if return_normals: # Gather the normals using the indices and keep only value for k=0 x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :] cham_norm_x = 1 - torch.abs( F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)) cham_norm_y = 1 - torch.abs( F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)) if is_x_heterogeneous: cham_norm_x[x_mask] = 0.0 if is_y_heterogeneous: cham_norm_y[y_mask] = 0.0 if weights is not None: cham_norm_x *= weights.view(N, 1) cham_norm_y *= weights.view(N, 1) # Apply point reduction cham_x = cham_x.sum(1) # (N,) cham_y = cham_y.sum(1) # (N,) if return_normals: cham_norm_x = cham_norm_x.sum(1) # (N,) cham_norm_y = cham_norm_y.sum(1) # (N,) if point_reduction == "mean": cham_x /= x_lengths cham_y /= y_lengths if return_normals: cham_norm_x /= x_lengths cham_norm_y /= y_lengths if batch_reduction is not None: # batch_reduction == "sum" cham_x = cham_x.sum() cham_y = cham_y.sum() if return_normals: cham_norm_x = cham_norm_x.sum() cham_norm_y = cham_norm_y.sum() if batch_reduction == "mean": div = weights.sum() if weights is not None else N cham_x /= div cham_y /= div if return_normals: cham_norm_x /= div cham_norm_y /= div cham_dist = cham_x + cham_y cham_normals = cham_norm_x + cham_norm_y if return_normals else None return cham_dist, cham_normals
def upsample_ear(points, normals, n_points: Union[int, torch.Tensor], num_points=None, neighborhood_size=16, repulsion_mu=0.4, edge_sensitivity=1.0): """ Args: points (N, P, 3) n_points (tensor of [N] or integer): target number of points per cloud """ batch_size = points.shape[0] knn_k = neighborhood_size if num_points is None: num_points = torch.tensor([points.shape[1]] * points.shape[0], device=points.device, dtype=torch.long) if not ((num_points - num_points[0]) == 0).all(): logger_py.warn( "May encounter unexpected behavior for heterogeneous batches") if num_points.sum() == 0: return points, num_points point_cloud_diag = (points.max(dim=-2)[0] - points.min(dim=-2)[0]).norm(dim=-1) inv_sigma_spatial = num_points / point_cloud_diag spatial_dist = 16 / inv_sigma_spatial knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True, return_sorted=True) # dists, idxs, nn, grid = frnn.frnn_grid_points(points_proj, points_proj, num_points, num_points, K=self.knn_k + 1, # r=torch.sqrt(spatial_dist), return_nn=True) # knn_result = _KNN(dists=dists, idx=idxs, knn=nn) _knn_idx = knn_result.idx[..., 1:] _knn_dists = knn_result.dists[..., 1:] _knn_nn = knn_result.knn[..., 1:, :] move_clip = knn_result.dists[..., 1].mean().sqrt() # 2. LOP projection if denoise_normals: normals_denoised, weights_p, weights_n = denoise_normals( points, normals, num_points, knn_result=knn_result) normals = normals_denoised # (optional) search knn in the original points # e(-(<n, p-pi>)^2/sigma_p) weight_lop = torch.exp(-torch.sum(normals[:, :, None, :] * (points[:, :, None, :] - _knn_nn), dim=-1)**2 * inv_sigma_spatial) weight_lop[_knn_dists > spatial_dist] = 0 # weight_lop[self._knn_idx < 0] = 0 # spatial weight deltap = _knn_dists spatial_w = torch.exp(-deltap * inv_sigma_spatial) spatial_w[deltap > spatial_dist] = 0 # spatial_w[self._knn_idx[..., 1:] < 0] = 0 density_w = torch.sum(spatial_w, dim=-1) + 1.0 move_data = torch.sum( weight_lop[..., None] * (points[:, :, None, :] - _knn_nn), dim=-2) / \ eps_denom(torch.sum(weight_lop, dim=-1, keepdim=True)) move_repul = repulsion_mu * density_w[..., None] * torch.sum(spatial_w[..., None] * ( knn_result.knn[:, :, 1:, :] - points[:, :, None, :]), dim=-2) / \ eps_denom(torch.sum(spatial_w, dim=-1, keepdim=True)) move_repul = F.normalize(move_repul) * move_repul.norm( dim=-1, keepdim=True).clamp_max(move_clip) move_data = F.normalize(move_data) * move_data.norm( dim=-1, keepdim=True).clamp_max(move_clip) move = move_data + move_repul points = points - move n_remaining = n_points - num_points while True: if (n_remaining == 0).all(): break # half of the points per batch sparse_pts = points sparse_dists = _knn_dists sparse_knn = _knn_nn batch_size, P, _ = sparse_pts.shape max_P = (P // 10) # sparse_knn_normals = frnn.frnn_gather( # normals_init, knn_result.idx, num_points)[:, 1:] # get all mid points mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3 # N,P,K,K,3 mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3) # minimize among all the neighbors min_dist2 = torch.norm(mid_nn_diff, dim=-1) # N,P,K,K min_dist2 = min_dist2.min(dim=-1)[0] # N,P,K father_sparsity, father_nb = min_dist2.max(dim=-1) # N,P # neighborhood to insert sparsity_sorted = father_sparsity.sort(dim=1).indices n_new_points = n_remaining.clone() n_new_points[n_new_points > max_P] = max_P sparsity_sorted = sparsity_sorted[:, -max_P:] # N, P//2, 3, sparsest at the end new_pts = torch.gather( mid_points[torch.arange(mid_points.shape[0]), torch.arange(mid_points.shape[1]), father_nb], 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) total_pts_list = [] for b, pts_batch in enumerate( padded_to_list(points, num_points.tolist())): total_pts_list.append( torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0)) points_proj = list_to_padded(total_pts_list) n_remaining = n_remaining - n_new_points num_points = n_new_points + num_points knn_result = knn_points(points_proj, points_proj, num_points, num_points, K=knn_k + 1, return_nn=True) _knn_idx = knn_result.idx[..., 1:] _knn_dists = knn_result.dists[..., 1:] _knn_nn = knn_result.knn[..., 1:, :] return points_proj, num_points
def upsample( pcl: Union[Pointclouds, torch.Tensor], n_points: Union[int, torch.Tensor], num_points=None, neighborhood_size=16, knn_result=None ) -> Union[Pointclouds, Tuple[torch.Tensor, torch.Tensor]]: """ Iteratively add points to the sparsest region Args: points (tensor of [N, P, 3] or Pointclouds) n_points (tensor of [N] or integer): target number of points per cloud Returns: Pointclouds or (padded_points, num_points) """ def _return_value(points, num_points, return_pcl): if return_pcl: points_list = padded_to_list(points, num_points.tolist()) return pcl.__class__(points_list) else: return points, num_points return_pcl = is_pointclouds(pcl) points, num_points = convert_pointclouds_to_tensor(pcl) knn_k = neighborhood_size if not ((num_points - num_points[0]) == 0).all(): logger_py.warn( "Upsampling operation may encounter unexpected behavior for heterogeneous batches" ) if num_points.sum() == 0: return _return_value(points, num_points, return_pcl) n_remaining = (n_points - num_points).to(dtype=torch.long) if (n_remaining <= 0).all(): return _return_value(points, num_points, return_pcl) if knn_result is None: knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True, return_sorted=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) while True: if (n_remaining == 0).all(): break # half of the points per batch sparse_pts = points sparse_dists = knn_result.dists sparse_knn = knn_result.knn batch_size, P, _ = sparse_pts.shape max_P = (P // 8) # sparse_knn_normals = frnn.frnn_gather( # normals_init, knn_result.idx, num_points)[:, 1:] # get all mid points mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3 # N,P,K,K,3 mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3) # minimize among all the neighbors min_dist2 = torch.norm(mid_nn_diff, dim=-1) # N,P,K,K min_dist2 = min_dist2.min(dim=-1)[0] # N,P,K father_sparsity, father_nb = min_dist2.max(dim=-1) # N,P # neighborhood to insert sparsity_sorted = father_sparsity.sort(dim=1).indices n_new_points = n_remaining.clone() n_new_points[n_new_points > max_P] = max_P sparsity_sorted = sparsity_sorted[:, -max_P:] new_pts = torch.gather( mid_points[torch.arange(mid_points.shape[0]).view(-1, 1, 1), torch.arange(mid_points.shape[1]).view(1, -1, 1), father_nb.unsqueeze(-1)].squeeze(-2), 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) sparse_selected = torch.gather( sparse_pts, 1, sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3)) total_pts_list = [] for b, pts_batch in enumerate( padded_to_list(points, num_points.tolist())): total_pts_list.append( torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0)) points = list_to_padded(total_pts_list) n_remaining = n_remaining - n_new_points num_points = n_new_points + num_points knn_result = knn_points(points, points, num_points, num_points, K=knn_k + 1, return_nn=True) knn_result = _KNN(dists=knn_result.dists[..., 1:], idx=knn_result.idx[..., 1:], knn=knn_result.knn[..., 1:, :]) return _return_value(points, num_points, return_pcl)