Exemple #1
0
    def precache_rays(self, cameras: List[CamerasBase], camera_hashes: List):
        """
        Precaches the rays emitted from the list of cameras `cameras`,
        where each camera is uniquely identified with the corresponding hash
        from `camera_hashes`.

        The cached rays are moved to cpu and stored in `self._ray_cache`.
        Raises `ValueError` when caching two cameras with the same hash.

        Args:
            cameras: A list of `N` cameras for which the rays are pre-cached.
            camera_hashes: A list of `N` unique identifiers of each
                camera from `cameras`.
        """
        print(f"Precaching {len(cameras)} ray bundles ...")
        full_chunksize = (
            self._grid_raysampler._xy_grid.numel()
            // 2
            * self._grid_raysampler._n_pts_per_ray
        )
        if self.get_n_chunks(full_chunksize, 1) != 1:
            raise ValueError("There has to be one chunk for precaching rays!")
        for camera_i, (camera, camera_hash) in enumerate(zip(cameras, camera_hashes)):
            ray_bundle = self.forward(
                camera,
                caching=True,
                chunksize=full_chunksize,
            )
            if camera_hash in self._ray_cache:
                raise ValueError("There are redundant cameras!")
            self._ray_cache[camera_hash] = RayBundle(
                *[v.to("cpu").detach() for v in ray_bundle]
            )
            self._print_precaching_progress(camera_i, len(cameras))
        print("")
Exemple #2
0
 def _get_bundle(self, *, device) -> RayBundle:
     origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
     directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
     lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
     bundle = RayBundle(
         lengths=lengths, origins=origins, directions=directions, xys=None
     )
     return bundle
Exemple #3
0
 def _normalize_raybundle(self, ray_bundle: RayBundle):
     """
     Normalizes the ray directions of the input `RayBundle` to unit norm.
     """
     ray_bundle = ray_bundle._replace(
         directions=torch.nn.functional.normalize(ray_bundle.directions, dim=-1)
     )
     return ray_bundle
Exemple #4
0
    def forward(
            self,
            input_ray_bundle: RayBundle,
            ray_weights: torch.Tensor,
            **kwargs,
    ) -> RayBundle:
        """
        Args:
            input_ray_bundle: An instance of `RayBundle` specifying the
                source rays for sampling of the probability distribution.
            ray_weights: A tensor of shape
                `(..., input_ray_bundle.legths.shape[-1])` with non-negative
                elements defining the probability distribution to sample
                ray points from.

        Returns:
            ray_bundle: A new `RayBundle` instance containing the input ray
                points together with `n_pts_per_ray` additional sampled
                points per ray.
        """

        # Calculate the mid-points between the ray depths.
        z_vals = input_ray_bundle.lengths
        batch_size = z_vals.shape[0]
        z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])

        # Carry out the importance sampling.
        z_samples = (
            sample_pdf(
                z_vals_mid.view(-1, z_vals_mid.shape[-1]),
                ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
                self._n_pts_per_ray,
                det=not (
                        (self._stratified and self.training)
                        or (self._stratified_test and not self.training)
                ),
            )
                .detach()
                .view(batch_size, z_vals.shape[1], self._n_pts_per_ray)
        )

        if self._add_input_samples:
            # Add the new samples to the input ones.
            z_vals = torch.cat((z_vals, z_samples), dim=-1)
        else:
            z_vals = z_samples
        # Resort by depth.
        z_vals, _ = torch.sort(z_vals, dim=-1)

        return RayBundle(
            origins=input_ray_bundle.origins,
            directions=input_ray_bundle.directions,
            lengths=z_vals,
            xys=input_ray_bundle.xys,
        )
