Example #1
0
    def forward(
        self,
        ray_bundle: RayBundle,
        fun_viewpool=None,
        camera: Optional[CamerasBase] = None,
        global_code=None,
        **kwargs,
    ):
        """
        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.
            fun_viewpool: an optional callback with the signature
                    fun_fiewpool(points) -> pooled_features
                where points is a [N_TGT x N x 3] tensor of world coords,
                and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
                of the features pooled from the context images.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacitiy of each ray point.
            rays_colors: Set to None.
        """
        # We first convert the ray parametrizations to world
        # coordinates with `ray_bundle_to_ray_points`.
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)

        embeds = create_embeddings_for_implicit_function(
            xyz_world=ray_bundle_to_ray_points(ray_bundle),
            # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
            #  for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
            xyz_embedding_function=self._harmonic_embedding,
            global_code=global_code,
            fun_viewpool=fun_viewpool,
            xyz_in_camera_coords=self.xyz_in_camera_coords,
            camera=camera,
        )

        # Before running the network, we have to resize embeds to ndims=3,
        # otherwise the SRN layers consume huge amounts of memory.
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
        raymarch_features = self._net(
            embeds.view(embeds.shape[0], -1, embeds.shape[-1]))
        # raymarch_features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]

        # NNs operate on the flattenned rays; reshaping to the correct spatial size
        raymarch_features = raymarch_features.reshape(
            *rays_points_world.shape[:-1], -1)

        return raymarch_features, None
Example #2
0
def spherical_volumetric_function(
    ray_bundle: RayBundle,
    sphere_centroid: torch.Tensor,
    sphere_diameter: float,
    **kwargs,
):
    """
    Volumetric function of a simple RGB sphere with diameter `sphere_diameter`
    and centroid `sphere_centroid`.
    """
    # convert the ray bundle to world points
    rays_points_world = ray_bundle_to_ray_points(ray_bundle)
    batch_size = rays_points_world.shape[0]

    # surface_vectors = vectors from world coords towards the sphere centroid
    surface_vectors = (rays_points_world.view(batch_size, -1, 3) -
                       sphere_centroid[:, None])

    # the squared distance of each ray point to the centroid of the sphere
    surface_dist = ((surface_vectors**2).sum(-1, keepdim=True).view(
        *rays_points_world.shape[:-1], 1))

    # set all ray densities within the sphere_diameter distance from the centroid to 1
    rays_densities = torch.sigmoid(-100.0 *
                                   (surface_dist - sphere_diameter**2))

    # ray colors are proportional to the normalized surface_vectors
    rays_features = (torch.nn.functional.normalize(
        surface_vectors.view(rays_points_world.shape), dim=-1) * 0.5 + 0.5)

    return rays_densities, rays_features
Example #3
0
def visualize_nerf_outputs(nerf_out: dict, output_cache: List, viz: Visdom,
                           visdom_env: str):
    """
    Visualizes the outputs of the `RadianceFieldRenderer`.

    Args:
        nerf_out: An output of the validation rendering pass.
        output_cache: A list with outputs of several training render passes.
        viz: A visdom connection object.
        visdom_env: The name of visdom environment for visualization.
    """

    # Show the training images.
    ims = torch.stack([o["image"] for o in output_cache])
    ims = torch.cat(list(ims), dim=1)
    viz.image(
        ims.permute(2, 0, 1),
        env=visdom_env,
        win="images",
        opts={"title": "train_images"},
    )

    # Show the coarse and fine renders together with the ground truth images.
    ims_full = torch.cat(
        [
            nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
            for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
        ],
        dim=2,
    )
    viz.image(
        ims_full,
        env=visdom_env,
        win="images_full",
        opts={"title": "coarse | fine | target"},
    )

    # Make a 3D plot of training cameras and their emitted rays.
    camera_trace = {
        f"camera_{ci:03d}": o["camera"].cpu()
        for ci, o in enumerate(output_cache)
    }
    ray_pts_trace = {
        f"ray_pts_{ci:03d}": Pointclouds(
            ray_bundle_to_ray_points(
                o["coarse_ray_bundle"]).detach().cpu().view(1, -1, 3))
        for ci, o in enumerate(output_cache)
    }
    plotly_plot = plot_scene(
        {
            "training_scene": {
                **camera_trace,
                **ray_pts_trace,
            },
        },
        pointcloud_max_points=5000,
        pointcloud_marker_size=1,
        camera_scale=0.3,
    )
    viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")
    def forward(
        self,
        ray_bundle: RayBundle,
        density_noise_std: float = 0.0,
        **kwargs,
    ):
        """
        The forward function accepts the parametrizations of
        3D points sampled along projection rays. The forward
        pass is responsible for attaching a 3D vector
        and a 1D scalar representing the point's
        RGB color and opacity respectively.

        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.
            density_noise_std: A floating point value representing the
                variance of the random normal noise added to the output of
                the opacity function. This can prevent floating artifacts.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacitiy of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.
        """
        # We first convert the ray parametrizations to world
        # coordinates with `ray_bundle_to_ray_points`.
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        # rays_points_world.shape = [minibatch x ... x 3]

        # For each 3D world coordinate, we obtain its harmonic embedding.
        embeds_xyz = torch.cat(
            (self.harmonic_embedding_xyz(rays_points_world),
             rays_points_world),
            dim=-1,
        )
        # embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]

        # self.mlp maps each harmonic embedding to a latent feature space.
        features = self.mlp_xyz(embeds_xyz, embeds_xyz)
        # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]

        rays_densities = self._get_densities(features, ray_bundle.lengths,
                                             density_noise_std)
        # rays_densities.shape = [minibatch x ... x 1] in [0-1]

        rays_colors = self._get_colors(features, ray_bundle.directions)
        # rays_colors.shape = [minibatch x ... x 3] in [0-1]

        return rays_densities, rays_colors
