Ejemplo n.º 1
0
            def getChamferDist(x, y):
                '''
                Computer chamfer distance
                :param x:
                :param y:
                :return:
                '''
                xlengths = torch.full((x.shape[0], ),
                                      x.shape[1],
                                      dtype=torch.int64,
                                      device=x.device)
                ylengths = torch.full((y.shape[0], ),
                                      y.shape[1],
                                      dtype=torch.int64,
                                      device=y.device)
                x_nn = knn_points(x,
                                  y,
                                  lengths1=xlengths,
                                  lengths2=ylengths,
                                  K=1)
                y_nn = knn_points(y,
                                  x,
                                  lengths1=ylengths,
                                  lengths2=xlengths,
                                  K=1)

                cham_x = x_nn.dists[..., 0]  # (N, P1)
                cham_y = y_nn.dists[..., 0]  # (N, P2)

                return cham_x, cham_y
Ejemplo n.º 2
0
    def compute_link_loss(self, scene_output, supervision):
        corresp = supervision["mano_corresp"]
        links = supervision["links"]
        body_verts = scene_output["body_info"]["verts"]
        left_hand_verts = body_verts[:, corresp["left_hand"]]
        right_hand_verts = body_verts[:, corresp["right_hand"]]
        obj_verts = torch.cat(
            [info["verts"] for info in scene_output["obj_infos"]], 2)

        # Compute min obj2 hand distance
        left2obj_mins = knn_points(obj_verts, left_hand_verts,
                                   K=1)[0].min(1)[0][:, 0]
        right2obj_mins = knn_points(obj_verts, right_hand_verts,
                                    K=1)[0].min(1)[0][:, 0]
        right_flags = obj_verts.new(links)[:, 1]
        left_flags = obj_verts.new(links)[:, 0]

        batch_min_dists = (left_flags * left2obj_mins +
                           right_flags * right2obj_mins)
        loss = batch_min_dists.mean()
        min_dists = {
            "left": left2obj_mins.detach().cpu(),
            "right": right2obj_mins.detach().cpu(),
            "left_flags": left_flags.detach().cpu(),
            "right_flags": right_flags.detach(),
        }
        loss_info = {"link_min_dists": min_dists}
        return loss, loss_info
Ejemplo n.º 3
0
def hausdorff_distance_vismesh(ma: trimesh.Trimesh,
                               mb: trimesh.Trimesh,
                               min=0,
                               max=10 * 1e-3):
    '''
    ma: trimesh
    mb: trimesh
    min: for visualization, clip cham_ab, color blue
    max: for visualization, clip cham_ab, color blue

    return: 
    bidirectional hausdorff distance, 
    colored mesh ma, 
    chamfer distance from ma to mb, 
    chamfer distance from mb to ma
    '''
    va = torch.from_numpy(ma.vertices).float().unsqueeze(0)
    vb = torch.from_numpy(mb.vertices).float().unsqueeze(0)
    x_nn = knn_points(va, vb, K=1)
    cham_x = x_nn.dists[..., 0].squeeze()  # (N, P1)
    cham_x = cham_x**0.5

    y_nn = knn_points(vb, va, K=1)
    cham_y = y_nn.dists[..., 0].squeeze()  # (N, P1)
    cham_y = cham_y**0.5

    hd_xy = cham_x.max()
    hd_yx = cham_y.max()
    hd = torch.max(hd_xy, hd_yx)

    # color ma by cham_x
    cham_x_clip = cham_x.clone()
    cham_x_clip[cham_x_clip > max] = max
    cham_x_clip[cham_x_clip < min] = min

    ratio = (cham_x_clip - min) / (max - min)
    max_h = 0  #red
    min_h = 0.667  #blue
    color_h = min_h - ratio * min_h
    rgb = np.stack([colorsys.hsv_to_rgb(i, 1, 0.8) for i in color_h])
    a = np.ones([rgb.shape[0], 1])
    rgba = np.hstack([rgb, a])
    rgba *= 255.
    ma.visual.vertex_colors = rgba

    cham_y_clip = cham_y.clone()
    cham_y_clip[cham_y_clip > max] = max
    cham_y_clip[cham_y_clip < min] = min
    ratio = (cham_y_clip - min) / (max - min)
    color_h = min_h - ratio * min_h
    rgb = np.stack([colorsys.hsv_to_rgb(i, 1, 0.8) for i in color_h])
    a = np.ones([rgb.shape[0], 1])
    rgba = np.hstack([rgb, a])
    rgba *= 255.
    mb.visual.vertex_colors = rgba

    return hd, ma, mb, cham_x.numpy(), cham_y.numpy()
