Esempio n. 1
0
    def __init__(
        self,
        n_pts_per_ray: int,
        min_depth: float,
        max_depth: float,
        n_rays_per_image: int,
        image_width: int,
        image_height: int,
        stratified: bool = False,
        stratified_test: bool = False,
    ):
        """
        Args:
            n_pts_per_ray: The number of points sampled along each ray.
            min_depth: The minimum depth of a ray-point.
            max_depth: The maximum depth of a ray-point.
            n_rays_per_image: Number of Monte Carlo ray samples when training
                (`self.training==True`).
            image_width: The horizontal size of the image grid.
            image_height: The vertical size of the image grid.
            stratified: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during training (`self.training==True`).
            stratified_test: If `True`, stratifies (=randomly offsets) the depths
                of each ray point during evaluation (`self.training==False`).
        """

        super().__init__()
        self._stratified = stratified
        self._stratified_test = stratified_test

        # Initialize the grid ray sampler.
        self._grid_raysampler = NDCMultinomialRaysampler(
            image_width=image_width,
            image_height=image_height,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=min_depth,
            max_depth=max_depth,
        )

        # Initialize the Monte Carlo ray sampler.
        self._mc_raysampler = MonteCarloRaysampler(
            min_x=-1.0,
            max_x=1.0,
            min_y=-1.0,
            max_y=1.0,
            n_rays_per_image=n_rays_per_image,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=min_depth,
            max_depth=max_depth,
        )

        # create empty ray cache
        self._ray_cache = {}
Esempio n. 2
0
    def renderer(
        batch_size=10,
        raymarcher_type=EmissionAbsorptionRaymarcher,
        n_rays_per_image=10,
        n_pts_per_ray=10,
        sphere_diameter=0.75,
    ):
        # generate NDC camera extrinsics and intrinsics
        cameras = init_cameras(batch_size, image_size=None, ndc=True)

        # get rand offset of the volume
        sphere_centroid = torch.randn(batch_size, 3,
                                      device=cameras.device) * 0.1

        # init the mc raysampler
        raysampler = MonteCarloRaysampler(
            min_x=-1.0,
            max_x=1.0,
            min_y=-1.0,
            max_y=1.0,
            n_rays_per_image=n_rays_per_image,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=0.1,
            max_depth=2.0,
        ).to(cameras.device)

        # get the raymarcher
        raymarcher = raymarcher_type()

        # get the implicit renderer
        renderer = ImplicitRenderer(raysampler=raysampler,
                                    raymarcher=raymarcher)

        def run_renderer():
            renderer(
                cameras=cameras,
                volumetric_function=spherical_volumetric_function,
                sphere_centroid=sphere_centroid,
                sphere_diameter=sphere_diameter,
            )

        return run_renderer
Esempio n. 3
0
    def renderer(
        volume_size=(25, 25, 25),
        batch_size=10,
        shape="sphere",
        raymarcher_type=EmissionAbsorptionRaymarcher,
        n_rays_per_image=10,
        n_pts_per_ray=10,
    ):
        # get the volumes
        volumes = init_boundary_volume(volume_size=volume_size,
                                       batch_size=batch_size,
                                       shape=shape)[0]

        # init the mc raysampler
        raysampler = MonteCarloRaysampler(
            min_x=-1.0,
            max_x=1.0,
            min_y=-1.0,
            max_y=1.0,
            n_rays_per_image=n_rays_per_image,
            n_pts_per_ray=n_pts_per_ray,
            min_depth=0.1,
            max_depth=2.0,
        ).to(volumes.device)

        # get the raymarcher
        raymarcher = raymarcher_type()

        renderer = VolumeRenderer(raysampler=raysampler,
                                  raymarcher=raymarcher,
                                  sample_mode="bilinear")

        # generate NDC camera extrinsics and intrinsics
        cameras = init_cameras(batch_size, image_size=None, ndc=True)

        def run_renderer():
            renderer(cameras=cameras, volumes=volumes)

        return run_renderer
