def compute(self, points: torch.Tensor, sdf: torch.Tensor, mesh_gt: Meshes): """ Rasterize mesh faces from an far camera facing the origin, transform the predicted points position to camera view and project to get the normalized image coordinates The number of points on the zbuf at the image coordinates that are larger than the predicted points determines the sign of sdf """ assert (points.ndim == 2 and points.shape[-1] == 3) device = points.device faces_per_pixel = 4 with torch.autograd.no_grad(): # a point that is definitely outside the mesh as camera center ray0 = torch.tensor([2, 2, 2], device=device, dtype=points.dtype).view(1, 3) R, T = look_at_view_transform(eye=ray0, at=((0, 0, 0), ), up=((0, 0, 1), )) cameras = PerspectiveCameras(R=R, T=T, device=device) rasterizer = MeshRasterizer(cameras=cameras, raster_settings=RasterizationSettings( faces_per_pixel=faces_per_pixel, )) fragments = rasterizer(mesh_gt) z_predicted = cameras.get_world_to_view_transform( ).transform_points(points=points.unsqueeze(0))[..., -1:] # normalized pixel (top-left smallest values) screen_xy = -cameras.transform_points(points.unsqueeze(0))[..., :2] outside_screen = (screen_xy.abs() > 1.0).any(dim=-1) # pix_to_face, zbuf, bary_coords, dists assert (fragments.zbuf.shape[-1] == faces_per_pixel) zbuf = torch.nn.functional.grid_sample( fragments.zbuf.permute(0, 3, 1, 2), screen_xy.clamp(-1.0, 1.0).view(1, -1, 1, 2), align_corners=False, mode='nearest') zbuf[outside_screen.unsqueeze(1).expand(-1, zbuf.shape[1], -1)] = -1.0 sign = (((zbuf > z_predicted).sum(dim=1) % 2) == 0).type_as(points).view(screen_xy.shape[1]) sign = sign * 2 - 1 pcls = PointClouds3D(points.unsqueeze(0)).to(device=device) points_first_idx = pcls.cloud_to_packed_first_idx() max_points = pcls.num_points_per_cloud().max().item() # packed representation for faces verts_packed = mesh_gt.verts_packed() faces_packed = mesh_gt.faces_packed() tris = verts_packed[faces_packed] # (T, 3, 3) tris_first_idx = mesh_gt.mesh_to_faces_packed_first_idx() max_tris = mesh_gt.num_faces_per_mesh().max().item() # point to face distance: shape (P,) point_to_face = point_face_distance(points, points_first_idx, tris, tris_first_idx, max_points) point_to_face = sign * torch.sqrt(eps_sqrt(point_to_face)) loss = (point_to_face - sdf)**2 return loss
def test_ndc_grid_sample_rendering(self): """ Use PyTorch3D point renderer to render a colored point cloud, then sample the image at the locations of the point projections with `ndc_grid_sample`. Finally, assert that the sampled colors are equal to the original point cloud colors. Note that, in order to ensure correctness, we use a nearest-neighbor assignment point renderer (i.e. no soft splatting). """ # generate a bunch of 3D points on a regular grid lying in the z-plane n_grid_pts = 10 grid_scale = 0.9 z_plane = 2.0 image_size = [128, 128] point_radius = 0.015 n_pts = n_grid_pts * n_grid_pts pts = torch.stack( meshgrid_ij([torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, ), dim=-1, ) pts = torch.cat([pts, z_plane * torch.ones_like(pts[..., :1])], dim=-1) pts = pts.reshape(1, n_pts, 3) # color the points randomly pts_colors = torch.rand(1, n_pts, 3) # make trivial rendering cameras cameras = PerspectiveCameras( R=eyes(dim=3, N=1), device=pts.device, T=torch.zeros(1, 3, dtype=torch.float32, device=pts.device), ) # render the point cloud pcl = Pointclouds(points=pts, features=pts_colors) renderer = NearestNeighborPointsRenderer( rasterizer=PointsRasterizer( cameras=cameras, raster_settings=PointsRasterizationSettings( image_size=image_size, radius=point_radius, points_per_pixel=1, ), ), compositor=AlphaCompositor(), ) im_render = renderer(pcl) # sample the render at projected pts pts_proj = cameras.transform_points(pcl.points_padded())[..., :2] pts_colors_sampled = ndc_grid_sample( im_render, pts_proj, mode="nearest", align_corners=False, ).permute(0, 2, 1) # assert that the samples are the same as original points self.assertClose(pts_colors, pts_colors_sampled, atol=1e-4)