def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List): """ Precaches the rays emitted from the list of cameras `cameras`, where each camera is uniquely identified with the corresponding hash from `camera_hashes`. The cached rays are moved to cpu and stored in `self._ray_cache`. Raises `ValueError` when caching two cameras with the same hash. Args: cameras: A list of `N` cameras for which the rays are pre-cached. camera_hashes: A list of `N` unique identifiers of each camera from `cameras`. """ print(f"Precaching {len(cameras)} ray bundles ...") full_chunksize = ( self._grid_raysampler._xy_grid.numel() // 2 * self._grid_raysampler._n_pts_per_ray ) if self.get_n_chunks(full_chunksize, 1) != 1: raise ValueError("There has to be one chunk for precaching rays!") for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)): ray_bundle = self.forward( camera, caching=True, chunksize=full_chunksize, ) if camera_hash in self._ray_cache: raise ValueError("There are redundant cameras!") self._ray_cache[camera_hash] = RayBundle( *[v.to("cpu").detach() for v in ray_bundle] ) self._print_precaching_progress(camera_i, len(cameras)) print("")
def _get_bundle(self, *, device) -> RayBundle: origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) bundle = RayBundle( lengths=lengths, origins=origins, directions=directions, xys=None ) return bundle
def forward( self, input_ray_bundle: RayBundle, ray_weights: torch.Tensor, **kwargs, ) -> RayBundle: """ Args: input_ray_bundle: An instance of `RayBundle` specifying the source rays for sampling of the probability distribution. ray_weights: A tensor of shape `(..., input_ray_bundle.legths.shape[-1])` with non-negative elements defining the probability distribution to sample ray points from. Returns: ray_bundle: A new `RayBundle` instance containing the input ray points together with `n_pts_per_ray` additional sampled points per ray. """ # Calculate the mid-points between the ray depths. z_vals = input_ray_bundle.lengths batch_size = z_vals.shape[0] z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) # Carry out the importance sampling. z_samples = ( sample_pdf( z_vals_mid.view(-1, z_vals_mid.shape[-1]), ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], self._n_pts_per_ray, det=not ( (self._stratified and self.training) or (self._stratified_test and not self.training) ), ) .detach() .view(batch_size, z_vals.shape[1], self._n_pts_per_ray) ) if self._add_input_samples: # Add the new samples to the input ones. z_vals = torch.cat((z_vals, z_samples), dim=-1) else: z_vals = z_samples # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) return RayBundle( origins=input_ray_bundle.origins, directions=input_ray_bundle.directions, lengths=z_vals, xys=input_ray_bundle.xys, )
def _add_struct_from_batch( batched_struct: Struct, scene_num: int, subplot_title: str, scene_dictionary: Dict[str, Dict[str, Struct]], trace_idx: int = 1, ): # pragma: no cover """ Adds the struct corresponding to the given scene_num index to a provided scene_dictionary to be passed in to plot_scene Args: batched_struct: the batched data structure to add to the dict scene_num: the subplot from plot_batch_individually which this struct should be added to subplot_title: the title of the subplot scene_dictionary: the dictionary to add the indexed struct to trace_idx: the trace number, starting at 1 for this struct's trace """ struct = None if isinstance(batched_struct, CamerasBase): # we can't index directly into camera batches R, T = batched_struct.R, batched_struct.T # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor, # torch.nn.Module]`. r_idx = min(scene_num, len(R) - 1) # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor, # torch.nn.Module]`. t_idx = min(scene_num, len(T) - 1) # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], # torch.Tensor, torch.nn.Module]` is not a function. R = R[r_idx].unsqueeze(0) # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], # torch.Tensor, torch.nn.Module]` is not a function. T = T[t_idx].unsqueeze(0) struct = CamerasBase(device=batched_struct.device, R=R, T=T) elif isinstance(batched_struct, RayBundle): # for RayBundle we treat the 1st dim as the batch index struct_idx = min(scene_num, len(batched_struct.lengths) - 1) struct = RayBundle( **{ attr: getattr(batched_struct, attr)[struct_idx] for attr in ["origins", "directions", "lengths", "xys"] }) else: # batched meshes and pointclouds are indexable struct_idx = min(scene_num, len(batched_struct) - 1) struct = batched_struct[struct_idx] trace_name = "trace{}-{}".format(scene_num + 1, trace_idx) scene_dictionary[subplot_title][trace_name] = struct
def _chunk_generator( chunk_size: int, ray_bundle: RayBundle, object_mask: Optional[torch.Tensor], tqdm_trigger_threshold: int, *args, **kwargs, ): """ Helper function which yields chunks of rays from the input ray_bundle, to be used when the number of rays is large and will not fit in memory for rendering. """ ( batch_size, *spatial_dim, n_pts_per_ray, ) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0: raise ValueError(f"chunk_size_grid ({chunk_size}) should be divisible " f"by n_pts_per_ray ({n_pts_per_ray})") n_rays = math.prod(spatial_dim) # special handling for raytracing-based methods n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) chunk_size_in_rays = -(-n_rays // n_chunks) iter = range(0, n_rays, chunk_size_in_rays) if len(iter) >= tqdm_trigger_threshold: iter = tqdm.tqdm(iter) for start_idx in iter: end_idx = min(start_idx + chunk_size_in_rays, n_rays) ray_bundle_chunk = RayBundle( origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], directions=ray_bundle.directions.reshape(batch_size, -1, 3)[:, start_idx:end_idx], lengths=ray_bundle.lengths.reshape( batch_size, math.prod(spatial_dim), n_pts_per_ray)[:, start_idx:end_idx], xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], ) extra_args = kwargs.copy() if object_mask is not None: extra_args["object_mask"] = object_mask.reshape( batch_size, -1, 1)[:, start_idx:end_idx] yield [ray_bundle_chunk, *args], extra_args
def test_simple(self): length = 15 n_pts_per_ray = 10 for add_input_samples in [False, True]: ray_point_refiner = RayPointRefiner( n_pts_per_ray=n_pts_per_ray, random_sampling=False, add_input_samples=add_input_samples, ) lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None) weights = torch.ones(3, 25, length) refined = ray_point_refiner(bundle, weights) self.assertIsNone(refined.directions) self.assertIsNone(refined.origins) self.assertIsNone(refined.xys) expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray) expected = expected.expand(3, 25, n_pts_per_ray) if add_input_samples: full_expected = torch.cat((lengths, expected), dim=-1).sort()[0] else: full_expected = expected self.assertClose(refined.lengths, full_expected) ray_point_refiner_random = RayPointRefiner( n_pts_per_ray=n_pts_per_ray, random_sampling=True, add_input_samples=add_input_samples, ) refined_random = ray_point_refiner_random(bundle, weights) lengths_random = refined_random.lengths self.assertEqual(lengths_random.shape, full_expected.shape) if not add_input_samples: self.assertGreater(lengths_random.min().item(), 0.5) self.assertLess(lengths_random.max().item(), length - 1.5) # Check sorted self.assertTrue( (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all())
def forward( self, input_ray_bundle: RayBundle, ray_weights: torch.Tensor, **kwargs, ) -> RayBundle: """ Args: input_ray_bundle: An instance of `RayBundle` specifying the source rays for sampling of the probability distribution. ray_weights: A tensor of shape `(..., input_ray_bundle.legths.shape[-1])` with non-negative elements defining the probability distribution to sample ray points from. Returns: ray_bundle: A new `RayBundle` instance containing the input ray points together with `n_pts_per_ray` additionally sampled points per ray. For each ray, the lengths are sorted. """ z_vals = input_ray_bundle.lengths with torch.no_grad(): z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) z_samples = sample_pdf( z_vals_mid.view(-1, z_vals_mid.shape[-1]), ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], self.n_pts_per_ray, det=not self.random_sampling, ).view(*z_vals.shape[:-1], self.n_pts_per_ray) if self.add_input_samples: # Add the new samples to the input ones. z_vals = torch.cat((z_vals, z_samples), dim=-1) else: z_vals = z_samples # Resort by depth. z_vals, _ = torch.sort(z_vals, dim=-1) return RayBundle( origins=input_ray_bundle.origins, directions=input_ray_bundle.directions, lengths=z_vals, xys=input_ray_bundle.xys, )
def _add_ray_bundle_trace( fig: go.Figure, ray_bundle: RayBundle, trace_name: str, subplot_idx: int, ncols: int, max_rays: int, max_points_per_ray: int, marker_size: int, line_width: int, ): # pragma: no cover """ Adds a trace rendering a RayBundle object to the passed in figure, with a given name and in a specific subplot. Args: fig: plotly figure to add the trace within. cameras: the Cameras object to render. It can be batched. trace_name: name to label the trace with. subplot_idx: identifies the subplot, with 0 being the top left. ncols: the number of subplots per row. max_rays: maximum number of plotted rays in total. Randomly subsamples without replacement in case the number of rays is bigger than max_rays. max_points_per_ray: maximum number of points plotted per ray. marker_size: the size of the ray point markers. line_width: the width of the ray lines. """ n_pts_per_ray = ray_bundle.lengths.shape[-1] n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16] # flatten all batches of rays into a single big bundle ray_bundle_flat = RayBundle( **{ attr: torch.flatten( getattr(ray_bundle, attr), start_dim=0, end_dim=-2) for attr in ["origins", "directions", "lengths", "xys"] }) # subsample the rays (if needed) if n_rays > max_rays: indices_rays = torch.randperm(n_rays)[:max_rays] ray_bundle_flat = RayBundle( **{ attr: getattr(ray_bundle_flat, attr)[indices_rays] for attr in ["origins", "directions", "lengths", "xys"] }) # make ray line endpoints min_max_ray_depth = torch.stack( [ ray_bundle_flat.lengths.min(dim=1).values, ray_bundle_flat.lengths.max(dim=1).values, ], dim=-1, ) ray_lines_endpoints = ray_bundle_to_ray_points( ray_bundle_flat._replace(lengths=min_max_ray_depth)) # make the ray lines for plotly plotting nan_tensor = torch.Tensor([[float("NaN")] * 3]) ray_lines = torch.empty(size=(1, 3)) for ray_line in ray_lines_endpoints: # We combine the ray lines into a single tensor to plot them in a # single trace. The NaNs are inserted between sets of ray lines # so that the lines drawn by Plotly are not drawn between # lines that belong to different rays. ray_lines = torch.cat((ray_lines, nan_tensor, ray_line)) x, y, z = ray_lines.detach().cpu().numpy().T.astype(float) row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1 fig.add_trace( go.Scatter3d( x=x, y=y, z=z, marker={"size": 0.1}, line={"width": line_width}, name=trace_name, ), row=row, col=col, ) # subsample the ray points (if needed) if n_pts_per_ray > max_points_per_ray: indices_ray_pts = torch.cat([ torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray for ri in range(ray_bundle_flat.lengths.shape[0]) ]) ray_bundle_flat = ray_bundle_flat._replace( lengths=ray_bundle_flat.lengths.reshape(-1) [indices_ray_pts].reshape(ray_bundle_flat.lengths.shape[0], -1)) # plot the ray points ray_points = (ray_bundle_to_ray_points(ray_bundle_flat).view( -1, 3).detach().cpu().numpy().astype(float)) fig.add_trace( go.Scatter3d( x=ray_points[:, 0], y=ray_points[:, 1], z=ray_points[:, 2], mode="markers", name=trace_name + "_points", marker={"size": marker_size}, ), row=row, col=col, ) # Access the current subplot's scene configuration plot_scene = "scene" + str(subplot_idx + 1) current_layout = fig["layout"][plot_scene] # update the bounds of the axes for the current trace all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3) ray_points_center = all_ray_points.mean(dim=0) max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item() _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
def test_input_types(self, batch_size: int = 10): """ Check that ValueErrors are thrown where expected. """ # check the constructor for bad_raysampler in (None, 5, []): for bad_raymarcher in (None, 5, []): with self.assertRaises(ValueError): VolumeRenderer(raysampler=bad_raysampler, raymarcher=bad_raymarcher) raysampler = NDCMultinomialRaysampler( image_width=100, image_height=100, n_pts_per_ray=10, min_depth=0.1, max_depth=1.0, ) # init a trivial renderer renderer = VolumeRenderer(raysampler=raysampler, raymarcher=EmissionAbsorptionRaymarcher()) # get cameras cameras = init_cameras(batch_size=batch_size) # get volumes volumes = init_boundary_volume(volume_size=(10, 10, 10), batch_size=batch_size)[0] # different batch sizes for cameras / volumes with self.assertRaises(ValueError): renderer(cameras=cameras, volumes=volumes[:-1]) # ray checks for VolumeSampler volume_sampler = VolumeSampler(volumes=volumes) n_rays = 100 for bad_ray_bundle in ( ( torch.rand(batch_size, n_rays, 3), torch.rand(batch_size, n_rays + 1, 3), torch.rand(batch_size, n_rays, 10), ), ( torch.rand(batch_size + 1, n_rays, 3), torch.rand(batch_size, n_rays, 3), torch.rand(batch_size, n_rays, 10), ), ( torch.rand(batch_size, n_rays, 3), torch.rand(batch_size, n_rays, 2), torch.rand(batch_size, n_rays, 10), ), ( torch.rand(batch_size, n_rays, 3), torch.rand(batch_size, n_rays, 3), torch.rand(batch_size, n_rays), ), ): ray_bundle = RayBundle( **dict( zip( ("origins", "directions", "lengths"), [r.to(cameras.device) for r in bad_ray_bundle], )), xys=None, ) with self.assertRaises(ValueError): volume_sampler(ray_bundle) # check also explicitly the ray bundle validation function with self.assertRaises(ValueError): _validate_ray_bundle_variables(*bad_ray_bundle)
def forward( self, cameras: CamerasBase, chunksize: int = None, chunk_idx: int = 0, camera_hash: str = None, caching: bool = False, **kwargs, ) -> RayBundle: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. chunksize: The number of rays per chunk. Active only when `self.training==False`. chunk_idx: The index of the ray chunk. The number has to be in `[0, self.get_n_chunks(chunksize, batch_size)-1]`. Active only when `self.training==False`. camera_hash: A unique identifier of a pre-cached camera. If `None`, the cache is not searched and the rays are calculated from scratch. caching: If `True`, activates the caching mode that returns the `RayBundle` that should be stored into the cache. Returns: A named tuple `RayBundle` with the following fields: origins: A tensor of shape `(batch_size, n_rays_per_image, 3)` denoting the locations of ray origins in the world coordinates. directions: A tensor of shape `(batch_size, n_rays_per_image, 3)` denoting the directions of each ray in the world coordinates. lengths: A tensor of shape `(batch_size, n_rays_per_image, n_pts_per_ray)` containing the z-coordinate (=depth) of each ray in world units. xys: A tensor of shape `(batch_size, n_rays_per_image, 2)` containing the 2D image coordinates of each ray. """ batch_size = cameras.R.shape[0] # pyre-ignore device = cameras.device if (camera_hash is None) and (not caching) and self.training: # Sample random rays from scratch. ray_bundle = self._mc_raysampler(cameras) ray_bundle = self._normalize_raybundle(ray_bundle) else: if camera_hash is not None: # The case where we retrieve a camera from cache. if batch_size != 1: raise NotImplementedError( "Ray caching works only for batches with a single camera!" ) full_ray_bundle = self._ray_cache[camera_hash] else: # We generate a full ray grid from scratch. full_ray_bundle = self._grid_raysampler(cameras) full_ray_bundle = self._normalize_raybundle(full_ray_bundle) n_pixels = full_ray_bundle.directions.shape[:-1].numel() if self.training: # During training we randomly subsample rays. sel_rays = torch.randperm(n_pixels, device=device)[ : self._mc_raysampler._n_rays_per_image ] else: # In case we test, we take only the requested chunk. if chunksize is None: chunksize = n_pixels * batch_size start = chunk_idx * chunksize * batch_size end = min(start + chunksize, n_pixels) sel_rays = torch.arange( start, end, dtype=torch.long, device=full_ray_bundle.lengths.device, ) # Take the "sel_rays" rays from the full ray bundle. ray_bundle = RayBundle( *[ v.view(n_pixels, -1)[sel_rays] .view(batch_size, sel_rays.numel() // batch_size, -1) .to(device) for v in full_ray_bundle ] ) if ( (self._stratified and self.training) or (self._stratified_test and not self.training) ) and not caching: # Make sure not to stratify when caching! ray_bundle = self._stratify_ray_bundle(ray_bundle) return ray_bundle
def batched_forward( net, ray_bundle: RayBundle, n_batches: int = 16, path=None, **kwargs, ): """ This function is used to allow for memory efficient processing of input rays. The input rays are first split to `n_batches` chunks and passed through the `self.forward` function one at a time in a for loop. Combined with disabling Pytorch gradient caching (`torch.no_grad()`), this allows for rendering large batches of rays that do not all fit into GPU memory in a single forward pass. In our case, batched_forward is used to export a fully-sized render of the radiance field for visualisation purposes. Args: ray_bundle: A RayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` containing the direction vectors of sampling rays in world coords. lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` containing the lengths at which the rays are sampled. n_batches: Specifies the number of batches the input rays are split into. The larger the number of batches, the smaller the memory footprint and the lower the processing speed. Returns: rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` denoting the opacitiy of each ray point. rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` denoting the color of each ray point. """ # Parse out shapes needed for tensor reshaping in this function. n_pts_per_ray = ray_bundle.lengths.shape[-1] spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray] # Split the rays to `n_batches` batches. tot_samples = ray_bundle.origins.shape[:-1].numel() batches = torch.chunk(torch.arange(tot_samples), n_batches) # For each batch, execute the standard forward pass. batch_outputs = [ net(RayBundle( origins=ray_bundle.origins.view(-1, 3)[batch_idx], directions=ray_bundle.directions.view(-1, 3)[batch_idx], lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx], xys=None, ), path=None if path is None else f"{path}/batch{i}") for i, batch_idx in enumerate(batches) ] # Concatenate the per-batch rays_densities and rays_colors # and reshape according to the sizes of the inputs. rays_densities, rays_colors = [ torch.cat([batch_output[output_i] for batch_output in batch_outputs], dim=0).view(*spatial_size, -1) for output_i in (0, 1) ] if path is not None: torch.save(spatial_size, f"{path}/spatial_size.pt") return rays_densities, rays_colors