Example #1
0
 def _normalize_raybundle(self, ray_bundle: RayBundle):
     """
     Normalizes the ray directions of the input `RayBundle` to unit norm.
     """
     ray_bundle = ray_bundle._replace(
         directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
     )
     return ray_bundle
Example #2
0
    def _stratify_ray_bundle(self, ray_bundle: RayBundle):
        """
        Stratifies the lengths of the input `ray_bundle`.

        More specifically, the stratification replaces each ray points' depth `z`
        with a sample from a uniform random distribution on
        `[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference
        of depths of the consecutive ray depth values.

        Args:
            `ray_bundle`: The input `RayBundle`.

        Returns:
            `stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced
                with the stratified samples.
        """
        z_vals = ray_bundle.lengths
        # Get intervals between samples.
        mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat((mids, z_vals[..., -1:]), dim=-1)
        lower = torch.cat((z_vals[..., :1], mids), dim=-1)
        # Stratified samples in those intervals.
        z_vals = lower + (upper - lower) * torch.rand_like(lower)
        return ray_bundle._replace(lengths=z_vals)
Example #3
0
def _add_ray_bundle_trace(
    fig: go.Figure,
    ray_bundle: RayBundle,
    trace_name: str,
    subplot_idx: int,
    ncols: int,
    max_rays: int,
    max_points_per_ray: int,
    marker_size: int,
    line_width: int,
):  # pragma: no cover
    """
    Adds a trace rendering a RayBundle object to the passed in figure, with
    a given name and in a specific subplot.

    Args:
        fig: plotly figure to add the trace within.
        cameras: the Cameras object to render. It can be batched.
        trace_name: name to label the trace with.
        subplot_idx: identifies the subplot, with 0 being the top left.
        ncols: the number of subplots per row.
        max_rays: maximum number of plotted rays in total. Randomly subsamples
            without replacement in case the number of rays is bigger than max_rays.
        max_points_per_ray: maximum number of points plotted per ray.
        marker_size: the size of the ray point markers.
        line_width: the width of the ray lines.
    """

    n_pts_per_ray = ray_bundle.lengths.shape[-1]
    n_rays = ray_bundle.lengths.shape[:-1].numel()  # pyre-ignore[16]

    # flatten all batches of rays into a single big bundle
    ray_bundle_flat = RayBundle(
        **{
            attr: torch.flatten(
                getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
            for attr in ["origins", "directions", "lengths", "xys"]
        })

    # subsample the rays (if needed)
    if n_rays > max_rays:
        indices_rays = torch.randperm(n_rays)[:max_rays]
        ray_bundle_flat = RayBundle(
            **{
                attr: getattr(ray_bundle_flat, attr)[indices_rays]
                for attr in ["origins", "directions", "lengths", "xys"]
            })

    # make ray line endpoints
    min_max_ray_depth = torch.stack(
        [
            ray_bundle_flat.lengths.min(dim=1).values,
            ray_bundle_flat.lengths.max(dim=1).values,
        ],
        dim=-1,
    )
    ray_lines_endpoints = ray_bundle_to_ray_points(
        ray_bundle_flat._replace(lengths=min_max_ray_depth))

    # make the ray lines for plotly plotting
    nan_tensor = torch.Tensor([[float("NaN")] * 3])
    ray_lines = torch.empty(size=(1, 3))
    for ray_line in ray_lines_endpoints:
        # We combine the ray lines into a single tensor to plot them in a
        # single trace. The NaNs are inserted between sets of ray lines
        # so that the lines drawn by Plotly are not drawn between
        # lines that belong to different rays.
        ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
    x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
    row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
    fig.add_trace(
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            marker={"size": 0.1},
            line={"width": line_width},
            name=trace_name,
        ),
        row=row,
        col=col,
    )

    # subsample the ray points (if needed)
    if n_pts_per_ray > max_points_per_ray:
        indices_ray_pts = torch.cat([
            torch.randperm(n_pts_per_ray)[:max_points_per_ray] +
            ri * n_pts_per_ray
            for ri in range(ray_bundle_flat.lengths.shape[0])
        ])
        ray_bundle_flat = ray_bundle_flat._replace(
            lengths=ray_bundle_flat.lengths.reshape(-1)
            [indices_ray_pts].reshape(ray_bundle_flat.lengths.shape[0], -1))

    # plot the ray points
    ray_points = (ray_bundle_to_ray_points(ray_bundle_flat).view(
        -1, 3).detach().cpu().numpy().astype(float))
    fig.add_trace(
        go.Scatter3d(
            x=ray_points[:, 0],
            y=ray_points[:, 1],
            z=ray_points[:, 2],
            mode="markers",
            name=trace_name + "_points",
            marker={"size": marker_size},
        ),
        row=row,
        col=col,
    )

    # Access the current subplot's scene configuration
    plot_scene = "scene" + str(subplot_idx + 1)
    current_layout = fig["layout"][plot_scene]

    # update the bounds of the axes for the current trace
    all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
    ray_points_center = all_ray_points.mean(dim=0)
    max_expand = (all_ray_points.max(0)[0] -
                  all_ray_points.min(0)[0]).max().item()
    _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