Exemple #5
0
def _add_struct_from_batch(
    batched_struct: Struct,
    scene_num: int,
    subplot_title: str,
    scene_dictionary: Dict[str, Dict[str, Struct]],
    trace_idx: int = 1,
):  # pragma: no cover
    """
    Adds the struct corresponding to the given scene_num index to
    a provided scene_dictionary to be passed in to plot_scene

    Args:
        batched_struct: the batched data structure to add to the dict
        scene_num: the subplot from plot_batch_individually which this struct
            should be added to
        subplot_title: the title of the subplot
        scene_dictionary: the dictionary to add the indexed struct to
        trace_idx: the trace number, starting at 1 for this struct's trace
    """
    struct = None
    if isinstance(batched_struct, CamerasBase):
        # we can't index directly into camera batches
        R, T = batched_struct.R, batched_struct.T
        # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
        #  torch.nn.Module]`.
        r_idx = min(scene_num, len(R) - 1)
        # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
        #  torch.nn.Module]`.
        t_idx = min(scene_num, len(T) - 1)
        # pyre-fixme[29]:
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
        #  torch.Tensor, torch.nn.Module]` is not a function.
        R = R[r_idx].unsqueeze(0)
        # pyre-fixme[29]:
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
        #  torch.Tensor, torch.nn.Module]` is not a function.
        T = T[t_idx].unsqueeze(0)
        struct = CamerasBase(device=batched_struct.device, R=R, T=T)
    elif isinstance(batched_struct, RayBundle):
        # for RayBundle we treat the 1st dim as the batch index
        struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
        struct = RayBundle(
            **{
                attr: getattr(batched_struct, attr)[struct_idx]
                for attr in ["origins", "directions", "lengths", "xys"]
            })
    else:  # batched meshes and pointclouds are indexable
        struct_idx = min(scene_num, len(batched_struct) - 1)
        struct = batched_struct[struct_idx]
    trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
    scene_dictionary[subplot_title][trace_name] = struct