Esempio n. 4
0
    def test_monte_carlo_rendering(self,
                                   n_frames=20,
                                   volume_size=(30, 30, 30),
                                   image_size=(40, 50)):
        """
        Tests that rendering with the MonteCarloRaysampler matches the
        rendering with MultinomialRaysampler sampled at the corresponding
        MonteCarlo locations.
        """
        volumes = init_boundary_volume(volume_size=volume_size,
                                       batch_size=n_frames,
                                       shape="sphere")[0]

        # generate camera extrinsics and intrinsics
        cameras = init_cameras(n_frames, image_size=image_size)

        # init the grid raysampler
        raysampler_multinomial = MultinomialRaysampler(
            min_x=0.5,
            max_x=image_size[1] - 0.5,
            min_y=0.5,
            max_y=image_size[0] - 0.5,
            image_width=image_size[1],
            image_height=image_size[0],
            n_pts_per_ray=256,
            min_depth=0.5,
            max_depth=2.0,
        )

        # init the mc raysampler
        raysampler_mc = MonteCarloRaysampler(
            min_x=0.5,
            max_x=image_size[1] - 0.5,
            min_y=0.5,
            max_y=image_size[0] - 0.5,
            n_rays_per_image=3000,
            n_pts_per_ray=256,
            min_depth=0.5,
            max_depth=2.0,
        )

        # get the EA raymarcher
        raymarcher = EmissionAbsorptionRaymarcher()

        # get both mc and grid renders
        (
            (images_opacities_mc, ray_bundle_mc),
            (images_opacities_grid, ray_bundle_grid),
        ) = [
            VolumeRenderer(
                raysampler=raysampler_multinomial,
                raymarcher=raymarcher,
                sample_mode="bilinear",
            )(cameras=cameras, volumes=volumes)
            for raysampler in (raysampler_mc, raysampler_multinomial)
        ]

        # convert the mc sampling locations to [-1, 1]
        sample_loc = ray_bundle_mc.xys.clone()
        sample_loc[..., 0] = 2 * (sample_loc[..., 0] / image_size[1]) - 1
        sample_loc[..., 1] = 2 * (sample_loc[..., 1] / image_size[0]) - 1

        # sample the grid render at the mc locations
        images_opacities_mc_ = torch.nn.functional.grid_sample(
            images_opacities_grid.permute(0, 3, 1, 2),
            sample_loc,
            align_corners=False)

        # check that the samples are the same
        self.assertClose(images_opacities_mc.permute(0, 3, 1, 2),
                         images_opacities_mc_,
                         atol=1e-4)