Example #4
0
    def forward(
        self,
        ray_bundle: RayBundle,
        implicit_functions: List[ImplicitFunctionWrapper],
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
        **kwargs,
    ) -> RendererOutput:
        """

        Args:
            ray_bundle: A `RayBundle` object containing the parametrizations of the
                sampled rendering rays.
            implicit_functions: A single-element list of ImplicitFunctionWrappers which
                defines the implicit function to be used.
            evaluation_mode: one of EvaluationMode.TRAINING or
                EvaluationMode.EVALUATION which determines the settings used for
                rendering, specifically the RayPointRefiner and the density_noise_std.

        Returns:
            instance of RendererOutput
        """
        if len(implicit_functions) != 1:
            raise ValueError(
                "LSTM renderer expects a single implicit function.")

        implicit_function = implicit_functions[0]

        if ray_bundle.lengths.shape[-1] != 1:
            raise ValueError(
                "LSTM renderer requires a ray-bundle with a single point per ray"
                + " which is the initial raymarching point.")

        # jitter the initial depths
        ray_bundle_t = ray_bundle._replace(
            lengths=ray_bundle.lengths +
            torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std)

        states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
        signed_distance = torch.zeros_like(ray_bundle_t.lengths)
        raymarch_features = None
        for t in range(self.num_raymarch_steps + 1):
            # move signed_distance along each ray
            ray_bundle_t = ray_bundle_t._replace(lengths=ray_bundle_t.lengths +
                                                 signed_distance)

            # eval the raymarching function
            raymarch_features, _ = implicit_function(
                ray_bundle_t,
                raymarch_features=None,
            )
            if self.verbose:
                msg = (
                    f"{t}: mu={float(signed_distance.mean()):1.2e};" +
                    f" std={float(signed_distance.std()):1.2e};"
                    # pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
                    #  typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
                    #  param but got `Tensor`.
                    + f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
                    # pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
                    #  typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
                    #  param but got `Tensor`.
                    + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};")
                logger.info(msg)
            if t == self.num_raymarch_steps:
                break

            # run the lstm marcher
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
            state_h, state_c = self._lstm(
                raymarch_features.view(-1, raymarch_features.shape[-1]),
                states[-1],
            )
            if state_h.requires_grad:
                state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
            # predict the next step size
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
            signed_distance = self._out_layer(state_h).view(
                ray_bundle_t.lengths.shape)
            # log the lstm states
            states.append((state_h, state_c))

        opacity_logits, features = implicit_function(
            raymarch_features=raymarch_features,
            ray_bundle=ray_bundle_t,
        )
        mask = torch.sigmoid(opacity_logits)
        depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm(
            dim=-1, keepdim=True)

        return RendererOutput(
            features=features[..., 0, :],
            depths=depth,
            masks=mask[..., 0, :],
        )