def _normalize_raybundle(self, ray_bundle: RayBundle): """ Normalizes the ray directions of the input `RayBundle` to unit norm. """ ray_bundle = ray_bundle._replace( directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1) ) return ray_bundle
def _stratify_ray_bundle(self, ray_bundle: RayBundle): """ Stratifies the lengths of the input `ray_bundle`. More specifically, the stratification replaces each ray points' depth `z` with a sample from a uniform random distribution on `[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference of depths of the consecutive ray depth values. Args: `ray_bundle`: The input `RayBundle`. Returns: `stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced with the stratified samples. """ z_vals = ray_bundle.lengths # Get intervals between samples. mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat((mids, z_vals[..., -1:]), dim=-1) lower = torch.cat((z_vals[..., :1], mids), dim=-1) # Stratified samples in those intervals. z_vals = lower + (upper - lower) * torch.rand_like(lower) return ray_bundle._replace(lengths=z_vals)
def _add_ray_bundle_trace( fig: go.Figure, ray_bundle: RayBundle, trace_name: str, subplot_idx: int, ncols: int, max_rays: int, max_points_per_ray: int, marker_size: int, line_width: int, ): # pragma: no cover """ Adds a trace rendering a RayBundle object to the passed in figure, with a given name and in a specific subplot. Args: fig: plotly figure to add the trace within. cameras: the Cameras object to render. It can be batched. trace_name: name to label the trace with. subplot_idx: identifies the subplot, with 0 being the top left. ncols: the number of subplots per row. max_rays: maximum number of plotted rays in total. Randomly subsamples without replacement in case the number of rays is bigger than max_rays. max_points_per_ray: maximum number of points plotted per ray. marker_size: the size of the ray point markers. line_width: the width of the ray lines. """ n_pts_per_ray = ray_bundle.lengths.shape[-1] n_rays = ray_bundle.lengths.shape[:-1].numel() # pyre-ignore[16] # flatten all batches of rays into a single big bundle ray_bundle_flat = RayBundle( **{ attr: torch.flatten( getattr(ray_bundle, attr), start_dim=0, end_dim=-2) for attr in ["origins", "directions", "lengths", "xys"] }) # subsample the rays (if needed) if n_rays > max_rays: indices_rays = torch.randperm(n_rays)[:max_rays] ray_bundle_flat = RayBundle( **{ attr: getattr(ray_bundle_flat, attr)[indices_rays] for attr in ["origins", "directions", "lengths", "xys"] }) # make ray line endpoints min_max_ray_depth = torch.stack( [ ray_bundle_flat.lengths.min(dim=1).values, ray_bundle_flat.lengths.max(dim=1).values, ], dim=-1, ) ray_lines_endpoints = ray_bundle_to_ray_points( ray_bundle_flat._replace(lengths=min_max_ray_depth)) # make the ray lines for plotly plotting nan_tensor = torch.Tensor([[float("NaN")] * 3]) ray_lines = torch.empty(size=(1, 3)) for ray_line in ray_lines_endpoints: # We combine the ray lines into a single tensor to plot them in a # single trace. The NaNs are inserted between sets of ray lines # so that the lines drawn by Plotly are not drawn between # lines that belong to different rays. ray_lines = torch.cat((ray_lines, nan_tensor, ray_line)) x, y, z = ray_lines.detach().cpu().numpy().T.astype(float) row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1 fig.add_trace( go.Scatter3d( x=x, y=y, z=z, marker={"size": 0.1}, line={"width": line_width}, name=trace_name, ), row=row, col=col, ) # subsample the ray points (if needed) if n_pts_per_ray > max_points_per_ray: indices_ray_pts = torch.cat([ torch.randperm(n_pts_per_ray)[:max_points_per_ray] + ri * n_pts_per_ray for ri in range(ray_bundle_flat.lengths.shape[0]) ]) ray_bundle_flat = ray_bundle_flat._replace( lengths=ray_bundle_flat.lengths.reshape(-1) [indices_ray_pts].reshape(ray_bundle_flat.lengths.shape[0], -1)) # plot the ray points ray_points = (ray_bundle_to_ray_points(ray_bundle_flat).view( -1, 3).detach().cpu().numpy().astype(float)) fig.add_trace( go.Scatter3d( x=ray_points[:, 0], y=ray_points[:, 1], z=ray_points[:, 2], mode="markers", name=trace_name + "_points", marker={"size": marker_size}, ), row=row, col=col, ) # Access the current subplot's scene configuration plot_scene = "scene" + str(subplot_idx + 1) current_layout = fig["layout"][plot_scene] # update the bounds of the axes for the current trace all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3) ray_points_center = all_ray_points.mean(dim=0) max_expand = (all_ray_points.max(0)[0] - all_ray_points.min(0)[0]).max().item() _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
def forward( self, ray_bundle: RayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, ) -> RendererOutput: """ Args: ray_bundle: A `RayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: A single-element list of ImplicitFunctionWrappers which defines the implicit function to be used. evaluation_mode: one of EvaluationMode.TRAINING or EvaluationMode.EVALUATION which determines the settings used for rendering, specifically the RayPointRefiner and the density_noise_std. Returns: instance of RendererOutput """ if len(implicit_functions) != 1: raise ValueError( "LSTM renderer expects a single implicit function.") implicit_function = implicit_functions[0] if ray_bundle.lengths.shape[-1] != 1: raise ValueError( "LSTM renderer requires a ray-bundle with a single point per ray" + " which is the initial raymarching point.") # jitter the initial depths ray_bundle_t = ray_bundle._replace( lengths=ray_bundle.lengths + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std) states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] signed_distance = torch.zeros_like(ray_bundle_t.lengths) raymarch_features = None for t in range(self.num_raymarch_steps + 1): # move signed_distance along each ray ray_bundle_t = ray_bundle_t._replace(lengths=ray_bundle_t.lengths + signed_distance) # eval the raymarching function raymarch_features, _ = implicit_function( ray_bundle_t, raymarch_features=None, ) if self.verbose: msg = ( f"{t}: mu={float(signed_distance.mean()):1.2e};" + f" std={float(signed_distance.std()):1.2e};" # pyre-fixme[6]: Expected `Union[bytearray, bytes, str, # typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st # param but got `Tensor`. + f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};" # pyre-fixme[6]: Expected `Union[bytearray, bytes, str, # typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st # param but got `Tensor`. + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};") logger.info(msg) if t == self.num_raymarch_steps: break # run the lstm marcher # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. state_h, state_c = self._lstm( raymarch_features.view(-1, raymarch_features.shape[-1]), states[-1], ) if state_h.requires_grad: state_h.register_hook(lambda x: x.clamp(min=-10, max=10)) # predict the next step size # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. signed_distance = self._out_layer(state_h).view( ray_bundle_t.lengths.shape) # log the lstm states states.append((state_h, state_c)) opacity_logits, features = implicit_function( raymarch_features=raymarch_features, ray_bundle=ray_bundle_t, ) mask = torch.sigmoid(opacity_logits) depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm( dim=-1, keepdim=True) return RendererOutput( features=features[..., 0, :], depths=depth, masks=mask[..., 0, :], )