Ejemplo n.º 4
0
    def test_invalid_norm(self):
        device = get_random_cuda_device()
        N, P1, P2, K, D = 4, 16, 12, 8, 3
        x = torch.rand((N, P1, D), device=device)
        y = torch.rand((N, P2, D), device=device)
        with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
            knn_points(x, y, K=K, norm=3)

        with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
            knn_points(x, y, K=K, norm=0)
Ejemplo n.º 5
0
    def discrete_project(self, pc: torch.Tensor, thres=0.9, cpu=False):
        with torch.no_grad():
            device = torch.device('cpu') if cpu else self.device
            pc = pc.double()
            if isinstance(self, Mesh):
                mid_points = self.vs[self.faces].mean(dim=1)
                normals = self.normals
            else:
                mid_points = self[:, :3]
                normals = self[:, 3:]
            pk12 = knn_points(mid_points[:, :3].unsqueeze(0),
                              pc[:, :, :3],
                              K=3).idx[0]
            pk21 = knn_points(pc[:, :, :3],
                              mid_points[:, :3].unsqueeze(0),
                              K=3).idx[0]
            loop = pk21[pk12].view(pk12.shape[0], -1)
            knn_mask = (loop == torch.arange(
                0, pk12.shape[0], device=self.device)[:, None]).sum(dim=1) > 0
            mid_points = mid_points.to(device)
            pc = pc[0].to(device)
            normals = normals.to(device)[~knn_mask, :]
            masked_mid_points = mid_points[~knn_mask, :]
            displacement = masked_mid_points[:, None, :] - pc[:, :3]
            torch.cuda.empty_cache()
            distance = displacement.norm(dim=-1)
            mask = (torch.abs(
                torch.sum((displacement / distance[:, :, None]) *
                          normals[:, None, :],
                          dim=-1)) > thres)
            if pc.shape[-1] == 6:
                pc_normals = pc[:, 3:]
                normals_correlation = torch.sum(normals[:, None, :] *
                                                pc_normals,
                                                dim=-1)
                mask = mask * (normals_correlation > 0)
            torch.cuda.empty_cache()
            distance[~mask] += float('inf')
            min, argmin = distance.min(dim=-1)

            pc_per_face_masked = pc[argmin, :].clone()
            pc_per_face_masked[min == float('inf'), :] = float('nan')
            pc_per_face = torch.zeros(mid_points.shape[0], 6).\
                type(pc_per_face_masked.dtype).to(pc_per_face_masked.device)
            pc_per_face[~knn_mask, :pc.shape[-1]] = pc_per_face_masked
            pc_per_face[knn_mask, :] = float('nan')

            # clean up
            del knn_mask
        return pc_per_face.to(
            self.device), (pc_per_face[:, 0] == pc_per_face[:, 0]).to(device)
Ejemplo n.º 6
0
    def _knn_vs_python_ragged_helper(self, device):
        Ns = [1, 4]
        Ds = [3, 5, 8]
        P1s = [8, 24]
        P2s = [8, 16, 32]
        Ks = [1, 3, 10]
        factors = [Ns, Ds, P1s, P2s, Ks]
        for N, D, P1, P2, K in product(*factors):
            x = torch.rand((N, P1, D), device=device, requires_grad=True)
            y = torch.rand((N, P2, D), device=device, requires_grad=True)
            lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
            lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)

            x_csrc = x.clone().detach()
            x_csrc.requires_grad_(True)
            y_csrc = y.clone().detach()
            y_csrc.requires_grad_(True)

            # forward
            out1 = self._knn_points_naive(
                x, y, lengths1=lengths1, lengths2=lengths2, K=K
            )
            out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
            self.assertClose(out1[0], out2[0])
            self.assertTrue(torch.all(out1[1] == out2[1]))

            # backward
            grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
            loss1 = (out1.dists * grad_dist).sum()
            loss1.backward()
            loss2 = (out2.dists * grad_dist).sum()
            loss2.backward()

            self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
            self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
