Exemplo n.º 1
0
    def test_input_types(self):
        """
        Check that ValueErrors are thrown where expected.
        """
        # check the constructor
        for bad_raysampler in (None, 5, []):
            for bad_raymarcher in (None, 5, []):
                with self.assertRaises(ValueError):
                    ImplicitRenderer(raysampler=bad_raysampler,
                                     raymarcher=bad_raymarcher)

        # init a trivial renderer
        renderer = ImplicitRenderer(
            raysampler=NDCMultinomialRaysampler(
                image_width=100,
                image_height=100,
                n_pts_per_ray=10,
                min_depth=0.1,
                max_depth=1.0,
            ),
            raymarcher=EmissionAbsorptionRaymarcher(),
        )

        # get default cameras
        cameras = init_cameras()

        for bad_volumetric_function in (None, 5, []):
            with self.assertRaises(ValueError):
                renderer(cameras=cameras,
                         volumetric_function=bad_volumetric_function)
Exemplo n.º 2
0
    def renderer(
        batch_size=10,
        raymarcher_type=EmissionAbsorptionRaymarcher,
        n_rays_per_image=10,
        n_pts_per_ray=10,
        sphere_diameter=0.75,
    ):
        # generate NDC camera extrinsics and intrinsics
        cameras = init_cameras(batch_size, image_size=None, ndc=True)

        # get rand offset of the volume
        sphere_centroid = torch.randn(batch_size, 3,
                                      device=cameras.device) * 0.1

        # init the mc raysampler
        raysampler = MonteCarloRaysampler(
            min_x=-1.0,
            max_x=1.0,
            min_y=-1.0,
            max_y=1.0,
            n_rays_per_image=n_rays_per_image,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=0.1,
            max_depth=2.0,
        ).to(cameras.device)

        # get the raymarcher
        raymarcher = raymarcher_type()

        # get the implicit renderer
        renderer = ImplicitRenderer(raysampler=raysampler,
                                    raymarcher=raymarcher)

        def run_renderer():
            renderer(
                cameras=cameras,
                volumetric_function=spherical_volumetric_function,
                sphere_centroid=sphere_centroid,
                sphere_diameter=sphere_diameter,
            )

        return run_renderer
Exemplo n.º 3
0
    def _rotating_gif(self,
                      image_size,
                      n_frames=50,
                      fps=15,
                      sphere_diameter=0.5):
        """
        Render a gif animation of a rotating sphere (runs only if `DEBUG==True`).
        """

        if not DEBUG:
            # do not run this if debug is False
            return

        # generate camera extrinsics and intrinsics
        cameras = init_cameras(n_frames, image_size=image_size)

        # init the grid raysampler
        raysampler = MultinomialRaysampler(
            min_x=0.5,
            max_x=image_size[1] - 0.5,
            min_y=0.5,
            max_y=image_size[0] - 0.5,
            image_width=image_size[1],
            image_height=image_size[0],
            n_pts_per_ray=256,
            min_depth=0.1,
            max_depth=2.0,
        )

        # get the EA raymarcher
        raymarcher = EmissionAbsorptionRaymarcher()

        # get the implicit render
        renderer = ImplicitRenderer(raysampler=raysampler,
                                    raymarcher=raymarcher)

        # get the (0) centroid of the sphere
        sphere_centroid = torch.zeros(n_frames, 3, device=cameras.device) * 0.1

        # run the renderer
        images_opacities = renderer(
            cameras=cameras,
            volumetric_function=spherical_volumetric_function,
            sphere_centroid=sphere_centroid,
            sphere_diameter=sphere_diameter,
        )[0]

        # split output to the alpha channel and rendered images
        images, opacities = images_opacities[..., :3], images_opacities[..., 3]

        # export the gif
        outdir = tempfile.gettempdir() + "/test_implicit_renderer_gifs"
        os.makedirs(outdir, exist_ok=True)
        frames = []
        for image, opacity in zip(images, opacities):
            image_pil = Image.fromarray((torch.cat(
                (image, opacity[..., None].clamp(0.0, 1.0).repeat(1, 1, 3)),
                dim=1,
            ).detach().cpu().numpy() * 255.0).astype(np.uint8))
            frames.append(image_pil)
        outfile = os.path.join(outdir, "rotating_sphere.gif")
        frames[0].save(
            outfile,
            save_all=True,
            append_images=frames[1:],
            duration=n_frames // fps,
            loop=0,
        )
        print(f"exported {outfile}")
