def test_invalid_inputs_shapes(self, device="cuda:0"): with self.assertRaisesRegex( ValueError, "input can only be 2-dimensional." ): values = torch.rand((100, 50, 2), device=device) first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) packed_to_padded(values, first_idxs, 100) with self.assertRaisesRegex( ValueError, "input can only be 3-dimensional." ): values = torch.rand((100,), device=device) first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) padded_to_packed(values, first_idxs, 20) with self.assertRaisesRegex( ValueError, "input can only be 3-dimensional." ): values = torch.rand((100, 50, 2, 2), device=device) first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) padded_to_packed(values, first_idxs, 20)
def _test_packed_to_padded_helper(self, D, device): """ Check the results from packed_to_padded and PyTorch implementations are the same. """ meshes = self.init_meshes(16, 100, 300, device=device) faces = meshes.faces_packed() mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx( ) max_faces = meshes.num_faces_per_mesh().max().item() if D == 0: values = torch.rand((faces.shape[0], ), device=device, requires_grad=True) else: values = torch.rand((faces.shape[0], D), device=device, requires_grad=True) values_torch = values.detach().clone() values_torch.requires_grad = True values_padded = packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces) values_padded_torch = TestPackedToPadded.packed_to_padded_python( values_torch, mesh_to_faces_packed_first_idx, max_faces, device) # check forward self.assertClose(values_padded, values_padded_torch) # check backward if D == 0: grad_inputs = torch.rand((len(meshes), max_faces), device=device) else: grad_inputs = torch.rand((len(meshes), max_faces, D), device=device) values_padded.backward(grad_inputs) grad_outputs = values.grad values_padded_torch.backward(grad_inputs) grad_outputs_torch1 = values_torch.grad grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python( grad_inputs, mesh_to_faces_packed_first_idx, values.size(0), device=device, ) self.assertClose(grad_outputs, grad_outputs_torch1) self.assertClose(grad_outputs, grad_outputs_torch2)
def reduce_mask_padded(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Remove the invalid values in a padded tensor form as much as possible Args: values (tensor(number)): (N, P, ..., C) mask (tensor(bool)): (N, P, ...) bool values Returns: reduced (tensor(number)): (N, Pmax, ..., C) Pmax is the maximum number of True values in the batches of mask. """ from pytorch3d.ops import packed_to_padded batch_size = values.shape[0] value_packed = values[mask] first_idx = torch.zeros( (values.shape[0],), device=values.device, dtype=torch.long) num_true_in_batch = mask.view(batch_size, -1).sum(dim=1) first_idx[1:] = num_true_in_batch.cumsum(dim=0)[:-1] value_padded = packed_to_padded( value_packed, first_idx, num_true_in_batch.max().item()) return value_padded
def out(): packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces) torch.cuda.synchronize()
def backward(ctx, idx_grad, zbuf_grad, qvalue_grad, occ_grad): # idx_grad and zbuf_grad are None (unless maybe we make weights depend on z? i.e. volumetric splatting) grad_radii = None grad_cloud_to_packed_first_idx = None grad_num_points_per_cloud = None grad_cutoff_thres = None grad_depth_merging_thres = None grad_image_size = None grad_points_per_pixel = None grad_bin_size = None grad_max_points_per_bin = None grad_radii_s = None grad_backward_rbf = None grads = (grad_cutoff_thres, grad_radii, grad_cloud_to_packed_first_idx, grad_num_points_per_cloud, grad_depth_merging_thres, grad_image_size, grad_points_per_pixel, grad_bin_size, grad_max_points_per_bin, grad_radii_s, grad_backward_rbf) radii_s = ctx.radii_backward_scaler # either use OccRBFBackward or use OccBackward pts_screen, ellipse_param, cutoff_threshold, radii, idx, zbuf0, \ cloud_to_packed_first_idx, num_points_per_cloud, \ = ctx.saved_tensors depth_merging_threshold = ctx.depth_merging_threshold backward_occ_fast = True if not backward_occ_fast: device = pts_screen.device grads_input_xy = pts_screen.new_zeros((pts_screen.shape[0], 2)) grads_input_z = pts_screen.new_zeros((pts_screen.shape[0], 1)) mask = (idx[..., 0] >= 0).bool() # float pts_visibility = torch.full((pts_screen.shape[0], ), False, dtype=torch.bool, device=pts_screen.device) # all rendered points (indices in packed points) visible_idx = idx[mask].unique().long().view(-1) visible_idx = visible_idx[visible_idx >= 0] pts_visibility[visible_idx] = True num_points_per_cloud = torch.stack([ x.sum() for x in torch.split( pts_visibility, num_points_per_cloud.tolist(), dim=0) ]) cloud_to_packed_first_idx = num_points_2_cloud_to_packed_first_idx( num_points_per_cloud) pts_screen = pts_screen[pts_visibility] radii = radii[pts_visibility] grad_visible = _C._splat_points_occ_backward( pts_screen, radii, occ_grad, cloud_to_packed_first_idx, num_points_per_cloud, radii_s, depth_merging_threshold) if torch.isnan(grad_visible).any( ) or not torch.isfinite(grad_visible).all(): print('invalid grad_visible') assert (pts_visibility.sum() == grad_visible.shape[0]) grads_input_xy[pts_visibility] = grad_visible _C._backward_zbuf(idx, zbuf_grad, grads_input_z) # TODO necessary to concatenate grads_input = torch.cat([grads_input_xy, grads_input_z], dim=-1) else: """ We only care about rasterized points (visible points) 1. Filter [P,*] data to [P_visible,*] data 2. Fast backward cuda 2a. call FRNN insertion 2b. count_sort """ device = pts_screen.device mask = (idx[..., 0] >= 0).bool() # float pts_visibility = torch.full((pts_screen.shape[0], ), False, dtype=torch.bool, device=pts_screen.device) # all rendered points (indices in packed points) visible_idx = idx[mask].unique().long().view(-1) visible_idx = visible_idx[visible_idx >= 0] pts_visibility[visible_idx] = True num_points_per_cloud = torch.stack([ x.sum() for x in torch.split( pts_visibility, num_points_per_cloud.tolist(), dim=0) ]) cloud_to_packed_first_idx = num_points_2_cloud_to_packed_first_idx( num_points_per_cloud) pts_screen_visible = pts_screen[pts_visibility] radii_visible = radii[pts_visibility] ##################################### # 2a. call FRNN insertion ##################################### N = num_points_per_cloud.shape[0] P = pts_screen_visible.shape[0] assert (num_points_per_cloud.sum().item() == P) # from frnn.frnn import GRID_PARAMS_SIZE, MAX_RES, prefix_sum_cuda # imported from from prefix_sum import prefix_sum_cuda GRID_2D_PARAMS_SIZE = 6 GRID_2D_MAX_RES = 1024 GRID_2D_DELTA = 2 GRID_2D_TOTAL = 5 RADIUS_CELL_RATIO = 2 # first convert to padded max_P = num_points_per_cloud.max().item() pts_padded = ops3d.packed_to_padded(pts_screen_visible, cloud_to_packed_first_idx, max_P) radii_padded = ops3d.packed_to_padded(radii_visible, cloud_to_packed_first_idx, max_P) # determine search radius as max(radii)*radii_s search_radius = torch.tensor([ radii_padded[i, :num_points_per_cloud[i]].median() * radii_s for i in range(N) ], dtype=torch.float, device=device) # create grid from scratch # setup grid params grid_params_cuda = torch.zeros((N, GRID_2D_PARAMS_SIZE), dtype=torch.float, device=pts_padded.device) G = -1 pts_padded_2D = pts_padded[:, :, :2].clone().contiguous() for i in range(N): # 0-2 grid_min; 3 grid_delta; 4-6 grid_res; 7 grid_total grid_min = pts_padded_2D[i, :num_points_per_cloud[i]].min( dim=0)[0] grid_max = pts_padded_2D[i, :num_points_per_cloud[i]].max( dim=0)[0] grid_params_cuda[i, :GRID_2D_DELTA] = grid_min grid_size = grid_max - grid_min cell_size = search_radius[i].item() / RADIUS_CELL_RATIO if cell_size < grid_size.min() / GRID_2D_MAX_RES: cell_size = grid_size.min() / GRID_2D_MAX_RES grid_params_cuda[i, GRID_2D_DELTA] = 1 / cell_size grid_params_cuda[i, GRID_2D_DELTA + 1:GRID_2D_TOTAL] = torch.floor( grid_size / cell_size) + 1 grid_params_cuda[i, GRID_2D_TOTAL] = torch.prod( grid_params_cuda[i, GRID_2D_DELTA + 1:GRID_2D_TOTAL]) if G < grid_params_cuda[i, GRID_2D_TOTAL]: G = int(grid_params_cuda[i, GRID_2D_TOTAL].item()) # insert points into the grid pc_grid_cnt = torch.zeros((N, G), dtype=torch.int, device=device) pc_grid_cell = torch.full((N, max_P), -1, dtype=torch.int, device=device) pc_grid_idx = torch.full((N, max_P), -1, dtype=torch.int, device=device) frnn._C.insert_points_cuda(pts_padded_2D, num_points_per_cloud, grid_params_cuda, pc_grid_cnt, pc_grid_cell, pc_grid_idx, G) # use prefix_sum from Matt Dean grid_params = grid_params_cuda.cpu() pc_grid_off = torch.full((N, G), 0, dtype=torch.int, device=device) for i in range(N): prefix_sum_cuda(pc_grid_cnt[i], grid_params[i, GRID_2D_TOTAL], pc_grid_off[i]) # sort points according to their grid positions and insertion orders # sort based on x, y first. Then we will use points_sorted_idxs to recover the points_sorted with Z points_sorted = torch.zeros((N, max_P, 2), dtype=torch.float, device=device) points_sorted_idxs = torch.full((N, max_P), -1, dtype=torch.int, device=device) frnn._C.counting_sort_cuda( pts_padded_2D, num_points_per_cloud, pc_grid_cell, pc_grid_idx, pc_grid_off, points_sorted, # (N,P,2) points_sorted_idxs # (N,P) ) new_points_sorted = torch.zeros_like(pts_padded) for i in range(N): points_sorted_idxs_i = points_sorted_idxs[ i, :num_points_per_cloud[i]].long().unsqueeze(1).expand( -1, 3) new_points_sorted[i, :num_points_per_cloud[i]] = torch.gather( pts_padded[i], 0, points_sorted_idxs_i) # print(points_sorted[i, :10]) # print(new_points_sorted[i, :10]) # new_points_sorted = torch.gather(pts_padded, 1, points_sorted_idxs.long().unsqueeze(2).expand(-1, -1, 3)) assert (new_points_sorted is not None and pc_grid_off is not None and points_sorted_idxs is not None and grid_params_cuda is not None) # convert sorted_points and sorted_points_idxs to packed (P, ) points_sorted = ops3d.padded_to_packed(new_points_sorted, cloud_to_packed_first_idx, P) # padded_to_packed only supports torch.float32... shifted_points_sorted_idxs = points_sorted_idxs + cloud_to_packed_first_idx.float( ).unsqueeze(1) points_sorted_idxs = ops3d.padded_to_packed( shifted_points_sorted_idxs, cloud_to_packed_first_idx, P) points_sorted_idxs_2D = points_sorted_idxs.long().unsqueeze( 1).expand(-1, 2) radii_sorted = torch.gather(radii_visible, 0, points_sorted_idxs_2D) pc_grid_off += cloud_to_packed_first_idx.unsqueeze(1) grad_sorted = _C._splat_points_occ_fast_cuda_backward( points_sorted, radii_sorted, search_radius, occ_grad, num_points_per_cloud, cloud_to_packed_first_idx, pc_grid_off, grid_params_cuda) # grad_sorted_slow = _C._splat_points_occ_backward(points_sorted, radii_sorted, # occ_grad, cloud_to_packed_first_idx, num_points_per_cloud, # radii_s, depth_merging_threshold) # breakpoint() # points_sorted_idxs_3D = points_sorted_idxs.long().unsqueeze(1).expand(-1, 3) # print(points_sorted_idxs_3D.max(), grad_sorted.shape[0]) grad_visible = torch.zeros_like(grad_sorted).scatter_( 0, points_sorted_idxs_2D, grad_sorted) # grad_visible_slow = _C._splat_points_occ_backward(pts_screen[pts_visibility], radii[pts_visibility], # occ_grad, cloud_to_packed_first_idx, num_points_per_cloud, # radii_s, depth_merging_threshold) # breakpoint() if torch.isnan(grad_visible).any( ) or not torch.isfinite(grad_visible).all(): print('invalid grad_visible') assert (pts_visibility.sum() == grad_visible.shape[0]) grads_input_xy = pts_screen.new_zeros(pts_screen.shape[0], 2) grads_input_z = pts_screen.new_zeros(pts_screen.shape[0], 1) # print("1") grads_input_xy[pts_visibility] = grad_visible _C._backward_zbuf(idx, zbuf_grad, grads_input_z) grads_input = torch.cat([grads_input_xy, grads_input_z], dim=-1) # print("2") pts_grad = grads_input return (pts_grad, None) + grads
def forward(self, point_clouds, point_clouds_filter=None, **kwargs) -> PointFragments: """ Args: point_clouds (Pointclouds3D): a set of point clouds with coordinates. per_point_info (dict): radii_packed: (N,2) axis-aligned radii in packed form ellipse_params_packed: (N,3) ellipse parameters in packed form Returns: PointFragments: Rasterization outputs as a named tuple. """ raster_settings = kwargs.get("raster_settings", self.raster_settings) cameras = kwargs.get('cameras', self.cameras) max_P = point_clouds.num_points_per_cloud().max().item() total_P = point_clouds.num_points_per_cloud().sum().item() point_clouds_filtered, mask_filtered = self.filter_renderable( point_clouds, point_clouds_filter, **kwargs) if point_clouds_filtered.isempty(): return self._empty_fragments(cameras.R.shape[0], **kwargs) # compute per-point features for elliptical gaussian weights with torch.autograd.no_grad(): per_point_info = self._get_per_point_info(point_clouds_filtered, **kwargs) _tmp = point_clouds_filtered.points_padded() if _tmp.requires_grad: _tmp.register_hook(lambda x: _check_grad(x, 'transform')) pcls_screen = self.transform(point_clouds_filtered, **kwargs) idx, zbuf, qvalue_map, occ_map = rasterize_elliptical_points( pcls_screen, per_point_info["ellipse_params"], per_point_info['cutoff_threshold'], per_point_info["radii"], depth_merging_threshold=raster_settings.depth_merging_threshold, image_size=raster_settings.image_size, points_per_pixel=raster_settings.points_per_pixel, bin_size=raster_settings.bin_size, max_points_per_bin=raster_settings.max_points_per_bin, radii_backward_scaler=raster_settings.radii_backward_scaler, clip_pts_grad=raster_settings.clip_pts_grad) # compute weight: scalar*exp(-0.5Q) frag_scaler = gather_with_neg_idx(per_point_info['scaler'], 0, idx.view(-1).long()) frag_scaler = frag_scaler.view_as(qvalue_map) fragments = PointFragments(idx=idx, zbuf=zbuf, qvalue=qvalue_map, scaler=frag_scaler, occupancy=occ_map) # returns (P,) boolean mask for visibility visibility_mask = get_per_point_visibility_mask( point_clouds_filtered, fragments) mask_filtered[mask_filtered] = visibility_mask if point_clouds_filter is not None: # update point_clouds visibility filter # we use this information in projection loss # put all_depth_visibility_mask (num_active) to original_visibility_mask (P,) # transform to padded # original_visibility_mask = ops3d.packed_to_padded( # valid_depth_mask.float(), first_idx, max_P).bool() # lixin original_visibility_mask = ops3d.packed_to_padded( mask_filtered.float(), point_clouds.cloud_to_packed_first_idx(), max_P).bool() point_clouds_filter.set_filter(visibility=original_visibility_mask) if kwargs.get('verbose', False): # use scatter to get per point info of the original original_per_point_info = {} for k in per_point_info: original_per_point_info[k] = per_point_info[k].new_zeros( (total_P, ) + per_point_info[k].shape[1:]) original_per_point_info[k][mask_filtered] = per_point_info[k] return fragments, point_clouds_filtered, original_per_point_info return fragments, point_clouds_filtered
def test_dataset(self): # 1. rerender input point clouds / meshes using the saved camera_mat # compare mask image with saved mask image # 2. backproject masked points to space with dense depth map, # fuse all views and save batch_size = 1 device = torch.device('cuda:0') data_dir = 'data/synthetic/cube_mesh' output_dir = os.path.join('tests', 'outputs', 'test_data') if not os.path.isdir(output_dir): os.makedirs(output_dir) # dataset dataset = MVRDataset(data_dir=data_dir, load_dense_depth=True, mode="train") data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False) meshes = load_objs_as_meshes([os.path.join(data_dir, 'mesh.obj')]).to(device) cams = dataset.get_cameras().to(device) image_size = imageio.imread(dataset.image_files[0]).shape[0] # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc raster_settings = RasterizationSettings( image_size=image_size, blur_radius=0.0, faces_per_pixel=5, bin_size= None, # this setting controls whether naive or coarse-to-fine rasterization is used max_faces_per_bin=None # this setting is for coarse rasterization ) rasterizer = MeshRasterizer(cameras=None, raster_settings=raster_settings) # render with loaded cameras positions and training tranformation functions pixel_world_all = [] for idx, data in enumerate(data_loader): # get datas img = data.get('img.rgb').to(device) assert (img.min() >= 0 and img.max() <= 1 ), "Image must be a floating number between 0 and 1." mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1) camera_mat = data['camera_mat'].to(device) cams.R, cams.T = decompose_to_R_and_t(camera_mat) cams._N = cams.R.shape[0] cams.to(device) self.assertTrue( torch.equal(cams.get_world_to_view_transform().get_matrix(), camera_mat)) # transform to view and rerender with non-rotated camera verts_padded = transform_to_camera_space(meshes.verts_padded(), cams) meshes_in_view = meshes.offset_verts( -meshes.verts_packed() + padded_to_packed( verts_padded, meshes.mesh_to_verts_packed_first_idx(), meshes.verts_packed().shape[0])) fragments = rasterizer(meshes_in_view, cameras=dataset.get_cameras().to(device)) # compare mask mask = fragments.pix_to_face[..., :1] >= 0 imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx), mask[0, ...].cpu().to(dtype=torch.uint8) * 255) # allow 5 pixels difference self.assertTrue(torch.sum(mask_gt != mask) < 5) # check dense maps # backproject points to the world pixel range (-1, 1) pixels = arange_pixels((image_size, image_size), batch_size)[1].to(device) depth_img = data.get('img.depth').to(device) # get the depth and mask at the sampled pixel position depth_gt = get_tensor_values(depth_img, pixels, squeeze_channel_dim=True) mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(), pixels, squeeze_channel_dim=True).bool() # get pixels and depth inside the masked area pixels_packed = pixels[mask_gt] depth_gt_packed = depth_gt[mask_gt] first_idx = torch.zeros((pixels.shape[0], ), device=device, dtype=torch.long) num_pts_in_mask = mask_gt.sum(dim=1) first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1] pixels_padded = packed_to_padded(pixels_packed, first_idx, num_pts_in_mask.max().item()) depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx, num_pts_in_mask.max().item()) # backproject to world coordinates # contains nan and infinite values due to depth_gt_padded containing 0.0 pixel_world_padded = transform_to_world(pixels_padded, depth_gt_padded[..., None], cams) # transform back to list, containing no padded values split_size = num_pts_in_mask[..., None].repeat(1, 2) split_size[:, 1] = 3 pixel_world_list = padded_to_list(pixel_world_padded, split_size) pixel_world_all.extend(pixel_world_list) idx += 1 if idx >= 10: break pixel_world_all = torch.cat(pixel_world_all, dim=0) mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(), faces=None, process=False) mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))