Exemple #6
0
def _chunk_generator(
    chunk_size: int,
    ray_bundle: RayBundle,
    object_mask: Optional[torch.Tensor],
    tqdm_trigger_threshold: int,
    *args,
    **kwargs,
):
    """
    Helper function which yields chunks of rays from the
    input ray_bundle, to be used when the number of rays is
    large and will not fit in memory for rendering.
    """
    (
        batch_size,
        *spatial_dim,
        n_pts_per_ray,
    ) = ray_bundle.lengths.shape  # B x ... x n_pts_per_ray
    if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
        raise ValueError(f"chunk_size_grid ({chunk_size}) should be divisible "
                         f"by n_pts_per_ray ({n_pts_per_ray})")

    n_rays = math.prod(spatial_dim)
    # special handling for raytracing-based methods
    n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
    chunk_size_in_rays = -(-n_rays // n_chunks)

    iter = range(0, n_rays, chunk_size_in_rays)
    if len(iter) >= tqdm_trigger_threshold:
        iter = tqdm.tqdm(iter)

    for start_idx in iter:
        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
        ray_bundle_chunk = RayBundle(
            origins=ray_bundle.origins.reshape(batch_size, -1,
                                               3)[:, start_idx:end_idx],
            directions=ray_bundle.directions.reshape(batch_size, -1,
                                                     3)[:, start_idx:end_idx],
            lengths=ray_bundle.lengths.reshape(
                batch_size, math.prod(spatial_dim),
                n_pts_per_ray)[:, start_idx:end_idx],
            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:,
                                                          start_idx:end_idx],
        )
        extra_args = kwargs.copy()
        if object_mask is not None:
            extra_args["object_mask"] = object_mask.reshape(
                batch_size, -1, 1)[:, start_idx:end_idx]
        yield [ray_bundle_chunk, *args], extra_args
    def test_simple(self):
        length = 15
        n_pts_per_ray = 10

        for add_input_samples in [False, True]:
            ray_point_refiner = RayPointRefiner(
                n_pts_per_ray=n_pts_per_ray,
                random_sampling=False,
                add_input_samples=add_input_samples,
            )
            lengths = torch.arange(length,
                                   dtype=torch.float32).expand(3, 25, length)
            bundle = RayBundle(lengths=lengths,
                               origins=None,
                               directions=None,
                               xys=None)
            weights = torch.ones(3, 25, length)
            refined = ray_point_refiner(bundle, weights)

            self.assertIsNone(refined.directions)
            self.assertIsNone(refined.origins)
            self.assertIsNone(refined.xys)
            expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
            expected = expected.expand(3, 25, n_pts_per_ray)
            if add_input_samples:
                full_expected = torch.cat((lengths, expected),
                                          dim=-1).sort()[0]
            else:
                full_expected = expected
            self.assertClose(refined.lengths, full_expected)

            ray_point_refiner_random = RayPointRefiner(
                n_pts_per_ray=n_pts_per_ray,
                random_sampling=True,
                add_input_samples=add_input_samples,
            )
            refined_random = ray_point_refiner_random(bundle, weights)
            lengths_random = refined_random.lengths
            self.assertEqual(lengths_random.shape, full_expected.shape)
            if not add_input_samples:
                self.assertGreater(lengths_random.min().item(), 0.5)
                self.assertLess(lengths_random.max().item(), length - 1.5)

            # Check sorted
            self.assertTrue(
                (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all())
Exemple #8
0
    def forward(
        self,
        input_ray_bundle: RayBundle,
        ray_weights: torch.Tensor,
        **kwargs,
    ) -> RayBundle:
        """
        Args:
            input_ray_bundle: An instance of `RayBundle` specifying the
                source rays for sampling of the probability distribution.
            ray_weights: A tensor of shape
                `(..., input_ray_bundle.legths.shape[-1])` with non-negative
                elements defining the probability distribution to sample
                ray points from.

        Returns:
            ray_bundle: A new `RayBundle` instance containing the input ray
                points together with `n_pts_per_ray` additionally sampled
                points per ray. For each ray, the lengths are sorted.
        """

        z_vals = input_ray_bundle.lengths
        with torch.no_grad():
            z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
            z_samples = sample_pdf(
                z_vals_mid.view(-1, z_vals_mid.shape[-1]),
                ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
                self.n_pts_per_ray,
                det=not self.random_sampling,
            ).view(*z_vals.shape[:-1], self.n_pts_per_ray)

        if self.add_input_samples:
            # Add the new samples to the input ones.
            z_vals = torch.cat((z_vals, z_samples), dim=-1)
        else:
            z_vals = z_samples
        # Resort by depth.
        z_vals, _ = torch.sort(z_vals, dim=-1)

        return RayBundle(
            origins=input_ray_bundle.origins,
            directions=input_ray_bundle.directions,
            lengths=z_vals,
            xys=input_ray_bundle.xys,
        )
Exemple #9
0
    def _stratify_ray_bundle(self, ray_bundle: RayBundle):
        """
        Stratifies the lengths of the input `ray_bundle`.

        More specifically, the stratification replaces each ray points' depth `z`
        with a sample from a uniform random distribution on
        `[z - delta_depth, z+delta_depth]`, where `delta_depth` is the difference
        of depths of the consecutive ray depth values.

        Args:
            `ray_bundle`: The input `RayBundle`.

        Returns:
            `stratified_ray_bundle`: `ray_bundle` whose `lengths` field is replaced
                with the stratified samples.
        """
        z_vals = ray_bundle.lengths
        # Get intervals between samples.
        mids = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat((mids, z_vals[..., -1:]), dim=-1)
        lower = torch.cat((z_vals[..., :1], mids), dim=-1)
        # Stratified samples in those intervals.
        z_vals = lower + (upper - lower) * torch.rand_like(lower)
        return ray_bundle._replace(lengths=z_vals)
Exemple #10
0
def _add_ray_bundle_trace(
    fig: go.Figure,
    ray_bundle: RayBundle,
    trace_name: str,
    subplot_idx: int,
    ncols: int,
    max_rays: int,
    max_points_per_ray: int,
    marker_size: int,
    line_width: int,
):  # pragma: no cover
    """
    Adds a trace rendering a RayBundle object to the passed in figure, with
    a given name and in a specific subplot.

    Args:
        fig: plotly figure to add the trace within.
        cameras: the Cameras object to render. It can be batched.
        trace_name: name to label the trace with.
        subplot_idx: identifies the subplot, with 0 being the top left.
        ncols: the number of subplots per row.
        max_rays: maximum number of plotted rays in total. Randomly subsamples
            without replacement in case the number of rays is bigger than max_rays.
        max_points_per_ray: maximum number of points plotted per ray.
        marker_size: the size of the ray point markers.
        line_width: the width of the ray lines.
    """

    n_pts_per_ray = ray_bundle.lengths.shape[-1]
    n_rays = ray_bundle.lengths.shape[:-1].numel()  # pyre-ignore[16]

    # flatten all batches of rays into a single big bundle
    ray_bundle_flat = RayBundle(
        **{
            attr: torch.flatten(
                getattr(ray_bundle, attr), start_dim=0, end_dim=-2)
            for attr in ["origins", "directions", "lengths", "xys"]
        })

    # subsample the rays (if needed)
    if n_rays > max_rays:
        indices_rays = torch.randperm(n_rays)[:max_rays]
        ray_bundle_flat = RayBundle(
            **{
                attr: getattr(ray_bundle_flat, attr)[indices_rays]
                for attr in ["origins", "directions", "lengths", "xys"]
            })

    # make ray line endpoints
    min_max_ray_depth = torch.stack(
        [
            ray_bundle_flat.lengths.min(dim=1).values,
            ray_bundle_flat.lengths.max(dim=1).values,
        ],
        dim=-1,
    )
    ray_lines_endpoints = ray_bundle_to_ray_points(
        ray_bundle_flat._replace(lengths=min_max_ray_depth))

    # make the ray lines for plotly plotting
    nan_tensor = torch.Tensor([[float("NaN")] * 3])
    ray_lines = torch.empty(size=(1, 3))
    for ray_line in ray_lines_endpoints:
        # We combine the ray lines into a single tensor to plot them in a
        # single trace. The NaNs are inserted between sets of ray lines
        # so that the lines drawn by Plotly are not drawn between
        # lines that belong to different rays.
        ray_lines = torch.cat((ray_lines, nan_tensor, ray_line))
    x, y, z = ray_lines.detach().cpu().numpy().T.astype(float)
    row, col = subplot_idx // ncols + 1, subplot_idx % ncols + 1
    fig.add_trace(
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            marker={"size": 0.1},
            line={"width": line_width},
            name=trace_name,
        ),
        row=row,
        col=col,
    )

    # subsample the ray points (if needed)
    if n_pts_per_ray > max_points_per_ray:
        indices_ray_pts = torch.cat([
            torch.randperm(n_pts_per_ray)[:max_points_per_ray] +
            ri * n_pts_per_ray
            for ri in range(ray_bundle_flat.lengths.shape[0])
        ])
        ray_bundle_flat = ray_bundle_flat._replace(
            lengths=ray_bundle_flat.lengths.reshape(-1)
            [indices_ray_pts].reshape(ray_bundle_flat.lengths.shape[0], -1))

    # plot the ray points
    ray_points = (ray_bundle_to_ray_points(ray_bundle_flat).view(
        -1, 3).detach().cpu().numpy().astype(float))
    fig.add_trace(
        go.Scatter3d(
            x=ray_points[:, 0],
            y=ray_points[:, 1],
            z=ray_points[:, 2],
            mode="markers",
            name=trace_name + "_points",
            marker={"size": marker_size},
        ),
        row=row,
        col=col,
    )

    # Access the current subplot's scene configuration
    plot_scene = "scene" + str(subplot_idx + 1)
    current_layout = fig["layout"][plot_scene]

    # update the bounds of the axes for the current trace
    all_ray_points = ray_bundle_to_ray_points(ray_bundle).view(-1, 3)
    ray_points_center = all_ray_points.mean(dim=0)
    max_expand = (all_ray_points.max(0)[0] -
                  all_ray_points.min(0)[0]).max().item()
    _update_axes_bounds(ray_points_center, float(max_expand), current_layout)