Exemplo n.º 4
0
    def _compare_with_meshes_renderer(self,
                                      image_size,
                                      batch_size=11,
                                      sphere_diameter=0.6):
        """
        Generate a spherical RGB volumetric function and its corresponding mesh
        and check whether MeshesRenderer returns the same images as the
        corresponding ImplicitRenderer.
        """

        # generate NDC camera extrinsics and intrinsics
        cameras = init_cameras(batch_size, image_size=image_size, ndc=True)

        # get rand offset of the volume
        sphere_centroid = torch.randn(batch_size, 3,
                                      device=cameras.device) * 0.1
        sphere_centroid.requires_grad = True

        # init the grid raysampler with the ndc grid
        raysampler = NDCMultinomialRaysampler(
            image_width=image_size[1],
            image_height=image_size[0],
            n_pts_per_ray=256,
            min_depth=0.1,
            max_depth=2.0,
        )

        # get the EA raymarcher
        raymarcher = EmissionAbsorptionRaymarcher()

        # jitter the camera intrinsics a bit for each render
        cameras_randomized = cameras.clone()
        cameras_randomized.principal_point = (
            torch.randn_like(cameras.principal_point) * 0.3)
        cameras_randomized.focal_length = (
            cameras.focal_length +
            torch.randn_like(cameras.focal_length) * 0.2)

        # the list of differentiable camera vars
        cam_vars = ("R", "T", "focal_length", "principal_point")
        # enable the gradient caching for the camera variables
        for cam_var in cam_vars:
            getattr(cameras_randomized, cam_var).requires_grad = True

        # get the implicit renderer
        images_opacities = ImplicitRenderer(
            raysampler=raysampler, raymarcher=raymarcher)(
                cameras=cameras_randomized,
                volumetric_function=spherical_volumetric_function,
                sphere_centroid=sphere_centroid,
                sphere_diameter=sphere_diameter,
            )[0]

        # check that the renderer does not erase gradients
        loss = images_opacities.sum()
        loss.backward()
        for check_var in (
                *[
                    getattr(cameras_randomized, cam_var)
                    for cam_var in cam_vars
                ],
                sphere_centroid,
        ):
            self.assertIsNotNone(check_var.grad)

        # instantiate the corresponding spherical mesh
        ico = ico_sphere(level=4, device=cameras.device).extend(batch_size)
        verts = (torch.nn.functional.normalize(ico.verts_padded(), dim=-1) *
                 sphere_diameter + sphere_centroid[:, None])
        meshes = Meshes(
            verts=verts,
            faces=ico.faces_padded(),
            textures=TexturesVertex(verts_features=(
                torch.nn.functional.normalize(verts, dim=-1) * 0.5 + 0.5)),
        )

        # instantiate the corresponding mesh renderer
        lights = PointLights(device=cameras.device, location=[[0.0, 0.0, 0.0]])
        renderer_textured = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras_randomized,
                raster_settings=RasterizationSettings(
                    image_size=image_size,
                    blur_radius=1e-3,
                    faces_per_pixel=10,
                    z_clip_value=None,
                    perspective_correct=False,
                ),
            ),
            shader=SoftPhongShader(
                device=cameras.device,
                cameras=cameras_randomized,
                lights=lights,
                materials=Materials(
                    ambient_color=((2.0, 2.0, 2.0), ),
                    diffuse_color=((0.0, 0.0, 0.0), ),
                    specular_color=((0.0, 0.0, 0.0), ),
                    shininess=64,
                    device=cameras.device,
                ),
                blend_params=BlendParams(sigma=1e-3,
                                         gamma=1e-4,
                                         background_color=(0.0, 0.0, 0.0)),
            ),
        )

        # get the mesh render
        images_opacities_meshes = renderer_textured(meshes,
                                                    cameras=cameras_randomized,
                                                    lights=lights)

        if DEBUG:
            outdir = tempfile.gettempdir() + "/test_implicit_vs_mesh_renderer"
            os.makedirs(outdir, exist_ok=True)

            frames = []
            for (image_opacity,
                 image_opacity_mesh) in zip(images_opacities,
                                            images_opacities_meshes):
                image, opacity = image_opacity.split([3, 1], dim=-1)
                image_mesh, opacity_mesh = image_opacity_mesh.split([3, 1],
                                                                    dim=-1)
                diff_image = (((image - image_mesh) * 0.5 + 0.5).mean(
                    dim=2, keepdim=True).repeat(1, 1, 3))
                image_pil = Image.fromarray((torch.cat(
                    (
                        image,
                        image_mesh,
                        diff_image,
                        opacity.repeat(1, 1, 3),
                        opacity_mesh.repeat(1, 1, 3),
                    ),
                    dim=1,
                ).detach().cpu().numpy() * 255.0).astype(np.uint8))
                frames.append(image_pil)

            # export gif
            outfile = os.path.join(outdir, "implicit_vs_mesh_render.gif")
            frames[0].save(
                outfile,
                save_all=True,
                append_images=frames[1:],
                duration=batch_size // 15,
                loop=0,
            )
            print(f"exported {outfile}")

            # export concatenated frames
            outfile_cat = os.path.join(outdir, "implicit_vs_mesh_render.png")
            Image.fromarray(
                np.concatenate([np.array(f) for f in frames],
                               axis=0)).save(outfile_cat)
            print(f"exported {outfile_cat}")

        # compare the renders
        diff = (images_opacities - images_opacities_meshes).abs().mean(dim=-1)
        mu_diff = diff.mean(dim=(1, 2))
        std_diff = diff.std(dim=(1, 2))
        self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2)
        self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
