Exemple #1
0
    def test_ball_query_output_simple(self):
        device = get_random_cuda_device()
        N, P1, P2, K = 5, 8, 16, 4
        sphere = ico_sphere(level=2, device=device).extend(N)
        points_1 = sample_points_from_meshes(sphere, P1)
        points_2 = sample_points_from_meshes(sphere, P2) * 5.0
        radius = 6.0

        naive_out = self._ball_query_naive(
            points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
        )
        cuda_out = ball_query(points_1, points_2, K=K, radius=radius)

        # All points should have N sample neighbors as radius is large
        # Zero is a valid index but can only be present once (i.e. no zero padding)
        naive_out_zeros = (naive_out.idx == 0).sum(dim=-1).max()
        cuda_out_zeros = (cuda_out.idx == 0).sum(dim=-1).max()
        self.assertTrue(naive_out_zeros == 0 or naive_out_zeros == 1)
        self.assertTrue(cuda_out_zeros == 0 or cuda_out_zeros == 1)

        # All points should now have zero sample neighbors as radius is small
        radius = 0.5
        naive_out = self._ball_query_naive(
            points_1, points_2, lengths1=None, lengths2=None, K=K, radius=radius
        )
        cuda_out = ball_query(points_1, points_2, K=K, radius=radius)
        naive_out_allzeros = (naive_out.idx == -1).all()
        cuda_out_allzeros = (cuda_out.idx == -1).sum()
        self.assertTrue(naive_out_allzeros)
        self.assertTrue(cuda_out_allzeros)
Exemple #2
0
 def output():
     out = ball_query(
         pts1, pts2, lengths1=lengths1, lengths2=lengths2, K=K, radius=radius
     )
     loss = (out.dists * grad_dists).sum()
     loss.backward()
     torch.cuda.synchronize()
Exemple #3
0
    def _ball_query_vs_python_ragged_helper(self, device):
        Ns = [1, 4]
        Ds = [3, 5, 8]
        P1s = [8, 24]
        P2s = [8, 16, 32]
        Ks = [2, 3, 10]
        Rs = [1.4, 5]  # radius
        factors = [Ns, Ds, P1s, P2s, Ks, Rs]
        for N, D, P1, P2, K, R in product(*factors):
            x = torch.rand((N, P1, D), device=device, requires_grad=True)
            y = torch.rand((N, P2, D), device=device, requires_grad=True)
            lengths1 = torch.randint(low=1, high=P1, size=(N, ), device=device)
            lengths2 = torch.randint(low=1, high=P2, size=(N, ), device=device)

            x_csrc = x.clone().detach()
            x_csrc.requires_grad_(True)
            y_csrc = y.clone().detach()
            y_csrc.requires_grad_(True)

            # forward
            out1 = self._ball_query_naive(x,
                                          y,
                                          lengths1=lengths1,
                                          lengths2=lengths2,
                                          K=K,
                                          radius=R)
            out2 = ball_query(
                x_csrc,
                y_csrc,
                lengths1=lengths1,
                lengths2=lengths2,
                K=K,
                radius=R,
            )

            self.assertClose(out1.idx, out2.idx)
            self.assertClose(out1.dists, out2.dists)

            # backward
            grad_dist = torch.ones((N, P1, K),
                                   dtype=torch.float32,
                                   device=device)
            loss1 = (out1.dists * grad_dist).sum()
            loss1.backward()
            loss2 = (out2.dists * grad_dist).sum()
            loss2.backward()

            self.assertClose(x_csrc.grad, x.grad, atol=5e-6)
            self.assertClose(y_csrc.grad, y.grad, atol=5e-6)
Exemple #4
0
    def _ball_query_vs_python_square_helper(self, device):
        Ns = [1, 4]
        Ds = [3, 5, 8]
        P1s = [8, 24]
        P2s = [8, 16, 32]
        Ks = [1, 5]
        Rs = [3, 5]
        factors = [Ns, Ds, P1s, P2s, Ks, Rs]
        for N, D, P1, P2, K, R in product(*factors):
            x = torch.randn(N, P1, D, device=device, requires_grad=True)
            x_cuda = x.clone().detach()
            x_cuda.requires_grad_(True)
            y = torch.randn(N, P2, D, device=device, requires_grad=True)
            y_cuda = y.clone().detach()
            y_cuda.requires_grad_(True)

            # forward
            out1 = self._ball_query_naive(x,
                                          y,
                                          lengths1=None,
                                          lengths2=None,
                                          K=K,
                                          radius=R)
            out2 = ball_query(x_cuda, y_cuda, K=K, radius=R)

            # Check dists
            self.assertClose(out1.dists, out2.dists)
            # Check idx
            self.assertTrue(torch.all(out1.idx == out2.idx))

            # backward
            grad_dist = torch.ones((N, P1, K),
                                   dtype=torch.float32,
                                   device=device)
            loss1 = (out1.dists * grad_dist).sum()
            loss1.backward()
            loss2 = (out2.dists * grad_dist).sum()
            loss2.backward()

            self.assertClose(x_cuda.grad, x.grad, atol=5e-6)
            self.assertClose(y_cuda.grad, y.grad, atol=5e-6)