Ejemplo n.º 7
0
    def _knn_vs_python_square_helper(self, device):
        Ns = [1, 4]
        Ds = [3, 5, 8]
        P1s = [8, 24]
        P2s = [8, 16, 32]
        Ks = [1, 3, 10]
        versions = [0, 1, 2, 3]
        factors = [Ns, Ds, P1s, P2s, Ks]
        for N, D, P1, P2, K in product(*factors):
            for version in versions:
                if version == 3 and K > 4:
                    continue
                x = torch.randn(N, P1, D, device=device, requires_grad=True)
                x_cuda = x.clone().detach()
                x_cuda.requires_grad_(True)
                y = torch.randn(N, P2, D, device=device, requires_grad=True)
                y_cuda = y.clone().detach()
                y_cuda.requires_grad_(True)

                # forward
                out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
                out2 = knn_points(x_cuda, y_cuda, K=K, version=version)
                self.assertClose(out1[0], out2[0])
                self.assertTrue(torch.all(out1[1] == out2[1]))

                # backward
                grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
                loss1 = (out1.dists * grad_dist).sum()
                loss1.backward()
                loss2 = (out2.dists * grad_dist).sum()
                loss2.backward()

                self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
                self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
Ejemplo n.º 8
0
    def _knn_vs_python_square_helper(self, device, return_sorted):
        Ns = [1, 4]
        Ds = [3, 5, 8]
        P1s = [8, 24]
        P2s = [8, 16, 32]
        Ks = [1, 3, 10]
        norms = [1, 2]
        versions = [0, 1, 2, 3]
        factors = [Ns, Ds, P1s, P2s, Ks, norms]
        for N, D, P1, P2, K, norm in product(*factors):
            for version in versions:
                if version == 3 and K > 4:
                    continue
                x = torch.randn(N, P1, D, device=device, requires_grad=True)
                x_cuda = x.clone().detach()
                x_cuda.requires_grad_(True)
                y = torch.randn(N, P2, D, device=device, requires_grad=True)
                y_cuda = y.clone().detach()
                y_cuda.requires_grad_(True)

                # forward
                out1 = self._knn_points_naive(
                    x, y, lengths1=None, lengths2=None, K=K, norm=norm
                )
                out2 = knn_points(
                    x_cuda,
                    y_cuda,
                    K=K,
                    norm=norm,
                    version=version,
                    return_sorted=return_sorted,
                )
                if K > 1 and not return_sorted:
                    # check out2 is not sorted
                    self.assertFalse(torch.allclose(out1[0], out2[0]))
                    self.assertFalse(torch.allclose(out1[1], out2[1]))
                    # now sort out2
                    dists, idx, _ = out2
                    if P2 < K:
                        dists[..., P2:] = float("inf")
                        dists, sort_idx = dists.sort(dim=2)
                        dists[..., P2:] = 0
                    else:
                        dists, sort_idx = dists.sort(dim=2)
                    idx = idx.gather(2, sort_idx)
                    out2 = _KNN(dists, idx, None)

                self.assertClose(out1[0], out2[0])
                self.assertTrue(torch.all(out1[1] == out2[1]))

                # backward
                grad_dist = torch.ones((N, P1, K), dtype=torch.float32, device=device)
                loss1 = (out1.dists * grad_dist).sum()
                loss1.backward()
                loss2 = (out2.dists * grad_dist).sum()
                loss2.backward()

                self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
                self.assertClose(y_cuda.grad, y.grad, atol=5e-6)
Ejemplo n.º 9
0
 def output():
     out = knn_points(pts1,
                      pts2,
                      lengths1=lengths1,
                      lengths2=lengths2,
                      K=K)
     loss = (out.dists * grad_dists).sum()
     loss.backward()
     torch.cuda.synchronize()