Exemplo n.º 5
0
    def __init__(
            self,
            image_size: Tuple[int, int],
            n_pts_per_ray: int,
            n_pts_per_ray_fine: int,
            n_rays_per_image: int,
            min_depth: float,
            max_depth: float,
            stratified: bool,
            stratified_test: bool,
            chunk_size_test: int,
            n_harmonic_functions_xyz: int = 6,
            n_harmonic_functions_dir: int = 4,
            n_hidden_neurons_xyz: int = 256,
            n_hidden_neurons_dir: int = 128,
            n_layers_xyz: int = 8,
            append_xyz: List[int] = (5, ),
            density_noise_std: float = 0.0,
    ):
        """
        Args:
            image_size: The size of the rendered image (`[height, width]`).
            n_pts_per_ray: The number of points sampled along each ray for the
                coarse rendering pass.
            n_pts_per_ray_fine: The number of points sampled along each ray for the
                fine rendering pass.
            n_rays_per_image: Number of Monte Carlo ray samples when training
                (`self.training==True`).
            min_depth: The minimum depth of a sampled ray-point for the coarse rendering.
            max_depth: The maximum depth of a sampled ray-point for the coarse rendering.
            stratified: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during training (`self.training==True`).
            stratified_test: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during evaluation (`self.training==False`).
            chunk_size_test: The number of rays in each chunk of image rays.
                Active only when `self.training==True`.
            n_harmonic_functions_xyz: The number of harmonic functions
                used to form the harmonic embedding of 3D point locations.
            n_harmonic_functions_dir: The number of harmonic functions
                used to form the harmonic embedding of the ray directions.
            n_hidden_neurons_xyz: The number of hidden units in the
                fully connected layers of the MLP that accepts the 3D point
                locations and outputs the occupancy field with the intermediate
                features.
            n_hidden_neurons_dir: The number of hidden units in the
                fully connected layers of the MLP that accepts the intermediate
                features and ray directions and outputs the radiance field
                (per-point colors).
            n_layers_xyz: The number of layers of the MLP that outputs the
                occupancy field.
            append_xyz: The list of indices of the skip layers of the occupancy MLP.
                Prior to evaluating the skip layers, the tensor which was input to MLP
                is appended to the skip layer input.
            density_noise_std: The standard deviation of the random normal noise
                added to the output of the occupancy MLP.
                Active only when `self.training==True`.
        """

        super().__init__()

        # The renderers and implicit functions are stored under the fine/coarse
        # keys in ModuleDict PyTorch modules.
        self._renderer = torch.nn.ModuleDict()
        self._implicit_function = torch.nn.ModuleDict()

        # Init the EA raymarcher used by both passes.
        raymarcher = EmissionAbsorptionNeRFRaymarcher()

        # Parse out image dimensions.
        image_height, image_width = image_size

        for render_pass in ("coarse", "fine"):
            if render_pass == "coarse":
                # Initialize the coarse raysampler.
                raysampler = NeRFRaysampler(
                    n_pts_per_ray=n_pts_per_ray,
                    min_depth=min_depth,
                    max_depth=max_depth,
                    stratified=stratified,
                    stratified_test=stratified_test,
                    n_rays_per_image=n_rays_per_image,
                    image_height=image_height,
                    image_width=image_width,
                )
            elif render_pass == "fine":
                # Initialize the fine raysampler.
                raysampler = ProbabilisticRaysampler(
                    n_pts_per_ray=n_pts_per_ray_fine,
                    stratified=stratified,
                    stratified_test=stratified_test,
                )
            else:
                raise ValueError(f"No such rendering pass {render_pass}")

            # Initialize the fine/coarse renderer.
            self._renderer[render_pass] = ImplicitRenderer(
                raysampler=raysampler,
                raymarcher=raymarcher,
            )

            # Instantiate the fine/coarse NeuralRadianceField module.
            self._implicit_function[render_pass] = NeuralRadianceField(
                n_harmonic_functions_xyz=n_harmonic_functions_xyz,
                n_harmonic_functions_dir=n_harmonic_functions_dir,
                n_hidden_neurons_xyz=n_hidden_neurons_xyz,
                n_hidden_neurons_dir=n_hidden_neurons_dir,
                n_layers_xyz=n_layers_xyz,
                append_xyz=append_xyz,
            )

        self._density_noise_std = density_noise_std
        self._chunk_size_test = chunk_size_test
        self._image_size = image_size