Exemple #11
0
    def test_input_types(self, batch_size: int = 10):
        """
        Check that ValueErrors are thrown where expected.
        """
        # check the constructor
        for bad_raysampler in (None, 5, []):
            for bad_raymarcher in (None, 5, []):
                with self.assertRaises(ValueError):
                    VolumeRenderer(raysampler=bad_raysampler,
                                   raymarcher=bad_raymarcher)

        raysampler = NDCMultinomialRaysampler(
            image_width=100,
            image_height=100,
            n_pts_per_ray=10,
            min_depth=0.1,
            max_depth=1.0,
        )

        # init a trivial renderer
        renderer = VolumeRenderer(raysampler=raysampler,
                                  raymarcher=EmissionAbsorptionRaymarcher())

        # get cameras
        cameras = init_cameras(batch_size=batch_size)

        # get volumes
        volumes = init_boundary_volume(volume_size=(10, 10, 10),
                                       batch_size=batch_size)[0]

        # different batch sizes for cameras / volumes
        with self.assertRaises(ValueError):
            renderer(cameras=cameras, volumes=volumes[:-1])

        # ray checks for VolumeSampler
        volume_sampler = VolumeSampler(volumes=volumes)
        n_rays = 100
        for bad_ray_bundle in (
            (
                torch.rand(batch_size, n_rays, 3),
                torch.rand(batch_size, n_rays + 1, 3),
                torch.rand(batch_size, n_rays, 10),
            ),
            (
                torch.rand(batch_size + 1, n_rays, 3),
                torch.rand(batch_size, n_rays, 3),
                torch.rand(batch_size, n_rays, 10),
            ),
            (
                torch.rand(batch_size, n_rays, 3),
                torch.rand(batch_size, n_rays, 2),
                torch.rand(batch_size, n_rays, 10),
            ),
            (
                torch.rand(batch_size, n_rays, 3),
                torch.rand(batch_size, n_rays, 3),
                torch.rand(batch_size, n_rays),
            ),
        ):
            ray_bundle = RayBundle(
                **dict(
                    zip(
                        ("origins", "directions", "lengths"),
                        [r.to(cameras.device) for r in bad_ray_bundle],
                    )),
                xys=None,
            )
            with self.assertRaises(ValueError):
                volume_sampler(ray_bundle)

            # check also explicitly the ray bundle validation function
            with self.assertRaises(ValueError):
                _validate_ray_bundle_variables(*bad_ray_bundle)