Ejemplo n.º 10
0
def feeature_pooling(x, x_lengths, y, y_lengths, neighbors):
    x = x
    x_lengths = x_lengths
    y = y
    y_lengths = y_lengths
    N, P1, D = x.shape
    P2 = y.shape[1]
    # T = 0.001
    T2 = 1e-20
    if P1 < 15 or P2 < 15:
        raise ValueError(
            "x or y does not have the enough points (at lest 15 points).")

    x_nn = knn_points(x,
                      x,
                      lengths1=x_lengths,
                      lengths2=x_lengths,
                      K=neighbors)
    y_nn = knn_points(x,
                      y,
                      lengths1=x_lengths,
                      lengths2=y_lengths,
                      K=neighbors)

    x_coor_near = knn_gather(x, x_nn.idx, x_lengths)  #batch,points,nei,3
    y_coor_near = knn_gather(y, y_nn.idx, x_lengths)

    if y.shape[0] != N or y.shape[2] != D:
        raise ValueError("y does not have the correct shape.")

    # three scales
    SPED1 = SPED(x, x_nn, y_nn, T2, N, P1, 10, x_coor_near, y_coor_near, D)
    SPED2 = SPED(x, x_nn, y_nn, T2, N, P1, 5, x_coor_near, y_coor_near, D)
    SPED3 = SPED(x, x_nn, y_nn, T2, N, P1, 1, x_coor_near, y_coor_near, D)

    SPED1 = SPED1.sum()
    SPED2 = SPED2.sum()
    SPED3 = SPED3.sum()
    MPED_SCORE = (SPED1 + SPED2 + SPED3)
    MPED_SCORE = MPED_SCORE / P1
    MPED_SCORE = MPED_SCORE / N
    return MPED_SCORE
Ejemplo n.º 11
0
 def knn(self):
     dists, idxs, nn = knn_points(self.pc1_knn,
                                  self.pc2_knn,
                                  self.lengths1_cuda,
                                  self.lengths2_cuda,
                                  K=self.K,
                                  version=-1,
                                  return_nn=True,
                                  return_sorted=True)
     # for backward, assume all we have k neighbors within the radius
     # mask = dists > self.r * self.r
     # idxs[mask] = -1
     # dists[mask] = -1
     # nn[mask] = 0.
     return dists, idxs, nn
Ejemplo n.º 12
0
    def test_knn_gather(self):
        device = get_random_cuda_device()
        N, P1, P2, K, D = 4, 16, 12, 8, 3
        x = torch.rand((N, P1, D), device=device)
        y = torch.rand((N, P2, D), device=device)
        lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
        lengths2 = torch.randint(low=1, high=P2, size=(N,), device=device)

        out = knn_points(x, y, lengths1=lengths1, lengths2=lengths2, K=K)
        y_nn = knn_gather(y, out.idx, lengths2)

        for n in range(N):
            for p1 in range(P1):
                for k in range(K):
                    if k < lengths2[n]:
                        self.assertClose(y_nn[n, p1, k], y[n, out.idx[n, p1, k]])
                    else:
                        self.assertTrue(torch.all(y_nn[n, p1, k] == 0.0))
Ejemplo n.º 13
0
def get_NN(src_xyz, trg_xyz, k=1):
    '''
    :param src_xyz: [B, N1, 3]
    :param trg_xyz: [B, N2, 3]
    :return: nn_dists, nn_dix: all [B, 3000] tensor for NN distance and index in N2
    '''
    B = src_xyz.size(0)
    src_lengths = torch.full((src_xyz.shape[0], ),
                             src_xyz.shape[1],
                             dtype=torch.int64,
                             device=src_xyz.device)  # [B], N for each num
    trg_lengths = torch.full((trg_xyz.shape[0], ),
                             trg_xyz.shape[1],
                             dtype=torch.int64,
                             device=trg_xyz.device)
    src_nn = knn_points(src_xyz,
                        trg_xyz,
                        lengths1=src_lengths,
                        lengths2=trg_lengths,
                        K=k)  # [dists, idx]
    nn_dists = src_nn.dists[..., 0]
    nn_idx = src_nn.idx[..., 0]
    return nn_dists, nn_idx