Exemplo n.º 6
0
def main(inference, n_iter, save_state_dict, load_state_dict,
         kl_annealing_iters, zero_kl_iters, max_kl_factor, init_scale,
         save_visualization):
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        print('Please note that NeRF is a resource-demanding method.' +
              ' Running this notebook on CPU will be extremely slow.' +
              ' We recommend running the example on a GPU' +
              ' with at least 10 GB of memory.')
        device = torch.device("cpu")

    target_cameras, target_images, target_silhouettes = generate_cow_renders(
        num_views=30, azimuth_low=-180, azimuth_high=90)
    print(f'Generated {len(target_images)} images/silhouettes/cameras.')

    # render_size describes the size of both sides of the
    # rendered images in pixels. Since an advantage of
    # Neural Radiance Fields are high quality renders
    # with a significant amount of details, we render
    # the implicit function at double the size of
    # target images.
    render_size = target_images.shape[1] * 2

    # Our rendered scene is centered around (0,0,0)
    # and is enclosed inside a bounding box
    # whose side is roughly equal to 3.0 (world units).
    volume_extent_world = 3.0

    # 1) Instantiate the raysamplers.

    # Here, NDCGridRaysampler generates a rectangular image
    # grid of rays whose coordinates follow the PyTorch3d
    # coordinate conventions.
    raysampler_grid = NDCGridRaysampler(
        image_height=render_size,
        image_width=render_size,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # MonteCarloRaysampler generates a random subset
    # of `n_rays_per_image` rays emitted from the image plane.
    raysampler_mc = MonteCarloRaysampler(
        min_x=-1.0,
        max_x=1.0,
        min_y=-1.0,
        max_y=1.0,
        n_rays_per_image=750,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # 2) Instantiate the raymarcher.
    # Here, we use the standard EmissionAbsorptionRaymarcher
    # which marches along each ray in order to render
    # the ray into a single 3D color vector
    # and an opacity scalar.
    raymarcher = EmissionAbsorptionRaymarcher()

    # Finally, instantiate the implicit renders
    # for both raysamplers.
    renderer_grid = ImplicitRenderer(
        raysampler=raysampler_grid,
        raymarcher=raymarcher,
    )
    renderer_mc = ImplicitRenderer(
        raysampler=raysampler_mc,
        raymarcher=raymarcher,
    )

    # First move all relevant variables to the correct device.
    renderer_grid = renderer_grid.to(device)
    renderer_mc = renderer_mc.to(device)
    target_cameras = target_cameras.to(device)
    target_images = target_images.to(device)
    target_silhouettes = target_silhouettes.to(device)

    # Set the seed for reproducibility
    torch.manual_seed(1)

    # Instantiate the radiance field model.
    neural_radiance_field_net = NeuralRadianceField().to(device)
    if load_state_dict is not None:
        sd = torch.load(load_state_dict)
        sd["harmonic_embedding.frequencies"] = neural_radiance_field_net.harmonic_embedding.frequencies
        neural_radiance_field_net.load_state_dict(sd)

    # TYXE comment: set up the BNN depending on the desired inference
    standard_normal = dist.Normal(
        torch.tensor(0.).to(device),
        torch.tensor(1.).to(device))
    prior_kwargs = {}
    test_samples = 1
    if inference == "ml":
        prior_kwargs.update(expose_all=False, hide_all=True)
        guide = None
    elif inference == "map":
        guide = partial(pyro.infer.autoguide.AutoDelta,
                        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(
                            neural_radiance_field_net))
    elif inference == "mean-field":
        guide = partial(tyxe.guides.AutoNormal,
                        init_scale=init_scale,
                        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(
                            neural_radiance_field_net))
        test_samples = 8
    else:
        raise RuntimeError(f"Unreachable inference: {inference}")

    prior = tyxe.priors.IIDPrior(standard_normal, **prior_kwargs)
    neural_radiance_field = tyxe.PytorchBNN(neural_radiance_field_net, prior,
                                            guide)

    # TYXE comment: we need a batch of dummy data for the BNN to trace the parameters
    dummy_data = namedtuple("RayBundle", "origins directions lengths")(
        torch.randn(1, 1, 3).to(device), torch.randn(1, 1, 3).to(device),
        torch.randn(1, 1, 8).to(device))
    # Instantiate the Adam optimizer. We set its master learning rate to 1e-3.
    lr = 1e-3
    optimizer = torch.optim.Adam(
        neural_radiance_field.pytorch_parameters(dummy_data), lr=lr)

    # We sample 6 random cameras in a minibatch. Each camera
    # emits raysampler_mc.n_pts_per_image rays.
    batch_size = 6

    # Init the loss history buffers.
    loss_history_color, loss_history_sil = [], []

    if kl_annealing_iters > 0 or zero_kl_iters > 0:
        kl_factor = 0.
        kl_annealing_rate = max_kl_factor / max(kl_annealing_iters, 1)
    else:
        kl_factor = max_kl_factor
        kl_annealing_rate = 0.
    # The main optimization loop.
    for iteration in range(n_iter):
        # In case we reached the last 75% of iterations,
        # decrease the learning rate of the optimizer 10-fold.
        if iteration == round(n_iter * 0.75):
            print('Decreasing LR 10-fold ...')
            optimizer = torch.optim.Adam(
                neural_radiance_field.pytorch_parameters(dummy_data),
                lr=lr * 0.1)

        # Zero the optimizer gradient.
        optimizer.zero_grad()

        # Sample random batch indices.
        batch_idx = torch.randperm(len(target_cameras))[:batch_size]

        # Sample the minibatch of cameras.
        batch_cameras = FoVPerspectiveCameras(
            R=target_cameras.R[batch_idx],
            T=target_cameras.T[batch_idx],
            znear=target_cameras.znear[batch_idx],
            zfar=target_cameras.zfar[batch_idx],
            aspect_ratio=target_cameras.aspect_ratio[batch_idx],
            fov=target_cameras.fov[batch_idx],
            device=device,
        )

        rendered_images_silhouettes, sampled_rays = renderer_mc(
            cameras=batch_cameras,
            volumetric_function=partial(batched_forward,
                                        net=neural_radiance_field))
        rendered_images, rendered_silhouettes = (
            rendered_images_silhouettes.split([3, 1], dim=-1))

        # Compute the silhoutte error as the mean huber
        # loss between the predicted masks and the
        # sampled target silhouettes.
        silhouettes_at_rays = sample_images_at_mc_locs(
            target_silhouettes[batch_idx, ..., None], sampled_rays.xys)
        sil_err = huber(
            rendered_silhouettes,
            silhouettes_at_rays,
        ).abs().mean()

        # Compute the color error as the mean huber
        # loss between the rendered colors and the
        # sampled target images.
        colors_at_rays = sample_images_at_mc_locs(target_images[batch_idx],
                                                  sampled_rays.xys)
        color_err = huber(
            rendered_images,
            colors_at_rays,
        ).abs().mean()

        # The optimization loss is a simple
        # sum of the color and silhouette errors.
        # TYXE comment: we also add a kl loss for the variational posterior scaled by the size of the data
        # i.e. the total number of data points times the number of values that the data-dependent part of the
        # objective averages over. Effectively I'm treating this as if this was something like a Bernoulli likelihood
        # in a VAE where the expected log likelihood is averaged over both data points and pixels
        beta = kl_factor / (target_images.numel() + target_silhouettes.numel())
        kl_err = neural_radiance_field.cached_kl_loss
        loss = color_err + sil_err + beta * kl_err

        # Log the loss history.
        loss_history_color.append(float(color_err))
        loss_history_sil.append(float(sil_err))

        # Every 10 iterations, print the current values of the losses.
        if iteration % 10 == 0:
            print(f'Iteration {iteration:05d}:' +
                  f' loss color = {float(color_err):1.2e}' +
                  f' loss silhouette = {float(sil_err):1.2e}' +
                  f' loss kl = {float(kl_err):1.2e}' +
                  f' kl_factor = {kl_factor:1.3e}')

        # Take the optimization step.
        loss.backward()
        optimizer.step()

        # TYXE comment: anneal the kl rate
        if iteration >= zero_kl_iters:
            kl_factor = min(max_kl_factor, kl_factor + kl_annealing_rate)

        # Visualize the full renders every 100 iterations.
        if iteration % 1000 == 0:
            show_idx = torch.randperm(len(target_cameras))[:1]
            fig = show_full_render(
                neural_radiance_field,
                FoVPerspectiveCameras(
                    R=target_cameras.R[show_idx],
                    T=target_cameras.T[show_idx],
                    znear=target_cameras.znear[show_idx],
                    zfar=target_cameras.zfar[show_idx],
                    aspect_ratio=target_cameras.aspect_ratio[show_idx],
                    fov=target_cameras.fov[show_idx],
                    device=device,
                ),
                target_images[show_idx][0],
                target_silhouettes[show_idx][0],
                loss_history_color,
                loss_history_sil,
                renderer_grid,
                num_forward=test_samples)
            plt.savefig(f"nerf/full_render{iteration}.png")
            plt.close(fig)

    with torch.no_grad():
        rotating_nerf_frames, uncertainty_frames = generate_rotating_nerf(
            neural_radiance_field,
            target_cameras,
            renderer_grid,
            device,
            n_frames=3 * 5,
            num_forward=test_samples,
            save_visualization=save_visualization)

    for i, (img, uncertainty) in enumerate(
            zip(
                rotating_nerf_frames.clamp(0., 1.).cpu().numpy(),
                uncertainty_frames.cpu().numpy())):
        f, ax = plt.subplots(figsize=(1.625, 1.625))
        f.subplots_adjust(0, 0, 1, 1)
        ax.imshow(img)
        ax.set_axis_off()
        f.savefig(f"nerf/final_image{i}.jpg",
                  bbox_inches="tight",
                  pad_inches=0)
        plt.close(f)

        f, ax = plt.subplots(figsize=(1.625, 1.625))
        f.subplots_adjust(0, 0, 1, 1)
        ax.imshow(uncertainty, cmap="hot", vmax=0.75**0.5)
        ax.set_axis_off()
        f.savefig(f"nerf/final_uncertainty{i}.jpg",
                  bbox_inches="tight",
                  pad_inches=0)
        plt.close(f)

    if save_state_dict is not None:
        if inference != "ml":
            raise ValueError(
                "Saving the state dict is only available for ml inference for now."
            )
        state_dict = dict(
            neural_radiance_field.named_pytorch_parameters(dummy_data))
        torch.save(state_dict, save_state_dict)

    test_cameras, test_images, test_silhouettes = generate_cow_renders(
        num_views=10, azimuth_low=90, azimuth_high=180)

    del renderer_mc
    del target_cameras
    del target_images
    del target_silhouettes
    torch.cuda.empty_cache()

    test_cameras = test_cameras.to(device)
    test_images = test_images.to(device)
    test_silhouettes = test_silhouettes.to(device)

    # TODO remove duplication from training code for test error
    with torch.no_grad():
        sil_err = 0.
        color_err = 0.
        for i in range(len(test_cameras)):
            batch_idx = [i]

            # Sample the minibatch of cameras.
            batch_cameras = FoVPerspectiveCameras(
                R=test_cameras.R[batch_idx],
                T=test_cameras.T[batch_idx],
                znear=test_cameras.znear[batch_idx],
                zfar=test_cameras.zfar[batch_idx],
                aspect_ratio=test_cameras.aspect_ratio[batch_idx],
                fov=test_cameras.fov[batch_idx],
                device=device,
            )

            img_list, sils_list, sampled_rays_list, = [], [], []
            for _ in range(test_samples):
                rendered_images_silhouettes, sampled_rays = renderer_grid(
                    cameras=batch_cameras,
                    volumetric_function=partial(batched_forward,
                                                net=neural_radiance_field))
                imgs, sils = (rendered_images_silhouettes.split([3, 1],
                                                                dim=-1))
                img_list.append(imgs)
                sils_list.append(sils)
                sampled_rays_list.append(sampled_rays.xys)

            assert sampled_rays_list[0].eq(
                torch.stack(sampled_rays_list)).all()

            rendered_images = torch.stack(img_list).mean(0)
            rendered_silhouettes = torch.stack(sils_list).mean(0)

            # Compute the silhoutte error as the mean huber
            # loss between the predicted masks and the
            # sampled target silhouettes.
            # TYXE comment: sampled_rays are always the same for renderer_grid
            silhouettes_at_rays = sample_images_at_mc_locs(
                test_silhouettes[batch_idx, ..., None], sampled_rays.xys)
            sil_err += huber(
                rendered_silhouettes,
                silhouettes_at_rays,
            ).abs().mean().item() / len(test_cameras)

            # Compute the color error as the mean huber
            # loss between the rendered colors and the
            # sampled target images.
            colors_at_rays = sample_images_at_mc_locs(test_images[batch_idx],
                                                      sampled_rays.xys)
            color_err += huber(
                rendered_images,
                colors_at_rays,
            ).abs().mean().item() / len(test_cameras)

    print(f"Test error: sil={sil_err:1.3e}; col={color_err:1.3e}")
Exemplo n.º 7
0
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # 2) Instantiate the raymarcher.
    # Here, we use the standard EmissionAbsorptionRaymarcher
    # which marches along each ray in order to render
    # the ray into a single 3D color vector
    # and an opacity scalar.
    raymarcher = EmissionAbsorptionRaymarcher()

    # Finally, instantiate the implicit renders
    # for both raysamplers.
    renderer_grid = ImplicitRenderer(
        raysampler=raysampler_grid,
        raymarcher=raymarcher,
    )
    renderer_mc = ImplicitRenderer(
        raysampler=raysampler_mc,
        raymarcher=raymarcher,
    )

    # First move all relevant variables to the correct device.
    renderer_grid = renderer_grid.to(device)
    renderer_mc = renderer_mc.to(device)
    target_cameras = target_cameras.to(device)
    target_images = target_images.to(device)
    target_silhouettes = target_silhouettes.to(device)

    # Set the seed for reproducibility
    torch.manual_seed(1)