Ejemplo n.º 1
0
    def test_basic_3chan(self):
        """Test rendering one image with one sphere, 3 channels."""
        from pytorch3d.renderer.points.pulsar import Renderer

        LOGGER.info("Setting up rendering test for 3 channels...")
        n_points = 1
        width = 1_000
        height = 1_000
        renderer = Renderer(width, height, n_points)
        vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
        vert_col = torch.tensor([[0.3, 0.5, 0.7]], dtype=torch.float32)
        vert_rad = torch.tensor([1.0], dtype=torch.float32)
        cam_params = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0],
                                  dtype=torch.float32)
        for device in devices:
            vert_pos = vert_pos.to(device)
            vert_col = vert_col.to(device)
            vert_rad = vert_rad.to(device)
            cam_params = cam_params.to(device)
            renderer = renderer.to(device)
            LOGGER.info("Rendering...")
            # Measurements.
            result = renderer.forward(vert_pos, vert_col, vert_rad, cam_params,
                                      1.0e-1, 45.0)
            hits = renderer.forward(
                vert_pos,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            )
            if not os.environ.get("FB_TEST", False):
                imageio.imsave(
                    path.join(
                        path.dirname(__file__),
                        "test_out",
                        "test_forward_TestForward_test_basic_3chan.png",
                    ),
                    (result * 255.0).cpu().to(torch.uint8).numpy(),
                )
                imageio.imsave(
                    path.join(
                        path.dirname(__file__),
                        "test_out",
                        "test_forward_TestForward_test_basic_3chan_hits.png",
                    ),
                    (hits * 255.0).cpu().to(torch.uint8).numpy(),
                )
            self.assertEqual(hits[500, 500, 0].item(), 1.0)
            self.assertTrue(
                np.allclose(
                    result[500, 500, :].cpu().numpy(),
                    [0.3, 0.5, 0.7],
                    rtol=1e-2,
                    atol=1e-2,
                ))
Ejemplo n.º 2
0
def _bm_pulsar_backward():
    n_points = 1_000_000
    width = 1_000
    height = 1_000
    renderer = Renderer(width, height, n_points)
    # Generate sample data.
    torch.manual_seed(1)
    vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
    vert_pos[:, 2] += 25.0
    vert_pos[:, :2] -= 5.0
    vert_col = torch.rand(n_points, 3, dtype=torch.float32)
    vert_rad = torch.rand(n_points, dtype=torch.float32)
    cam_params = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0],
                              dtype=torch.float32)
    device = torch.device("cuda")
    vert_pos = vert_pos.to(device)
    vert_col = vert_col.to(device)
    vert_rad = vert_rad.to(device)
    cam_params = cam_params.to(device)
    renderer = renderer.to(device)
    vert_pos_var = Variable(vert_pos, requires_grad=True)
    vert_col_var = Variable(vert_col, requires_grad=True)
    vert_rad_var = Variable(vert_rad, requires_grad=True)
    cam_params_var = Variable(cam_params, requires_grad=True)
    res = renderer.forward(
        vert_pos_var,
        vert_col_var,
        vert_rad_var,
        cam_params_var,
        1.0e-1,
        45.0,
        percent_allowed_difference=0.01,
    )
    loss = res.sum()

    def bm_closure():
        loss.backward(retain_graph=True)
        torch.cuda.synchronize()

    return bm_closure
Ejemplo n.º 3
0
    def test_basic(self):
        """Basic forward test."""
        from pytorch3d.renderer.points.pulsar import Renderer
        import torch

        n_points = 10
        width = 1_000
        height = 1_000
        renderer_1 = Renderer(width, height, n_points, n_channels=1)
        renderer_3 = Renderer(width, height, n_points, n_channels=3)
        renderer_8 = Renderer(width, height, n_points, n_channels=8)
        # Generate sample data.
        torch.manual_seed(1)
        vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
        vert_pos[:, 2] += 25.0
        vert_pos[:, :2] -= 5.0
        vert_col = torch.rand(n_points, 8, dtype=torch.float32)
        vert_rad = torch.rand(n_points, dtype=torch.float32)
        cam_params = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0],
                                  dtype=torch.float32)
        for device in devices:
            vert_pos = vert_pos.to(device)
            vert_col = vert_col.to(device)
            vert_rad = vert_rad.to(device)
            cam_params = cam_params.to(device)
            renderer_1 = renderer_1.to(device)
            renderer_3 = renderer_3.to(device)
            renderer_8 = renderer_8.to(device)
            result_1 = (renderer_1.forward(
                vert_pos,
                vert_col[:, :1],
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
            ).cpu().detach().numpy())
            hits_1 = (renderer_1.forward(
                vert_pos,
                vert_col[:, :1],
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            ).cpu().detach().numpy())
            result_3 = (renderer_3.forward(
                vert_pos,
                vert_col[:, :3],
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
            ).cpu().detach().numpy())
            hits_3 = (renderer_3.forward(
                vert_pos,
                vert_col[:, :3],
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            ).cpu().detach().numpy())
            result_8 = (renderer_8.forward(
                vert_pos,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
            ).cpu().detach().numpy())
            hits_8 = (renderer_8.forward(
                vert_pos,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            ).cpu().detach().numpy())
            self.assertClose(result_1, result_3[:, :, :1])
            self.assertClose(result_3, result_8[:, :, :3])
            self.assertClose(hits_1, hits_3)
            self.assertClose(hits_8, hits_3)