Ejemplo n.º 14
0
def chamfer_distance(
    x,
    y,
    x_lengths=None,
    y_lengths=None,
    x_normals=None,
    y_normals=None,
    weights=None,
    batch_reduction: Union[str, None] = "mean",
    point_reduction: str = "mean",
):
    """
    Chamfer distance between two pointclouds x and y.

    Args:
        x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
            a batch of point clouds with at most P1 points in each batch element,
            batch size N and feature dimension D.
        y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
            a batch of point clouds with at most P2 points in each batch element,
            batch size N and feature dimension D.
        x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
            cloud in x.
        y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
            cloud in y.
        x_normals: Optional FloatTensor of shape (N, P1, D).
        y_normals: Optional FloatTensor of shape (N, P2, D).
        weights: Optional FloatTensor of shape (N,) giving weights for
            batch elements for reduction operation.
        batch_reduction: Reduction operation to apply for the loss across the
            batch, can be one of ["mean", "sum"] or None.
        point_reduction: Reduction operation to apply for the loss across the
            points, can be one of ["mean", "sum"].

    Returns:
        2-element tuple containing

        - **loss**: Tensor giving the reduced distance between the pointclouds
          in x and the pointclouds in y.
        - **loss_normals**: Tensor giving the reduced cosine distance of normals
          between pointclouds in x and pointclouds in y. Returns None if
          x_normals and y_normals are None.
    """
    _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)

    x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
    y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)

    return_normals = x_normals is not None and y_normals is not None

    N, P1, D = x.shape
    P2 = y.shape[1]

    # Check if inputs are heterogeneous and create a lengths mask.
    is_x_heterogeneous = (x_lengths != P1).any()
    is_y_heterogeneous = (y_lengths != P2).any()
    x_mask = (torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
              )  # shape [N, P1]
    y_mask = (torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
              )  # shape [N, P2]

    if y.shape[0] != N or y.shape[2] != D:
        raise ValueError("y does not have the correct shape.")
    if weights is not None:
        if weights.size(0) != N:
            raise ValueError("weights must be of shape (N,).")
        if not (weights >= 0).all():
            raise ValueError("weights cannot be negative.")
        if weights.sum() == 0.0:
            weights = weights.view(N, 1)
            if batch_reduction in ["mean", "sum"]:
                return (
                    (x.sum((1, 2)) * weights).sum() * 0.0,
                    (x.sum((1, 2)) * weights).sum() * 0.0,
                )
            return ((x.sum((1, 2)) * weights) * 0.0, (x.sum(
                (1, 2)) * weights) * 0.0)

    cham_norm_x = x.new_zeros(())
    cham_norm_y = x.new_zeros(())

    x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
    y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)

    cham_x = x_nn.dists[..., 0]  # (N, P1)
    cham_y = y_nn.dists[..., 0]  # (N, P2)

    if is_x_heterogeneous:
        cham_x[x_mask] = 0.0
    if is_y_heterogeneous:
        cham_y[y_mask] = 0.0

    if weights is not None:
        cham_x *= weights.view(N, 1)
        cham_y *= weights.view(N, 1)

    if return_normals:
        # Gather the normals using the indices and keep only value for k=0
        x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
        y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]

        cham_norm_x = 1 - torch.abs(
            F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6))
        cham_norm_y = 1 - torch.abs(
            F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6))

        if is_x_heterogeneous:
            cham_norm_x[x_mask] = 0.0
        if is_y_heterogeneous:
            cham_norm_y[y_mask] = 0.0

        if weights is not None:
            cham_norm_x *= weights.view(N, 1)
            cham_norm_y *= weights.view(N, 1)

    # Apply point reduction
    cham_x = cham_x.sum(1)  # (N,)
    cham_y = cham_y.sum(1)  # (N,)
    if return_normals:
        cham_norm_x = cham_norm_x.sum(1)  # (N,)
        cham_norm_y = cham_norm_y.sum(1)  # (N,)
    if point_reduction == "mean":
        cham_x /= x_lengths
        cham_y /= y_lengths
        if return_normals:
            cham_norm_x /= x_lengths
            cham_norm_y /= y_lengths

    if batch_reduction is not None:
        # batch_reduction == "sum"
        cham_x = cham_x.sum()
        cham_y = cham_y.sum()
        if return_normals:
            cham_norm_x = cham_norm_x.sum()
            cham_norm_y = cham_norm_y.sum()
        if batch_reduction == "mean":
            div = weights.sum() if weights is not None else N
            cham_x /= div
            cham_y /= div
            if return_normals:
                cham_norm_x /= div
                cham_norm_y /= div

    cham_dist = cham_x + cham_y
    cham_normals = cham_norm_x + cham_norm_y if return_normals else None

    return cham_dist, cham_normals