Exemple #12
0
    def forward(
        self,
        ray_bundle: RayBundle,
        implicit_functions: List[ImplicitFunctionWrapper],
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
        **kwargs,
    ) -> RendererOutput:
        """

        Args:
            ray_bundle: A `RayBundle` object containing the parametrizations of the
                sampled rendering rays.
            implicit_functions: A single-element list of ImplicitFunctionWrappers which
                defines the implicit function to be used.
            evaluation_mode: one of EvaluationMode.TRAINING or
                EvaluationMode.EVALUATION which determines the settings used for
                rendering, specifically the RayPointRefiner and the density_noise_std.

        Returns:
            instance of RendererOutput
        """
        if len(implicit_functions) != 1:
            raise ValueError(
                "LSTM renderer expects a single implicit function.")

        implicit_function = implicit_functions[0]

        if ray_bundle.lengths.shape[-1] != 1:
            raise ValueError(
                "LSTM renderer requires a ray-bundle with a single point per ray"
                + " which is the initial raymarching point.")

        # jitter the initial depths
        ray_bundle_t = ray_bundle._replace(
            lengths=ray_bundle.lengths +
            torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std)

        states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
        signed_distance = torch.zeros_like(ray_bundle_t.lengths)
        raymarch_features = None
        for t in range(self.num_raymarch_steps + 1):
            # move signed_distance along each ray
            ray_bundle_t = ray_bundle_t._replace(lengths=ray_bundle_t.lengths +
                                                 signed_distance)

            # eval the raymarching function
            raymarch_features, _ = implicit_function(
                ray_bundle_t,
                raymarch_features=None,
            )
            if self.verbose:
                msg = (
                    f"{t}: mu={float(signed_distance.mean()):1.2e};" +
                    f" std={float(signed_distance.std()):1.2e};"
                    # pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
                    #  typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
                    #  param but got `Tensor`.
                    + f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
                    # pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
                    #  typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
                    #  param but got `Tensor`.
                    + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};")
                logger.info(msg)
            if t == self.num_raymarch_steps:
                break

            # run the lstm marcher
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
            state_h, state_c = self._lstm(
                raymarch_features.view(-1, raymarch_features.shape[-1]),
                states[-1],
            )
            if state_h.requires_grad:
                state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
            # predict the next step size
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
            signed_distance = self._out_layer(state_h).view(
                ray_bundle_t.lengths.shape)
            # log the lstm states
            states.append((state_h, state_c))

        opacity_logits, features = implicit_function(
            raymarch_features=raymarch_features,
            ray_bundle=ray_bundle_t,
        )
        mask = torch.sigmoid(opacity_logits)
        depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm(
            dim=-1, keepdim=True)

        return RendererOutput(
            features=features[..., 0, :],
            depths=depth,
            masks=mask[..., 0, :],
        )