Example #5
0
    def forward(
        self,
        ray_bundle: RayBundle,
        path=None,
        **kwargs,
    ):
        """
        The forward function accepts the parametrizations of
        3D points sampled along projection rays. The forward
        pass is responsible for attaching a 3D vector
        and a 1D scalar representing the point's
        RGB color and opacity respectively.

        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacitiy of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.
        """
        # We first convert the ray parametrizations to world
        # coordinates with `ray_bundle_to_ray_points`.
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        # rays_points_world.shape = [minibatch x ... x 3]

        # For each 3D world coordinate, we obtain its harmonic embedding.
        embeds = self.harmonic_embedding(rays_points_world)
        # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]

        # self.mlp maps each harmonic embedding to a latent feature space.
        features = self.mlp(embeds)
        # features.shape = [minibatch x ... x n_hidden_neurons]

        # Finally, given the per-point features,
        # execute the density and color branches.

        rays_densities = self._get_densities(features)
        # rays_densities.shape = [minibatch x ... x 1]

        rays_colors = self._get_colors(features, ray_bundle.directions)
        # rays_colors.shape = [minibatch x ... x 3]

        if path is not None:
            save_io_nerf(rays_points_world, rays_densities, rays_colors, path)

        return rays_densities, rays_colors
Example #6
0
def get_rgbd_point_cloud(
    camera: CamerasBase,
    image_rgb: torch.Tensor,
    depth_map: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    mask_thr: float = 0.5,
    mask_points: bool = True,
) -> Pointclouds:
    """
    Given a batch of images, depths, masks and cameras, generate a colored
    point cloud by unprojecting depth maps to the  and coloring with the source
    pixel colors.
    """
    imh, imw = image_rgb.shape[2:]

    # convert the depth maps to point clouds using the grid ray sampler
    pts_3d = ray_bundle_to_ray_points(
        NDCMultinomialRaysampler(
            image_width=imw,
            image_height=imh,
            n_pts_per_ray=1,
            min_depth=1.0,
            max_depth=1.0,
        )(camera)._replace(lengths=depth_map[:, 0, ..., None])
    )

    pts_mask = depth_map > 0.0
    if mask is not None:
        pts_mask *= mask > mask_thr
    pts_mask = pts_mask.reshape(-1)

    pts_3d = pts_3d.reshape(-1, 3)[pts_mask]

    pts_colors = torch.nn.functional.interpolate(
        image_rgb,
        # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
        #  `List[typing.Any]`.
        size=[imh, imw],
        mode="bilinear",
        align_corners=False,
    )
    pts_colors = pts_colors.permute(0, 2, 3, 1).reshape(-1, 3)[pts_mask]

    return Pointclouds(points=pts_3d[None], features=pts_colors[None])