Ejemplo n.º 15
0
def upsample_ear(points,
                 normals,
                 n_points: Union[int, torch.Tensor],
                 num_points=None,
                 neighborhood_size=16,
                 repulsion_mu=0.4,
                 edge_sensitivity=1.0):
    """
    Args:
        points (N, P, 3)
        n_points (tensor of [N] or integer): target number of points per cloud

    """
    batch_size = points.shape[0]
    knn_k = neighborhood_size
    if num_points is None:
        num_points = torch.tensor([points.shape[1]] * points.shape[0],
                                  device=points.device,
                                  dtype=torch.long)
    if not ((num_points - num_points[0]) == 0).all():
        logger_py.warn(
            "May encounter unexpected behavior for heterogeneous batches")
    if num_points.sum() == 0:
        return points, num_points

    point_cloud_diag = (points.max(dim=-2)[0] -
                        points.min(dim=-2)[0]).norm(dim=-1)
    inv_sigma_spatial = num_points / point_cloud_diag
    spatial_dist = 16 / inv_sigma_spatial

    knn_result = knn_points(points,
                            points,
                            num_points,
                            num_points,
                            K=knn_k + 1,
                            return_nn=True,
                            return_sorted=True)
    # dists, idxs, nn, grid = frnn.frnn_grid_points(points_proj, points_proj, num_points, num_points, K=self.knn_k + 1,
    #                                               r=torch.sqrt(spatial_dist), return_nn=True)
    # knn_result = _KNN(dists=dists, idx=idxs, knn=nn)
    _knn_idx = knn_result.idx[..., 1:]
    _knn_dists = knn_result.dists[..., 1:]
    _knn_nn = knn_result.knn[..., 1:, :]
    move_clip = knn_result.dists[..., 1].mean().sqrt()

    # 2. LOP projection
    if denoise_normals:
        normals_denoised, weights_p, weights_n = denoise_normals(
            points, normals, num_points, knn_result=knn_result)
        normals = normals_denoised

    # (optional) search knn in the original points
    # e(-(<n, p-pi>)^2/sigma_p)
    weight_lop = torch.exp(-torch.sum(normals[:, :, None, :] *
                                      (points[:, :, None, :] - _knn_nn),
                                      dim=-1)**2 * inv_sigma_spatial)
    weight_lop[_knn_dists > spatial_dist] = 0
    # weight_lop[self._knn_idx < 0] = 0

    # spatial weight
    deltap = _knn_dists
    spatial_w = torch.exp(-deltap * inv_sigma_spatial)
    spatial_w[deltap > spatial_dist] = 0
    # spatial_w[self._knn_idx[..., 1:] < 0] = 0
    density_w = torch.sum(spatial_w, dim=-1) + 1.0
    move_data = torch.sum(
        weight_lop[..., None] * (points[:, :, None, :] - _knn_nn), dim=-2) / \
        eps_denom(torch.sum(weight_lop, dim=-1, keepdim=True))
    move_repul = repulsion_mu * density_w[..., None] * torch.sum(spatial_w[..., None] * (
        knn_result.knn[:, :, 1:, :] - points[:, :, None, :]), dim=-2) / \
        eps_denom(torch.sum(spatial_w, dim=-1, keepdim=True))
    move_repul = F.normalize(move_repul) * move_repul.norm(
        dim=-1, keepdim=True).clamp_max(move_clip)
    move_data = F.normalize(move_data) * move_data.norm(
        dim=-1, keepdim=True).clamp_max(move_clip)
    move = move_data + move_repul
    points = points - move

    n_remaining = n_points - num_points
    while True:
        if (n_remaining == 0).all():
            break
        # half of the points per batch
        sparse_pts = points
        sparse_dists = _knn_dists
        sparse_knn = _knn_nn
        batch_size, P, _ = sparse_pts.shape
        max_P = (P // 10)
        # sparse_knn_normals = frnn.frnn_gather(
        #     normals_init, knn_result.idx, num_points)[:, 1:]
        # get all mid points
        mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3
        # N,P,K,K,3
        mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3)
        # minimize among all the neighbors
        min_dist2 = torch.norm(mid_nn_diff, dim=-1)  # N,P,K,K
        min_dist2 = min_dist2.min(dim=-1)[0]  # N,P,K
        father_sparsity, father_nb = min_dist2.max(dim=-1)  # N,P
        # neighborhood to insert
        sparsity_sorted = father_sparsity.sort(dim=1).indices
        n_new_points = n_remaining.clone()
        n_new_points[n_new_points > max_P] = max_P
        sparsity_sorted = sparsity_sorted[:, -max_P:]
        # N, P//2, 3, sparsest at the end
        new_pts = torch.gather(
            mid_points[torch.arange(mid_points.shape[0]),
                       torch.arange(mid_points.shape[1]), father_nb], 1,
            sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3))
        total_pts_list = []
        for b, pts_batch in enumerate(
                padded_to_list(points, num_points.tolist())):
            total_pts_list.append(
                torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0))

        points_proj = list_to_padded(total_pts_list)
        n_remaining = n_remaining - n_new_points
        num_points = n_new_points + num_points
        knn_result = knn_points(points_proj,
                                points_proj,
                                num_points,
                                num_points,
                                K=knn_k + 1,
                                return_nn=True)
        _knn_idx = knn_result.idx[..., 1:]
        _knn_dists = knn_result.dists[..., 1:]
        _knn_nn = knn_result.knn[..., 1:, :]

    return points_proj, num_points
