Example #1
0
    def compute(self, points: torch.Tensor, sdf: torch.Tensor,
                mesh_gt: Meshes):
        """
        Rasterize mesh faces from an far camera facing the origin,
        transform the predicted points position to camera view and project to get the normalized image coordinates
        The number of points on the zbuf at the image coordinates that are larger than the predicted points
        determines the sign of sdf
        """
        assert (points.ndim == 2 and points.shape[-1] == 3)
        device = points.device
        faces_per_pixel = 4
        with torch.autograd.no_grad():
            # a point that is definitely outside the mesh as camera center
            ray0 = torch.tensor([2, 2, 2], device=device,
                                dtype=points.dtype).view(1, 3)
            R, T = look_at_view_transform(eye=ray0,
                                          at=((0, 0, 0), ),
                                          up=((0, 0, 1), ))
            cameras = PerspectiveCameras(R=R, T=T, device=device)
            rasterizer = MeshRasterizer(cameras=cameras,
                                        raster_settings=RasterizationSettings(
                                            faces_per_pixel=faces_per_pixel, ))
            fragments = rasterizer(mesh_gt)

            z_predicted = cameras.get_world_to_view_transform(
            ).transform_points(points=points.unsqueeze(0))[..., -1:]
            # normalized pixel (top-left smallest values)
            screen_xy = -cameras.transform_points(points.unsqueeze(0))[..., :2]
            outside_screen = (screen_xy.abs() > 1.0).any(dim=-1)

            # pix_to_face, zbuf, bary_coords, dists
            assert (fragments.zbuf.shape[-1] == faces_per_pixel)
            zbuf = torch.nn.functional.grid_sample(
                fragments.zbuf.permute(0, 3, 1, 2),
                screen_xy.clamp(-1.0, 1.0).view(1, -1, 1, 2),
                align_corners=False,
                mode='nearest')
            zbuf[outside_screen.unsqueeze(1).expand(-1, zbuf.shape[1],
                                                    -1)] = -1.0
            sign = (((zbuf > z_predicted).sum(dim=1) %
                     2) == 0).type_as(points).view(screen_xy.shape[1])
            sign = sign * 2 - 1

        pcls = PointClouds3D(points.unsqueeze(0)).to(device=device)

        points_first_idx = pcls.cloud_to_packed_first_idx()
        max_points = pcls.num_points_per_cloud().max().item()

        # packed representation for faces
        verts_packed = mesh_gt.verts_packed()
        faces_packed = mesh_gt.faces_packed()
        tris = verts_packed[faces_packed]  # (T, 3, 3)
        tris_first_idx = mesh_gt.mesh_to_faces_packed_first_idx()
        max_tris = mesh_gt.num_faces_per_mesh().max().item()

        # point to face distance: shape (P,)
        point_to_face = point_face_distance(points, points_first_idx, tris,
                                            tris_first_idx, max_points)
        point_to_face = sign * torch.sqrt(eps_sqrt(point_to_face))
        loss = (point_to_face - sdf)**2
        return loss
Example #2
0
    def test_ndc_grid_sample_rendering(self):
        """
        Use PyTorch3D point renderer to render a colored point cloud, then
        sample the image at the locations of the point projections with
        `ndc_grid_sample`. Finally, assert that the sampled colors are equal to the
        original point cloud colors.

        Note that, in order to ensure correctness, we use a nearest-neighbor
        assignment point renderer (i.e. no soft splatting).
        """

        # generate a bunch of 3D points on a regular grid lying in the z-plane
        n_grid_pts = 10
        grid_scale = 0.9
        z_plane = 2.0
        image_size = [128, 128]
        point_radius = 0.015
        n_pts = n_grid_pts * n_grid_pts
        pts = torch.stack(
            meshgrid_ij([torch.linspace(-grid_scale, grid_scale, n_grid_pts)] *
                        2, ),
            dim=-1,
        )
        pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1)
        pts = pts.reshape(1, n_pts, 3)

        # color the points randomly
        pts_colors = torch.rand(1, n_pts, 3)

        # make trivial rendering cameras
        cameras = PerspectiveCameras(
            R=eyes(dim=3, N=1),
            device=pts.device,
            T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device),
        )

        # render the point cloud
        pcl = Pointclouds(points=pts, features=pts_colors)
        renderer = NearestNeighborPointsRenderer(
            rasterizer=PointsRasterizer(
                cameras=cameras,
                raster_settings=PointsRasterizationSettings(
                    image_size=image_size,
                    radius=point_radius,
                    points_per_pixel=1,
                ),
            ),
            compositor=AlphaCompositor(),
        )
        im_render = renderer(pcl)

        # sample the render at projected pts
        pts_proj = cameras.transform_points(pcl.points_padded())[..., :2]
        pts_colors_sampled = ndc_grid_sample(
            im_render,
            pts_proj,
            mode="nearest",
            align_corners=False,
        ).permute(0, 2, 1)

        # assert that the samples are the same as original points
        self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)