Ejemplo n.º 4
0
    def test_principal_point(self):
        """Test shifting the principal point."""
        from pytorch3d.renderer.points.pulsar import Renderer

        LOGGER.info("Setting up rendering test for shifted principal point...")
        n_points = 1
        width = 1_000
        height = 1_000
        renderer = Renderer(width, height, n_points, n_channels=1)
        vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32)
        vert_col = torch.tensor([[0.0]], dtype=torch.float32)
        vert_rad = torch.tensor([1.0], dtype=torch.float32)
        cam_params = torch.tensor(
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0, 0.0, 0.0],
            dtype=torch.float32)
        for device in devices:
            vert_pos = vert_pos.to(device)
            vert_col = vert_col.to(device)
            vert_rad = vert_rad.to(device)
            cam_params = cam_params.to(device)
            cam_params[-2] = -250.0
            cam_params[-1] = -250.0
            renderer = renderer.to(device)
            LOGGER.info("Rendering...")
            # Measurements.
            result = renderer.forward(vert_pos, vert_col, vert_rad, cam_params,
                                      1.0e-1, 45.0)
            if not os.environ.get("FB_TEST", False):
                imageio.imsave(
                    path.join(
                        path.dirname(__file__),
                        "test_out",
                        "test_forward_TestForward_test_principal_point.png",
                    ),
                    (result * 255.0).cpu().to(torch.uint8).numpy(),
                )
            self.assertTrue(
                np.allclose(result[750, 750, :].cpu().numpy(), [0.0],
                            rtol=1e-2,
                            atol=1e-2))
        for device in devices:
            vert_pos = vert_pos.to(device)
            vert_col = vert_col.to(device)
            vert_rad = vert_rad.to(device)
            cam_params = cam_params.to(device)
            cam_params[-2] = 250.0
            cam_params[-1] = 250.0
            renderer = renderer.to(device)
            LOGGER.info("Rendering...")
            # Measurements.
            result = renderer.forward(vert_pos, vert_col, vert_rad, cam_params,
                                      1.0e-1, 45.0)
            if not os.environ.get("FB_TEST", False):
                imageio.imsave(
                    path.join(
                        path.dirname(__file__),
                        "test_out",
                        "test_forward_TestForward_test_principal_point.png",
                    ),
                    (result * 255.0).cpu().to(torch.uint8).numpy(),
                )
            self.assertTrue(
                np.allclose(result[250, 250, :].cpu().numpy(), [0.0],
                            rtol=1e-2,
                            atol=1e-2))
Ejemplo n.º 5
0
    def test_basic(self):
        """Basic forward test."""
        from pytorch3d.renderer.points.pulsar import Renderer

        n_points = 10
        width = 1000
        height = 1000
        renderer_left = Renderer(width,
                                 height,
                                 n_points,
                                 right_handed_system=False)
        renderer_right = Renderer(width,
                                  height,
                                  n_points,
                                  right_handed_system=True)
        # Generate sample data.
        torch.manual_seed(1)
        vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
        vert_pos[:, 2] += 25.0
        vert_pos[:, :2] -= 5.0
        vert_pos_neg = vert_pos.clone()
        vert_pos_neg[:, 2] *= -1.0
        vert_col = torch.rand(n_points, 3, dtype=torch.float32)
        vert_rad = torch.rand(n_points, dtype=torch.float32)
        cam_params = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0],
                                  dtype=torch.float32)
        for device in devices:
            vert_pos = vert_pos.to(device)
            vert_pos_neg = vert_pos_neg.to(device)
            vert_col = vert_col.to(device)
            vert_rad = vert_rad.to(device)
            cam_params = cam_params.to(device)
            renderer_left = renderer_left.to(device)
            renderer_right = renderer_right.to(device)
            result_left = (renderer_left.forward(
                vert_pos,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
            ).cpu().detach().numpy())
            hits_left = (renderer_left.forward(
                vert_pos,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            ).cpu().detach().numpy())
            result_right = (renderer_right.forward(
                vert_pos_neg,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
            ).cpu().detach().numpy())
            hits_right = (renderer_right.forward(
                vert_pos_neg,
                vert_col,
                vert_rad,
                cam_params,
                1.0e-1,
                45.0,
                percent_allowed_difference=0.01,
                mode=1,
            ).cpu().detach().numpy())
            self.assertClose(result_left, result_right)
            self.assertClose(hits_left, hits_right)