Esempio n. 5
0
def main(inference, n_iter, save_state_dict, load_state_dict,
         kl_annealing_iters, zero_kl_iters, max_kl_factor, init_scale,
         save_visualization):
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        print('Please note that NeRF is a resource-demanding method.' +
              ' Running this notebook on CPU will be extremely slow.' +
              ' We recommend running the example on a GPU' +
              ' with at least 10 GB of memory.')
        device = torch.device("cpu")

    target_cameras, target_images, target_silhouettes = generate_cow_renders(
        num_views=30, azimuth_low=-180, azimuth_high=90)
    print(f'Generated {len(target_images)} images/silhouettes/cameras.')

    # render_size describes the size of both sides of the
    # rendered images in pixels. Since an advantage of
    # Neural Radiance Fields are high quality renders
    # with a significant amount of details, we render
    # the implicit function at double the size of
    # target images.
    render_size = target_images.shape[1] * 2

    # Our rendered scene is centered around (0,0,0)
    # and is enclosed inside a bounding box
    # whose side is roughly equal to 3.0 (world units).
    volume_extent_world = 3.0

    # 1) Instantiate the raysamplers.

    # Here, NDCGridRaysampler generates a rectangular image
    # grid of rays whose coordinates follow the PyTorch3d
    # coordinate conventions.
    raysampler_grid = NDCGridRaysampler(
        image_height=render_size,
        image_width=render_size,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # MonteCarloRaysampler generates a random subset
    # of `n_rays_per_image` rays emitted from the image plane.
    raysampler_mc = MonteCarloRaysampler(
        min_x=-1.0,
        max_x=1.0,
        min_y=-1.0,
        max_y=1.0,
        n_rays_per_image=750,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # 2) Instantiate the raymarcher.
    # Here, we use the standard EmissionAbsorptionRaymarcher
    # which marches along each ray in order to render
    # the ray into a single 3D color vector
    # and an opacity scalar.
    raymarcher = EmissionAbsorptionRaymarcher()

    # Finally, instantiate the implicit renders
    # for both raysamplers.
    renderer_grid = ImplicitRenderer(
        raysampler=raysampler_grid,
        raymarcher=raymarcher,
    )
    renderer_mc = ImplicitRenderer(
        raysampler=raysampler_mc,
        raymarcher=raymarcher,
    )

    # First move all relevant variables to the correct device.
    renderer_grid = renderer_grid.to(device)
    renderer_mc = renderer_mc.to(device)
    target_cameras = target_cameras.to(device)
    target_images = target_images.to(device)
    target_silhouettes = target_silhouettes.to(device)

    # Set the seed for reproducibility
    torch.manual_seed(1)

    # Instantiate the radiance field model.
    neural_radiance_field_net = NeuralRadianceField().to(device)
    if load_state_dict is not None:
        sd = torch.load(load_state_dict)
        sd["harmonic_embedding.frequencies"] = neural_radiance_field_net.harmonic_embedding.frequencies
        neural_radiance_field_net.load_state_dict(sd)

    # TYXE comment: set up the BNN depending on the desired inference
    standard_normal = dist.Normal(
        torch.tensor(0.).to(device),
        torch.tensor(1.).to(device))
    prior_kwargs = {}
    test_samples = 1
    if inference == "ml":
        prior_kwargs.update(expose_all=False, hide_all=True)
        guide = None
    elif inference == "map":
        guide = partial(pyro.infer.autoguide.AutoDelta,
                        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(
                            neural_radiance_field_net))
    elif inference == "mean-field":
        guide = partial(tyxe.guides.AutoNormal,
                        init_scale=init_scale,
                        init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(
                            neural_radiance_field_net))
        test_samples = 8
    else:
        raise RuntimeError(f"Unreachable inference: {inference}")

    prior = tyxe.priors.IIDPrior(standard_normal, **prior_kwargs)
    neural_radiance_field = tyxe.PytorchBNN(neural_radiance_field_net, prior,
                                            guide)

    # TYXE comment: we need a batch of dummy data for the BNN to trace the parameters
    dummy_data = namedtuple("RayBundle", "origins directions lengths")(
        torch.randn(1, 1, 3).to(device), torch.randn(1, 1, 3).to(device),
        torch.randn(1, 1, 8).to(device))
    # Instantiate the Adam optimizer. We set its master learning rate to 1e-3.
    lr = 1e-3
    optimizer = torch.optim.Adam(
        neural_radiance_field.pytorch_parameters(dummy_data), lr=lr)

    # We sample 6 random cameras in a minibatch. Each camera
    # emits raysampler_mc.n_pts_per_image rays.
    batch_size = 6

    # Init the loss history buffers.
    loss_history_color, loss_history_sil = [], []

    if kl_annealing_iters > 0 or zero_kl_iters > 0:
        kl_factor = 0.
        kl_annealing_rate = max_kl_factor / max(kl_annealing_iters, 1)
    else:
        kl_factor = max_kl_factor
        kl_annealing_rate = 0.
    # The main optimization loop.
    for iteration in range(n_iter):
        # In case we reached the last 75% of iterations,
        # decrease the learning rate of the optimizer 10-fold.
        if iteration == round(n_iter * 0.75):
            print('Decreasing LR 10-fold ...')
            optimizer = torch.optim.Adam(
                neural_radiance_field.pytorch_parameters(dummy_data),
                lr=lr * 0.1)

        # Zero the optimizer gradient.
        optimizer.zero_grad()

        # Sample random batch indices.
        batch_idx = torch.randperm(len(target_cameras))[:batch_size]

        # Sample the minibatch of cameras.
        batch_cameras = FoVPerspectiveCameras(
            R=target_cameras.R[batch_idx],
            T=target_cameras.T[batch_idx],
            znear=target_cameras.znear[batch_idx],
            zfar=target_cameras.zfar[batch_idx],
            aspect_ratio=target_cameras.aspect_ratio[batch_idx],
            fov=target_cameras.fov[batch_idx],
            device=device,
        )

        rendered_images_silhouettes, sampled_rays = renderer_mc(
            cameras=batch_cameras,
            volumetric_function=partial(batched_forward,
                                        net=neural_radiance_field))
        rendered_images, rendered_silhouettes = (
            rendered_images_silhouettes.split([3, 1], dim=-1))

        # Compute the silhoutte error as the mean huber
        # loss between the predicted masks and the
        # sampled target silhouettes.
        silhouettes_at_rays = sample_images_at_mc_locs(
            target_silhouettes[batch_idx, ..., None], sampled_rays.xys)
        sil_err = huber(
            rendered_silhouettes,
            silhouettes_at_rays,
        ).abs().mean()

        # Compute the color error as the mean huber
        # loss between the rendered colors and the
        # sampled target images.
        colors_at_rays = sample_images_at_mc_locs(target_images[batch_idx],
                                                  sampled_rays.xys)
        color_err = huber(
            rendered_images,
            colors_at_rays,
        ).abs().mean()

        # The optimization loss is a simple
        # sum of the color and silhouette errors.
        # TYXE comment: we also add a kl loss for the variational posterior scaled by the size of the data
        # i.e. the total number of data points times the number of values that the data-dependent part of the
        # objective averages over. Effectively I'm treating this as if this was something like a Bernoulli likelihood
        # in a VAE where the expected log likelihood is averaged over both data points and pixels
        beta = kl_factor / (target_images.numel() + target_silhouettes.numel())
        kl_err = neural_radiance_field.cached_kl_loss
        loss = color_err + sil_err + beta * kl_err

        # Log the loss history.
        loss_history_color.append(float(color_err))
        loss_history_sil.append(float(sil_err))

        # Every 10 iterations, print the current values of the losses.
        if iteration % 10 == 0:
            print(f'Iteration {iteration:05d}:' +
                  f' loss color = {float(color_err):1.2e}' +
                  f' loss silhouette = {float(sil_err):1.2e}' +
                  f' loss kl = {float(kl_err):1.2e}' +
                  f' kl_factor = {kl_factor:1.3e}')

        # Take the optimization step.
        loss.backward()
        optimizer.step()

        # TYXE comment: anneal the kl rate
        if iteration >= zero_kl_iters:
            kl_factor = min(max_kl_factor, kl_factor + kl_annealing_rate)

        # Visualize the full renders every 100 iterations.
        if iteration % 1000 == 0:
            show_idx = torch.randperm(len(target_cameras))[:1]
            fig = show_full_render(
                neural_radiance_field,
                FoVPerspectiveCameras(
                    R=target_cameras.R[show_idx],
                    T=target_cameras.T[show_idx],
                    znear=target_cameras.znear[show_idx],
                    zfar=target_cameras.zfar[show_idx],
                    aspect_ratio=target_cameras.aspect_ratio[show_idx],
                    fov=target_cameras.fov[show_idx],
                    device=device,
                ),
                target_images[show_idx][0],
                target_silhouettes[show_idx][0],
                loss_history_color,
                loss_history_sil,
                renderer_grid,
                num_forward=test_samples)
            plt.savefig(f"nerf/full_render{iteration}.png")
            plt.close(fig)

    with torch.no_grad():
        rotating_nerf_frames, uncertainty_frames = generate_rotating_nerf(
            neural_radiance_field,
            target_cameras,
            renderer_grid,
            device,
            n_frames=3 * 5,
            num_forward=test_samples,
            save_visualization=save_visualization)

    for i, (img, uncertainty) in enumerate(
            zip(
                rotating_nerf_frames.clamp(0., 1.).cpu().numpy(),
                uncertainty_frames.cpu().numpy())):
        f, ax = plt.subplots(figsize=(1.625, 1.625))
        f.subplots_adjust(0, 0, 1, 1)
        ax.imshow(img)
        ax.set_axis_off()
        f.savefig(f"nerf/final_image{i}.jpg",
                  bbox_inches="tight",
                  pad_inches=0)
        plt.close(f)

        f, ax = plt.subplots(figsize=(1.625, 1.625))
        f.subplots_adjust(0, 0, 1, 1)
        ax.imshow(uncertainty, cmap="hot", vmax=0.75**0.5)
        ax.set_axis_off()
        f.savefig(f"nerf/final_uncertainty{i}.jpg",
                  bbox_inches="tight",
                  pad_inches=0)
        plt.close(f)

    if save_state_dict is not None:
        if inference != "ml":
            raise ValueError(
                "Saving the state dict is only available for ml inference for now."
            )
        state_dict = dict(
            neural_radiance_field.named_pytorch_parameters(dummy_data))
        torch.save(state_dict, save_state_dict)

    test_cameras, test_images, test_silhouettes = generate_cow_renders(
        num_views=10, azimuth_low=90, azimuth_high=180)

    del renderer_mc
    del target_cameras
    del target_images
    del target_silhouettes
    torch.cuda.empty_cache()

    test_cameras = test_cameras.to(device)
    test_images = test_images.to(device)
    test_silhouettes = test_silhouettes.to(device)

    # TODO remove duplication from training code for test error
    with torch.no_grad():
        sil_err = 0.
        color_err = 0.
        for i in range(len(test_cameras)):
            batch_idx = [i]

            # Sample the minibatch of cameras.
            batch_cameras = FoVPerspectiveCameras(
                R=test_cameras.R[batch_idx],
                T=test_cameras.T[batch_idx],
                znear=test_cameras.znear[batch_idx],
                zfar=test_cameras.zfar[batch_idx],
                aspect_ratio=test_cameras.aspect_ratio[batch_idx],
                fov=test_cameras.fov[batch_idx],
                device=device,
            )

            img_list, sils_list, sampled_rays_list, = [], [], []
            for _ in range(test_samples):
                rendered_images_silhouettes, sampled_rays = renderer_grid(
                    cameras=batch_cameras,
                    volumetric_function=partial(batched_forward,
                                                net=neural_radiance_field))
                imgs, sils = (rendered_images_silhouettes.split([3, 1],
                                                                dim=-1))
                img_list.append(imgs)
                sils_list.append(sils)
                sampled_rays_list.append(sampled_rays.xys)

            assert sampled_rays_list[0].eq(
                torch.stack(sampled_rays_list)).all()

            rendered_images = torch.stack(img_list).mean(0)
            rendered_silhouettes = torch.stack(sils_list).mean(0)

            # Compute the silhoutte error as the mean huber
            # loss between the predicted masks and the
            # sampled target silhouettes.
            # TYXE comment: sampled_rays are always the same for renderer_grid
            silhouettes_at_rays = sample_images_at_mc_locs(
                test_silhouettes[batch_idx, ..., None], sampled_rays.xys)
            sil_err += huber(
                rendered_silhouettes,
                silhouettes_at_rays,
            ).abs().mean().item() / len(test_cameras)

            # Compute the color error as the mean huber
            # loss between the rendered colors and the
            # sampled target images.
            colors_at_rays = sample_images_at_mc_locs(test_images[batch_idx],
                                                      sampled_rays.xys)
            color_err += huber(
                rendered_images,
                colors_at_rays,
            ).abs().mean().item() / len(test_cameras)

    print(f"Test error: sil={sil_err:1.3e}; col={color_err:1.3e}")
Esempio n. 6
0
    # coordinate conventions.
    raysampler_grid = NDCGridRaysampler(
        image_height=render_size,
        image_width=render_size,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # MonteCarloRaysampler generates a random subset
    # of `n_rays_per_image` rays emitted from the image plane.
    raysampler_mc = MonteCarloRaysampler(
        min_x=-1.0,
        max_x=1.0,
        min_y=-1.0,
        max_y=1.0,
        n_rays_per_image=750,
        n_pts_per_ray=128,
        min_depth=0.1,
        max_depth=volume_extent_world,
    )

    # 2) Instantiate the raymarcher.
    # Here, we use the standard EmissionAbsorptionRaymarcher
    # which marches along each ray in order to render
    # the ray into a single 3D color vector
    # and an opacity scalar.
    raymarcher = EmissionAbsorptionRaymarcher()

    # Finally, instantiate the implicit renders
    # for both raysamplers.
    renderer_grid = ImplicitRenderer(