Ejemplo n.º 16
0
def upsample(
        pcl: Union[Pointclouds, torch.Tensor],
        n_points: Union[int, torch.Tensor],
        num_points=None,
        neighborhood_size=16,
        knn_result=None
) -> Union[Pointclouds, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Iteratively add points to the sparsest region
    Args:
        points (tensor of [N, P, 3] or Pointclouds)
        n_points (tensor of [N] or integer): target number of points per cloud
    Returns:
        Pointclouds or (padded_points, num_points)
    """
    def _return_value(points, num_points, return_pcl):
        if return_pcl:
            points_list = padded_to_list(points, num_points.tolist())
            return pcl.__class__(points_list)
        else:
            return points, num_points

    return_pcl = is_pointclouds(pcl)
    points, num_points = convert_pointclouds_to_tensor(pcl)

    knn_k = neighborhood_size

    if not ((num_points - num_points[0]) == 0).all():
        logger_py.warn(
            "Upsampling operation may encounter unexpected behavior for heterogeneous batches"
        )

    if num_points.sum() == 0:
        return _return_value(points, num_points, return_pcl)

    n_remaining = (n_points - num_points).to(dtype=torch.long)
    if (n_remaining <= 0).all():
        return _return_value(points, num_points, return_pcl)

    if knn_result is None:
        knn_result = knn_points(points,
                                points,
                                num_points,
                                num_points,
                                K=knn_k + 1,
                                return_nn=True,
                                return_sorted=True)

        knn_result = _KNN(dists=knn_result.dists[..., 1:],
                          idx=knn_result.idx[..., 1:],
                          knn=knn_result.knn[..., 1:, :])

    while True:
        if (n_remaining == 0).all():
            break
        # half of the points per batch
        sparse_pts = points
        sparse_dists = knn_result.dists
        sparse_knn = knn_result.knn
        batch_size, P, _ = sparse_pts.shape
        max_P = (P // 8)
        # sparse_knn_normals = frnn.frnn_gather(
        #     normals_init, knn_result.idx, num_points)[:, 1:]
        # get all mid points
        mid_points = (sparse_knn + 2 * sparse_pts[..., None, :]) / 3
        # N,P,K,K,3
        mid_nn_diff = mid_points.unsqueeze(-2) - sparse_knn.unsqueeze(-3)
        # minimize among all the neighbors
        min_dist2 = torch.norm(mid_nn_diff, dim=-1)  # N,P,K,K
        min_dist2 = min_dist2.min(dim=-1)[0]  # N,P,K
        father_sparsity, father_nb = min_dist2.max(dim=-1)  # N,P
        # neighborhood to insert
        sparsity_sorted = father_sparsity.sort(dim=1).indices
        n_new_points = n_remaining.clone()
        n_new_points[n_new_points > max_P] = max_P
        sparsity_sorted = sparsity_sorted[:, -max_P:]
        new_pts = torch.gather(
            mid_points[torch.arange(mid_points.shape[0]).view(-1, 1, 1),
                       torch.arange(mid_points.shape[1]).view(1, -1, 1),
                       father_nb.unsqueeze(-1)].squeeze(-2), 1,
            sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3))

        sparse_selected = torch.gather(
            sparse_pts, 1,
            sparsity_sorted.unsqueeze(-1).expand(-1, -1, 3))

        total_pts_list = []
        for b, pts_batch in enumerate(
                padded_to_list(points, num_points.tolist())):
            total_pts_list.append(
                torch.cat([new_pts[b][-n_new_points[b]:], pts_batch], dim=0))

        points = list_to_padded(total_pts_list)
        n_remaining = n_remaining - n_new_points
        num_points = n_new_points + num_points
        knn_result = knn_points(points,
                                points,
                                num_points,
                                num_points,
                                K=knn_k + 1,
                                return_nn=True)
        knn_result = _KNN(dists=knn_result.dists[..., 1:],
                          idx=knn_result.idx[..., 1:],
                          knn=knn_result.knn[..., 1:, :])

    return _return_value(points, num_points, return_pcl)