Exemple #13
0
    def forward(
        self,
        cameras: CamerasBase,
        chunksize: int = None,
        chunk_idx: int = 0,
        camera_hash: str = None,
        caching: bool = False,
        **kwargs,
    ) -> RayBundle:
        """
        Args:
            cameras: A batch of `batch_size` cameras from which the rays are emitted.
            chunksize: The number of rays per chunk.
                Active only when `self.training==False`.
            chunk_idx: The index of the ray chunk. The number has to be in
                `[0, self.get_n_chunks(chunksize, batch_size)-1]`.
                Active only when `self.training==False`.
            camera_hash: A unique identifier of a pre-cached camera. If `None`,
                the cache is not searched and the rays are calculated from scratch.
            caching: If `True`, activates the caching mode that returns the `RayBundle`
                that should be stored into the cache.
        Returns:
            A named tuple `RayBundle` with the following fields:
                origins: A tensor of shape
                    `(batch_size, n_rays_per_image, 3)`
                    denoting the locations of ray origins in the world coordinates.
                directions: A tensor of shape
                    `(batch_size, n_rays_per_image, 3)`
                    denoting the directions of each ray in the world coordinates.
                lengths: A tensor of shape
                    `(batch_size, n_rays_per_image, n_pts_per_ray)`
                    containing the z-coordinate (=depth) of each ray in world units.
                xys: A tensor of shape
                    `(batch_size, n_rays_per_image, 2)`
                    containing the 2D image coordinates of each ray.
        """

        batch_size = cameras.R.shape[0]  # pyre-ignore
        device = cameras.device

        if (camera_hash is None) and (not caching) and self.training:
            # Sample random rays from scratch.
            ray_bundle = self._mc_raysampler(cameras)
            ray_bundle = self._normalize_raybundle(ray_bundle)
        else:
            if camera_hash is not None:
                # The case where we retrieve a camera from cache.
                if batch_size != 1:
                    raise NotImplementedError(
                        "Ray caching works only for batches with a single camera!"
                    )
                full_ray_bundle = self._ray_cache[camera_hash]
            else:
                # We generate a full ray grid from scratch.
                full_ray_bundle = self._grid_raysampler(cameras)
                full_ray_bundle = self._normalize_raybundle(full_ray_bundle)

            n_pixels = full_ray_bundle.directions.shape[:-1].numel()

            if self.training:
                # During training we randomly subsample rays.
                sel_rays = torch.randperm(n_pixels, device=device)[
                    : self._mc_raysampler._n_rays_per_image
                ]
            else:
                # In case we test, we take only the requested chunk.
                if chunksize is None:
                    chunksize = n_pixels * batch_size
                start = chunk_idx * chunksize * batch_size
                end = min(start + chunksize, n_pixels)
                sel_rays = torch.arange(
                    start,
                    end,
                    dtype=torch.long,
                    device=full_ray_bundle.lengths.device,
                )

            # Take the "sel_rays" rays from the full ray bundle.
            ray_bundle = RayBundle(
                *[
                    v.view(n_pixels, -1)[sel_rays]
                    .view(batch_size, sel_rays.numel() // batch_size, -1)
                    .to(device)
                    for v in full_ray_bundle
                ]
            )

        if (
            (self._stratified and self.training)
            or (self._stratified_test and not self.training)
        ) and not caching:  # Make sure not to stratify when caching!
            ray_bundle = self._stratify_ray_bundle(ray_bundle)

        return ray_bundle
Exemple #14
0
def batched_forward(
    net,
    ray_bundle: RayBundle,
    n_batches: int = 16,
    path=None,
    **kwargs,
):
    """
    This function is used to allow for memory efficient processing
    of input rays. The input rays are first split to `n_batches`
    chunks and passed through the `self.forward` function one at a time
    in a for loop. Combined with disabling Pytorch gradient caching
    (`torch.no_grad()`), this allows for rendering large batches
    of rays that do not all fit into GPU memory in a single forward pass.
    In our case, batched_forward is used to export a fully-sized render
    of the radiance field for visualisation purposes.

    Args:
        ray_bundle: A RayBundle object containing the following variables:
            origins: A tensor of shape `(minibatch, ..., 3)` denoting the
                origins of the sampling rays in world coords.
            directions: A tensor of shape `(minibatch, ..., 3)`
                containing the direction vectors of sampling rays in world coords.
            lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
                containing the lengths at which the rays are sampled.
        n_batches: Specifies the number of batches the input rays are split into.
            The larger the number of batches, the smaller the memory footprint
            and the lower the processing speed.

    Returns:
        rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
            denoting the opacitiy of each ray point.
        rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
            denoting the color of each ray point.

    """

    # Parse out shapes needed for tensor reshaping in this function.
    n_pts_per_ray = ray_bundle.lengths.shape[-1]
    spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]

    # Split the rays to `n_batches` batches.
    tot_samples = ray_bundle.origins.shape[:-1].numel()
    batches = torch.chunk(torch.arange(tot_samples), n_batches)

    # For each batch, execute the standard forward pass.
    batch_outputs = [
        net(RayBundle(
            origins=ray_bundle.origins.view(-1, 3)[batch_idx],
            directions=ray_bundle.directions.view(-1, 3)[batch_idx],
            lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
            xys=None,
        ),
            path=None if path is None else f"{path}/batch{i}")
        for i, batch_idx in enumerate(batches)
    ]

    # Concatenate the per-batch rays_densities and rays_colors
    # and reshape according to the sizes of the inputs.
    rays_densities, rays_colors = [
        torch.cat([batch_output[output_i] for batch_output in batch_outputs],
                  dim=0).view(*spatial_size, -1) for output_i in (0, 1)
    ]
    if path is not None:
        torch.save(spatial_size, f"{path}/spatial_size.pt")
    return rays_densities, rays_colors