def _test_compare_random_heterogeneous(self, device="cpu"):
        N, P, D, K = 5, 20, 5, 8
        points = torch.randn((N, P, D), device=device, dtype=torch.float32)
        out_points_naive, out_idxs_naive = sample_farthest_points_naive(points,
                                                                        K=K)
        out_points, out_idxs = sample_farthest_points(points, K=K)
        self.assertTrue(out_idxs.min() >= 0)
        self.assertClose(out_idxs, out_idxs_naive)
        self.assertClose(out_points, out_points_naive)
        for n in range(N):
            self.assertEqual(out_idxs[n].ne(-1).sum(), K)

        # Test case where K > P
        K = 30
        points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
        points2 = points1.clone()
        points1.requires_grad = True
        points2.requires_grad = True
        lengths = torch.randint(low=1, high=P, size=(N, ), device=device)
        out_points_naive, out_idxs_naive = sample_farthest_points_naive(
            points1, lengths, K=K)
        out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
        self.assertClose(out_idxs, out_idxs_naive)
        self.assertClose(out_points, out_points_naive)

        for n in range(N):
            # Check that for heterogeneous batches, the max number of
            # selected points is less than the length
            self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
            self.assertTrue(out_idxs[n].max() <= lengths[n])

            # Check there are no duplicate indices
            val_mask = out_idxs[n].ne(-1)
            vals, counts = torch.unique(out_idxs[n][val_mask],
                                        return_counts=True)
            self.assertTrue(counts.le(1).all())

        # Check gradients
        grad_sampled_points = torch.ones((N, K, D),
                                         dtype=torch.float32,
                                         device=device)
        loss1 = (out_points_naive * grad_sampled_points).sum()
        loss1.backward()
        loss2 = (out_points * grad_sampled_points).sum()
        loss2.backward()
        self.assertClose(points1.grad, points2.grad, atol=5e-6)
 def output():
     out_points, _ = sample_farthest_points(pts, K=K)
     loss = (out_points * grad_pts).sum()
     loss.backward()
     torch.cuda.synchronize()
    def test_cuda_vs_cpu(self):
        """
        Compare cuda vs cpu on a complex object
        """
        obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
        K = 250

        # Run on CPU
        device = "cpu"
        points, _, _ = load_obj(obj_filename,
                                device=device,
                                load_textures=False)
        points = points[None, ...]
        out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)

        # Run on GPU
        device = get_random_cuda_device()
        points_cuda = points.to(device)
        out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda,
                                                                K=K)

        # Check that the indices from CUDA and CPU match
        self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())

        # Check there are no duplicate indices
        val_mask = out_idxs_cuda[0].ne(-1)
        vals, counts = torch.unique(out_idxs_cuda[0][val_mask],
                                    return_counts=True)
        self.assertTrue(counts.le(1).all())

        # Plot all results
        if DEBUG:
            # mplot3d is required for 3d projection plots
            import matplotlib.pyplot as plt
            from mpl_toolkits import mplot3d  # noqa: F401

            # Move to cpu and convert to numpy for plotting
            points = points.squeeze()
            out_points_cpu = out_points_cpu.squeeze().numpy()
            out_points_cuda = out_points_cuda.squeeze().cpu().numpy()

            # Farthest point sampling CPU
            fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
            ax1 = fig.add_subplot(1, 3, 1, projection="3d")
            ax1.scatter(*points.T, alpha=0.1)
            ax1.scatter(*out_points_cpu.T, color="black")
            ax1.set_title("FPS CPU")

            # Farthest point sampling CUDA
            ax2 = fig.add_subplot(1, 3, 2, projection="3d")
            ax2.scatter(*points.T, alpha=0.1)
            ax2.scatter(*out_points_cuda.T, color="red")
            ax2.set_title("FPS CUDA")

            # Random Sampling
            random_points = np.random.permutation(points)[:K]
            ax3 = fig.add_subplot(1, 3, 3, projection="3d")
            ax3.scatter(*points.T, alpha=0.1)
            ax3.scatter(*random_points.T, color="green")
            ax3.set_title("Random")

            # Save image
            filename = "DEBUG_fps.jpg"
            filepath = DATA_DIR / filename
            plt.savefig(filepath)