Example #7
0
def _add_ray_bundle_trace(
    fig: go.Figure,
    ray_bundle: RayBundle,
    trace_name: str,
    subplot_idx: int,
    ncols: int,
    max_rays: int,
    max_points_per_ray: int,
    marker_size: int,
    line_width: int,
):  # pragma: no cover
    """
    Adds a trace rendering a RayBundle object to the passed in figure, with
    a given name and in a specific subplot.

    Args:
        fig: plotly figure to add the trace within.
        cameras: the Cameras object to render. It can be batched.
        trace_name: name to label the trace with.
        subplot_idx: identifies the subplot, with 0 being the top left.
        ncols: the number of subplots per row.
        max_rays: maximum number of plotted rays in total. Randomly subsamples
            without replacement in case the number of rays is bigger than max_rays.
        max_points_per_ray: maximum number of points plotted per ray.
        marker_size: the size of the ray point markers.
        line_width: the width of the ray lines.
    """

    n_pts_per_ray = ray_bundle.lengths.shape[-1]
    n_rays = ray_bundle.lengths.shape[:-1].numel()  # pyre-ignore[16]

    # flatten all batches of rays into a single big bundle
    ray_bundle_flat = RayBundle(
        **{
            attr: torch.flatten(
                getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
            for attr in ["origins", "directions", "lengths", "xys"]
        })

    # subsample the rays (if needed)
    if n_rays > max_rays:
        indices_rays = torch.randperm(n_rays)[:max_rays]
        ray_bundle_flat = RayBundle(
            **{
                attr: getattr(ray_bundle_flat, attr)[indices_rays]
                for attr in ["origins", "directions", "lengths", "xys"]
            })

    # make ray line endpoints
    min_max_ray_depth = torch.stack(
        [
            ray_bundle_flat.lengths.min(dim=1).values,
            ray_bundle_flat.lengths.max(dim=1).values,
        ],
        dim=-1,
    )
    ray_lines_endpoints = ray_bundle_to_ray_points(
        ray_bundle_flat._replace(lengths=min_max_ray_depth))

    # make the ray lines for plotly plotting
    nan_tensor = torch.Tensor([[float("NaN")] * 3])
    ray_lines = torch.empty(size=(1, 3))
    for ray_line in ray_lines_endpoints:
        # We combine the ray lines into a single tensor to plot them in a
        # single trace. The NaNs are inserted between sets of ray lines
        # so that the lines drawn by Plotly are not drawn between
        # lines that belong to different rays.
        ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
    x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
    row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
    fig.add_trace(
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            marker={"size": 0.1},
            line={"width": line_width},
            name=trace_name,
        ),
        row=row,
        col=col,
    )

    # subsample the ray points (if needed)
    if n_pts_per_ray > max_points_per_ray:
        indices_ray_pts = torch.cat([
            torch.randperm(n_pts_per_ray)[:max_points_per_ray] +
            ri * n_pts_per_ray
            for ri in range(ray_bundle_flat.lengths.shape[0])
        ])
        ray_bundle_flat = ray_bundle_flat._replace(
            lengths=ray_bundle_flat.lengths.reshape(-1)
            [indices_ray_pts].reshape(ray_bundle_flat.lengths.shape[0], -1))

    # plot the ray points
    ray_points = (ray_bundle_to_ray_points(ray_bundle_flat).view(
        -1, 3).detach().cpu().numpy().astype(float))
    fig.add_trace(
        go.Scatter3d(
            x=ray_points[:, 0],
            y=ray_points[:, 1],
            z=ray_points[:, 2],
            mode="markers",
            name=trace_name + "_points",
            marker={"size": marker_size},
        ),
        row=row,
        col=col,
    )

    # Access the current subplot's scene configuration
    plot_scene = "scene" + str(subplot_idx + 1)
    current_layout = fig["layout"][plot_scene]

    # update the bounds of the axes for the current trace
    all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
    ray_points_center = all_ray_points.mean(dim=0)
    max_expand = (all_ray_points.max(0)[0] -
                  all_ray_points.min(0)[0]).max().item()
    _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
Example #8
0
    def forward(
        self,
        ray_bundle: RayBundle,
        fun_viewpool=None,
        camera: Optional[CamerasBase] = None,
        global_code=None,
        **kwargs,
    ):
        """
        The forward function accepts the parametrizations of
        3D points sampled along projection rays. The forward
        pass is responsible for attaching a 3D vector
        and a 1D scalar representing the point's
        RGB color and opacity respectively.

        Args:
            ray_bundle: A RayBundle object containing the following variables:
                origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                    origins of the sampling rays in world coords.
                directions: A tensor of shape `(minibatch, ..., 3)`
                    containing the direction vectors of sampling rays in world coords.
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                    containing the lengths at which the rays are sampled.
            fun_viewpool: an optional callback with the signature
                    fun_fiewpool(points) -> pooled_features
                where points is a [N_TGT x N x 3] tensor of world coords,
                and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
                of the features pooled from the context images.

        Returns:
            rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
                denoting the opacitiy of each ray point.
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
                denoting the color of each ray point.
        """
        # We first convert the ray parametrizations to world
        # coordinates with `ray_bundle_to_ray_points`.
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]

        embeds = create_embeddings_for_implicit_function(
            xyz_world=ray_bundle_to_ray_points(ray_bundle),
            # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
            #  for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
            xyz_embedding_function=self.harmonic_embedding_xyz
            if self.input_xyz else None,
            global_code=global_code,
            fun_viewpool=fun_viewpool,
            xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
            camera=camera,
        )

        # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
        features = self.xyz_encoder(embeds)
        # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
        # NNs operate on the flattenned rays; reshaping to the correct spatial size
        # TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
        features = features.reshape(*rays_points_world.shape[:-1], -1)

        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
        raw_densities = self.density_layer(features)
        # raw_densities.shape = [minibatch x ... x 1] in [0-1]

        if self.xyz_ray_dir_in_camera_coords:
            if camera is None:
                raise ValueError(
                    "Camera must be given if xyz_ray_dir_in_camera_coords")

            directions = ray_bundle.directions @ camera.R
        else:
            directions = ray_bundle.directions

        rays_colors = self._get_colors(features, directions)
        # rays_colors.shape = [minibatch x ... x 3] in [0-1]

        return raw_densities, rays_colors, {}