Esempio n. 1
0
def save_8bit_png(img_and_path):
    """Save an 8bit numpy array as a PNG on disk.

  Args:
    img_and_path: A tuple of an image (numpy array, 8bit,
      [height, width, channels]) and a path where the image is saved (string).
  """
    img, pth = img_and_path
    with utils.open_file(pth, 'wb') as imgout:
        Image.fromarray(img).save(imgout, 'PNG')
Esempio n. 2
0
 def _load_renderings(self, args):
     """Load images from disk."""
     if args.render_path:
         raise ValueError(
             "render_path cannot be used for the blender dataset.")
     with utils.open_file(
             path.join(args.data_dir,
                       "transforms_{}.json".format(self.split)), "r") as fp:
         meta = json.load(fp)
     images = []
     cams = []
     for i in range(len(meta["frames"])):
         frame = meta["frames"][i]
         fname = os.path.join(args.data_dir, frame["file_path"] + ".png")
         with utils.open_file(fname, "rb") as imgin:
             image = np.array(Image.open(imgin), dtype=np.float32) / 255.
             if args.factor == 2:
                 [halfres_h,
                  halfres_w] = [hw // 2 for hw in image.shape[:2]]
                 image = cv2.resize(image, (halfres_w, halfres_h),
                                    interpolation=cv2.INTER_AREA)
             elif args.factor > 0:
                 raise ValueError(
                     "Blender dataset only supports factor=0 or 2, {} "
                     "set.".format(args.factor))
         cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
         images.append(image)
     self.images = np.stack(images, axis=0)
     if args.white_bkgd:
         self.images = (
             self.images[Ellipsis, :3] * self.images[Ellipsis, -1:] +
             (1. - self.images[Ellipsis, -1:]))
     else:
         self.images = self.images[Ellipsis, :3]
     self.h, self.w = self.images.shape[1:3]
     self.resolution = self.h * self.w
     self.camtoworlds = np.stack(cams, axis=0)
     camera_angle_x = float(meta["camera_angle_x"])
     self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
     self.n_examples = self.images.shape[0]
Esempio n. 3
0
    def _load_renderings(self, args):
        """Load images from disk."""
        # Load images.
        imgdir_suffix = ""
        if args.factor > 0:
            imgdir_suffix = "_{}".format(args.factor)
            factor = args.factor
        else:
            factor = 1
        imgdir = path.join(args.data_dir, "images" + imgdir_suffix)
        if not utils.file_exists(imgdir):
            raise ValueError("Image folder {} doesn't exist.".format(imgdir))
        imgfiles = [
            path.join(imgdir, f) for f in sorted(utils.listdir(imgdir))
            if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
        ]
        images = []
        for imgfile in imgfiles:
            with utils.open_file(imgfile, "rb") as imgin:
                image = np.array(Image.open(imgin), dtype=np.float32) / 255.
                images.append(image)
        images = np.stack(images, axis=-1)

        # Load poses and bds.
        with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"),
                             "rb") as fp:
            poses_arr = np.load(fp)
        poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
        bds = poses_arr[:, -2:].transpose([1, 0])
        if poses.shape[-1] != images.shape[-1]:
            raise RuntimeError("Mismatch between imgs {} and poses {}".format(
                images.shape[-1], poses.shape[-1]))

        # Update poses according to downsampling.
        poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
        poses[2, 4, :] = poses[2, 4, :] * 1. / factor

        # Correct rotation matrix ordering and move variable dim to axis 0.
        poses = np.concatenate(
            [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
        poses = np.moveaxis(poses, -1, 0).astype(np.float32)
        images = np.moveaxis(images, -1, 0)
        bds = np.moveaxis(bds, -1, 0).astype(np.float32)

        # Rescale according to a default bd factor.
        scale = 1. / (bds.min() * .75)
        poses[:, :3, 3] *= scale
        bds *= scale

        # Recenter poses.
        poses = self._recenter_poses(poses)

        # Generate a spiral/spherical ray path for rendering videos.
        if args.spherify:
            poses = self._generate_spherical_poses(poses, bds)
            self.spherify = True
        else:
            self.spherify = False
        if not args.spherify and self.split == "test":
            self._generate_spiral_poses(poses, bds)

        # Select the split.
        i_test = np.arange(images.shape[0])[::args.llffhold]
        i_train = np.array(
            [i for i in np.arange(int(images.shape[0])) if i not in i_test])
        if self.split == "train":
            indices = i_train
        else:
            indices = i_test
        images = images[indices]
        poses = poses[indices]

        self.images = images
        self.camtoworlds = poses[:, :3, :4]
        self.focal = poses[0, -1, -1]
        self.h, self.w = images.shape[1:3]
        self.resolution = self.h * self.w
        if args.render_path:
            self.n_examples = self.render_poses.shape[0]
        else:
            self.n_examples = images.shape[0]
Esempio n. 4
0
def main(unused_argv):
    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = random.PRNGKey(20200823)

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    dataset = datasets.get_dataset("test", FLAGS)
    rng, key = random.split(rng)
    model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables

    # Rendering is forced to be deterministic even if training was randomized, as
    # this eliminates "speckle" artifacts.
    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              False),
                                  axis_name="batch")

    # pmap over only the data input.
    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),
        donate_argnums=3,
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    last_step = 0
    out_dir = path.join(FLAGS.train_dir,
                        "path_renders" if FLAGS.render_path else "test_preds")
    if not FLAGS.eval_once:
        summary_writer = tensorboard.SummaryWriter(
            path.join(FLAGS.train_dir, "eval"))
    while True:
        state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
        step = int(state.optimizer.state.step)
        if step <= last_step:
            continue
        if FLAGS.save_output and (not utils.isdir(out_dir)):
            utils.makedirs(out_dir)
        psnr_values = []
        ssim_values = []
        if not FLAGS.eval_once:
            showcase_index = np.random.randint(0, dataset.size)
        for idx in range(dataset.size):
            print(f"Evaluating {idx+1}/{dataset.size}")
            batch = next(dataset)
            (pred_color, pred_disp, pred_acc, pred_features,
             pred_specular) = utils.render_image(functools.partial(
                 render_pfn, state.optimizer.target),
                                                 batch["rays"],
                                                 rng,
                                                 FLAGS.dataset == "llff",
                                                 chunk=FLAGS.chunk)
            if jax.host_id() != 0:  # Only record via host 0.
                continue

            if not FLAGS.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                showcase_features = pred_features
                showcase_specular = pred_specular
                if not FLAGS.render_path:
                    showcase_gt = batch["pixels"]

            if not FLAGS.render_path:
                psnr = utils.compute_psnr(
                    ((pred_color - batch["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, batch["pixels"])
                print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
                psnr_values.append(float(psnr))
                ssim_values.append(float(ssim))

            if FLAGS.save_output:
                utils.save_img(pred_color,
                               path.join(out_dir, "{:03d}.png".format(idx)))
                utils.save_img(
                    pred_disp[Ellipsis, 0],
                    path.join(out_dir, "disp_{:03d}.png".format(idx)))

        if (not FLAGS.eval_once) and (jax.host_id() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            summary_writer.image("pred_disp", showcase_disp, step)
            summary_writer.image("pred_acc", showcase_acc, step)
            summary_writer.image("pred_features", showcase_features, step)
            summary_writer.image("pred_specular", showcase_specular, step)
            if not FLAGS.render_path:
                summary_writer.scalar("psnr", np.mean(np.array(psnr_values)),
                                      step)
                summary_writer.scalar("ssim", np.mean(np.array(ssim_values)),
                                      step)
                summary_writer.image("target", showcase_gt, step)

        if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id()
                                                              == 0):
            with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"),
                                 "w") as f:
                f.write(" ".join([str(v) for v in psnr_values]))
            with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"),
                                 "w") as f:
                f.write(" ".join([str(v) for v in ssim_values]))
            with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
                f.write("{}".format(np.mean(np.array(psnr_values))))
            with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
                f.write("{}".format(np.mean(np.array(ssim_values))))

        if FLAGS.eval_once:
            break
        if int(step) >= FLAGS.max_steps:
            break
        last_step = step
Esempio n. 5
0
def main(unused_argv):
    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = random.PRNGKey(20200823)

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    # The viewdir MLP refinement code needs this, as it assumes that both datasets
    # are split into images, rather than a unordered bunch of rays.
    FLAGS.__dict__["batching"] = "single_image"

    train_dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    rng, key = random.split(rng)
    model, init_variables = models.get_model(key, test_dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables

    # Initialize the parameters dictionaries for SNeRG.
    (render_params_init, culling_params_init, atlas_params_init,
     scene_params_init) = params.initialize_params(FLAGS)

    # Also initialize the JAX functions and tensorflow models needed to evaluate
    # image quality.
    quality_evaluator = eval_and_refine.ImageQualityEvaluator()

    last_step = 0
    out_dir = path.join(FLAGS.train_dir, "baked")
    out_render_dir = path.join(out_dir, "test_preds")
    if jax.host_id() == 0:
        utils.makedirs(out_dir)
        utils.makedirs(out_render_dir)

    # Make sure that all JAX hosts have reached this point before proceeding. We
    # need to make sure that out_dir and out_render_dir both exist.
    export.synchronize_jax_hosts()

    if not FLAGS.eval_once:
        summary_writer = tensorboard.SummaryWriter(
            path.join(FLAGS.train_dir, "bake"))

    while True:
        state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
        step = int(state.optimizer.state.step)
        if step <= last_step:
            continue

        # We interleave explicit calls to garbage collection throughout this loop,
        # with the hope of alleviating out-of-memory errors on systems with limited
        # CPU RAM.
        gc.collect()

        # Extract the MLPs we need for baking a SNeRG.
        (mlp_model, mlp_params, viewdir_mlp_model,
         viewdir_mlp_params) = model_utils.extract_snerg_mlps(
             state.optimizer.target, scene_params_init)

        # Render out the low-res grid used for culling.
        culling_grid_coordinates = baking.build_3d_grid(
            scene_params_init["min_xyz"], culling_params_init["_voxel_size"],
            culling_params_init["_grid_size"],
            scene_params_init["worldspace_T_opengl"],
            np.dtype(scene_params_init["dtype"]))
        _, culling_grid_alpha = baking.render_voxel_block(
            mlp_model, mlp_params, culling_grid_coordinates,
            culling_params_init["_voxel_size"], scene_params_init)

        # Early out in case the culling grid is completely empty.
        if culling_grid_alpha.max() < culling_params_init["alpha_threshold"]:
            if FLAGS.eval_once:
                break
            else:
                continue

        # Using this grid, maximize resolution with a tight crop on the scene.
        (render_params, culling_params, atlas_params,
         scene_params) = culling.crop_alpha_grid(render_params_init,
                                                 culling_params_init,
                                                 atlas_params_init,
                                                 scene_params_init,
                                                 culling_grid_alpha)

        # Recompute the low-res grid using the cropped scene bounds.
        culling_grid_coordinates = baking.build_3d_grid(
            scene_params["min_xyz"], culling_params["_voxel_size"],
            culling_params["_grid_size"], scene_params["worldspace_T_opengl"],
            np.dtype(scene_params["dtype"]))
        _, culling_grid_alpha = baking.render_voxel_block(
            mlp_model, mlp_params, culling_grid_coordinates,
            culling_params["_voxel_size"], scene_params)

        # Determine which voxels are visible from the training views.
        num_training_cameras = train_dataset.camtoworlds.shape[0]
        culling_grid_visibility = np.zeros_like(culling_grid_alpha)
        for camera_index in range(
                0, num_training_cameras,
                culling_params["visibility_subsample_factor"]):
            culling.integrate_visibility_from_image(
                train_dataset.h * culling_params["visibility_image_factor"],
                train_dataset.w * culling_params["visibility_image_factor"],
                train_dataset.focal *
                culling_params["visibility_image_factor"],
                train_dataset.camtoworlds[camera_index], culling_grid_alpha,
                culling_grid_visibility, scene_params, culling_params)

        # Finally, using this updated low-res grid, compute the maximum alpha
        # within each macroblock.
        atlas_grid_alpha = culling.max_downsample_grid(culling_params,
                                                       atlas_params,
                                                       culling_grid_alpha)
        atlas_grid_visibility = culling.max_downsample_grid(
            culling_params, atlas_params, culling_grid_visibility)

        # Make the visibility grid more conservative by dilating it. We need to
        # temporarly cast to float32 here, as ndimage.maximum_filter doesn't work
        # with float16.
        atlas_grid_visibility = ndimage.maximum_filter(
            atlas_grid_visibility.astype(np.float32),
            culling_params["visibility_grid_dilation"]).astype(
                atlas_grid_visibility.dtype)

        # Now we're ready to extract the scene and pack it into a 3D texture atlas.
        atlas, atlas_block_indices = baking.extract_3d_atlas(
            mlp_model, mlp_params, scene_params, render_params, atlas_params,
            culling_params, atlas_grid_alpha, atlas_grid_visibility)

        # Free up CPU memory wherever we can to avoid OOM in the larger scenes.
        del atlas_grid_alpha
        del atlas_grid_visibility
        del culling_grid_alpha
        del culling_grid_visibility
        gc.collect()

        # Convert the atlas to a tensor, so we can use can use tensorflow's massive
        # CPU parallelism for ray marching.
        atlas_block_indices_t = tf.convert_to_tensor(atlas_block_indices)
        del atlas_block_indices
        gc.collect()

        atlas_t_list = []
        for i in range(atlas.shape[2]):
            atlas_t_list.append(tf.convert_to_tensor(atlas[:, :, i, :]))
        del atlas
        gc.collect()

        atlas_t = tf.stack(atlas_t_list, 2)
        del atlas_t_list
        gc.collect()

        # Quantize the atlas to 8-bit precision, as this is the precision will be
        # working with for the exported PNGs.
        uint_multiplier = 2.0**8 - 1.0
        atlas_t *= uint_multiplier
        gc.collect()
        atlas_t = tf.floor(atlas_t)
        gc.collect()
        atlas_t = tf.maximum(0.0, atlas_t)
        gc.collect()
        atlas_t = tf.minimum(uint_multiplier, atlas_t)
        gc.collect()
        atlas_t /= uint_multiplier
        gc.collect()

        # Ray march through the baked SNeRG scene to create training data for the
        # view-depdence MLP.
        (train_rgbs, _, train_directions, train_refs
         ) = eval_and_refine.build_sharded_dataset_for_view_dependence(
             train_dataset, atlas_t, atlas_block_indices_t, atlas_params,
             scene_params, render_params)

        # Refine the view-dependence MLP to alleviate the domain gap between a
        # deferred NeRF scene and the baked SNeRG scene.
        refined_viewdir_mlp_params = eval_and_refine.refine_view_dependence_mlp(
            train_rgbs, train_directions, train_refs, viewdir_mlp_model,
            viewdir_mlp_params, scene_params)
        del train_rgbs
        del train_directions
        del train_refs
        gc.collect()

        # Now that we've refined the MLP, create test data with ray marching too.
        (test_rgbs, _, test_directions,
         _) = eval_and_refine.build_sharded_dataset_for_view_dependence(
             test_dataset, atlas_t, atlas_block_indices_t, atlas_params,
             scene_params, render_params)

        # Now run the view-dependence on the ray marched output images to add
        # back view-depdenent effects. Note that we do this both before and after
        # refining the parameters.
        pre_refined_images = eval_and_refine.eval_dataset_and_unshard(
            viewdir_mlp_model, viewdir_mlp_params, test_rgbs, test_directions,
            test_dataset, scene_params)
        post_refined_images = eval_and_refine.eval_dataset_and_unshard(
            viewdir_mlp_model, refined_viewdir_mlp_params, test_rgbs,
            test_directions, test_dataset, scene_params)
        del test_rgbs
        del test_directions
        gc.collect()

        # Evaluate image quality metrics for the baked SNeRG scene, both before and
        # after refining the  view-dependence MLP.
        pre_image_metrics = quality_evaluator.eval_image_list(
            pre_refined_images, test_dataset.images)
        post_image_metrics = quality_evaluator.eval_image_list(
            post_refined_images, test_dataset.images)
        pre_psnr, pre_ssim = pre_image_metrics[0], pre_image_metrics[1]
        post_psnr, post_ssim = post_image_metrics[0], post_image_metrics[1]
        gc.collect()

        # Export the baked scene so we can view it in the web-viewer.
        export.export_snerg_scene(out_dir, atlas_t.numpy(),
                                  atlas_block_indices_t.numpy(),
                                  refined_viewdir_mlp_params, render_params,
                                  atlas_params, scene_params, test_dataset.h,
                                  test_dataset.w, test_dataset.focal)
        gc.collect()

        # Compute the size of the exportet SNeRG scene.
        png_size_gb, byte_size_gb, float_size_gb = export.compute_scene_size(
            out_dir, atlas_block_indices_t.numpy(), atlas_params, scene_params)
        gc.collect()

        # Finally, export the rendered test set images and update tensorboard.

        # Parallelize the image export over JAX hosts to speed this up.
        renders_and_paths = []
        paths = []
        for i in range(test_dataset.camtoworlds.shape[0]):
            renders_and_paths.append((post_refined_images[i],
                                      path.join(out_render_dir,
                                                "{:03d}.png".format(i))))
        export.parallel_write_images(
            lambda render_and_path: utils.save_img(  # pylint: disable=g-long-lambda
                render_and_path[0], render_and_path[1]),
            renders_and_paths)

        if (not FLAGS.eval_once) and (jax.host_id() == 0):
            summary_writer.image("baked_raw_color", pre_refined_images[0],
                                 step)
            summary_writer.image("baked_refined_color", post_refined_images[0],
                                 step)
            summary_writer.image("baked_target", test_dataset.images[0], step)
            summary_writer.scalar("baked_raw_psnr", pre_psnr, step)
            summary_writer.scalar("baked_raw_ssim", pre_ssim, step)
            summary_writer.scalar("baked_refined_psnr", post_psnr, step)
            summary_writer.scalar("baked_refined_ssim", post_ssim, step)
            summary_writer.scalar("baked_size_png_gb", png_size_gb, step)
            summary_writer.scalar("baked_size_byte_gb", byte_size_gb, step)
            summary_writer.scalar("baked_size_float_gb", float_size_gb, step)

        if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id()
                                                              == 0):
            with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
                f.write("{}".format(post_psnr))
            with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
                f.write("{}".format(post_ssim))
            with utils.open_file(path.join(out_dir, "png_gb.txt"), "w") as f:
                f.write("{}".format(png_size_gb))
            with utils.open_file(path.join(out_dir, "byte_gb.txt"), "w") as f:
                f.write("{}".format(byte_size_gb))
            with utils.open_file(path.join(out_dir, "float_gb.txt"), "w") as f:
                f.write("{}".format(float_size_gb))

        if FLAGS.eval_once:
            break

        if int(step) >= FLAGS.max_steps:
            break

        last_step = step
Esempio n. 6
0
def export_snerg_scene(output_directory, atlas, atlas_block_indices,
                       viewdir_mlp_params, render_params, atlas_params,
                       scene_params, input_height, input_width, input_focal):
    """Exports a scene to web-viewer format: a collection of PNGs and a JSON file.

  The scene gets exported to output_directory/png. Any previous results will
  be overwritten.

  Args:
    output_directory: The root directory where the scene gets written.
    atlas: The SNeRG scene packed as a texture atlas in a [S, S, N, C] numpy
      array, where the channels C contain both RGB and features.
    atlas_block_indices: The indirection grid of the SNeRG scene, represented as
      a numpy int32 array of size (bW, bH, bD, 3).
    viewdir_mlp_params: A dict containing the MLP parameters for the per-sample
      view-dependence MLP.
    render_params: A dict with parameters for high-res rendering.
    atlas_params: A dict with params for building the 3D texture atlas.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    input_height: Height (pixels) of the NDC camera (i.e. the training cameras).
    input_width: Width (pixels) of the NDC camera (i.e. the training cameras).
    input_focal: Focal length (pixels) of the NDC camera (i.e. the
      training cameras).
  """
    # Slice the atlas into images.
    rgbs = []
    alphas = []
    for i in range(0, atlas.shape[2], 4):
        rgb_stack = []
        alpha_stack = []
        for j in range(4):
            plane_index = i + j
            rgb_stack.append(atlas[:, :,
                                   plane_index, :][Ellipsis,
                                                   0:3].transpose([1, 0, 2]))
            alpha_stack.append(
                atlas[:, :,
                      plane_index, :][Ellipsis,
                                      scene_params['_channels']].transpose(
                                          [1, 0]))
        rgbs.append(np.concatenate(rgb_stack, axis=0))
        alphas.append(np.concatenate(alpha_stack, axis=0))

    atlas_index_image = np.transpose(atlas_block_indices,
                                     [2, 1, 0, 3]).reshape(
                                         (-1, atlas_block_indices.shape[0],
                                          3)).astype(np.uint8)

    # Build a dictionary of the scene parameters, so we can export it as a json.
    export_scene_params = {}
    export_scene_params['voxel_size'] = float(render_params['_voxel_size'])
    export_scene_params['block_size'] = atlas_params['_data_block_size']
    export_scene_params['grid_width'] = int(render_params['_grid_size'][0])
    export_scene_params['grid_height'] = int(render_params['_grid_size'][1])
    export_scene_params['grid_depth'] = int(render_params['_grid_size'][2])
    export_scene_params['atlas_width'] = atlas.shape[0]
    export_scene_params['atlas_height'] = atlas.shape[1]
    export_scene_params['atlas_depth'] = atlas.shape[2]
    export_scene_params['num_slices'] = len(rgbs)

    export_scene_params['min_x'] = float(scene_params['min_xyz'][0])
    export_scene_params['min_y'] = float(scene_params['min_xyz'][1])
    export_scene_params['min_z'] = float(scene_params['min_xyz'][2])

    export_scene_params['atlas_blocks_x'] = int(
        atlas.shape[0] / atlas_params['atlas_block_size'])
    export_scene_params['atlas_blocks_y'] = int(
        atlas.shape[1] / atlas_params['atlas_block_size'])
    export_scene_params['atlas_blocks_z'] = int(
        atlas.shape[2] / atlas_params['atlas_block_size'])

    export_scene_params['input_height'] = float(input_height)
    export_scene_params['input_width'] = float(input_width)
    export_scene_params['input_focal'] = float(input_focal)

    export_scene_params['worldspace_T_opengl'] = scene_params[
        'worldspace_T_opengl'].tolist()
    export_scene_params['ndc'] = scene_params['ndc']

    # Also include the network weights in this dictionary.
    export_scene_params['0_weights'] = viewdir_mlp_params['params']['Dense_0'][
        'kernel'].tolist()
    export_scene_params['1_weights'] = viewdir_mlp_params['params']['Dense_1'][
        'kernel'].tolist()
    export_scene_params['2_weights'] = viewdir_mlp_params['params']['Dense_3'][
        'kernel'].tolist()
    export_scene_params['0_bias'] = viewdir_mlp_params['params']['Dense_0'][
        'bias'].tolist()
    export_scene_params['1_bias'] = viewdir_mlp_params['params']['Dense_1'][
        'bias'].tolist()
    export_scene_params['2_bias'] = viewdir_mlp_params['params']['Dense_3'][
        'bias'].tolist()

    # To avoid partial overwrites, first dump the scene to a temporary directory.
    output_tmp_directory = output_directory + '/temp'

    if jax.host_id() == 0:
        # Delete the folder if it already exists.
        if utils.isdir(output_tmp_directory):
            tf.io.gfile.rmtree(output_tmp_directory)
        utils.makedirs(output_tmp_directory)

    # Now store the indirection grid.
    atlas_indices_path = '%s/atlas_indices.png' % output_tmp_directory
    if jax.host_id() == 0:
        save_8bit_png((atlas_index_image, atlas_indices_path))

    # Make sure that all JAX hosts have reached this point in the code before we
    # proceed. Things will get tricky if output_tmp_directory doesn't yet exist.
    synchronize_jax_hosts()

    # Save the alpha values and RGB colors as one set of PNG images.
    output_images = []
    output_paths = []
    for i, rgb_and_alpha in enumerate(zip(rgbs, alphas)):
        rgb, alpha = rgb_and_alpha
        rgba = np.concatenate([rgb, np.expand_dims(alpha, -1)], axis=-1)
        uint_multiplier = 2.0**8 - 1.0
        rgba = np.minimum(uint_multiplier,
                          np.maximum(0.0, np.floor(uint_multiplier *
                                                   rgba))).astype(np.uint8)
        output_images.append(rgba)
        atlas_rgba_path = '%s/rgba_%03d.png' % (output_tmp_directory, i)
        output_paths.append(atlas_rgba_path)

    # Save the computed features a separate collection of PNGs.
    uint_multiplier = 2.0**8 - 1.0
    for i in range(0, atlas.shape[2], 4):
        feature_stack = []
        for j in range(4):
            plane_index = i + j
            feature_slice = atlas[:, :,
                                  plane_index, :][Ellipsis,
                                                  3:-1].transpose([1, 0, 2])
            feature_slice = np.minimum(
                uint_multiplier,
                np.maximum(0.0, np.floor(uint_multiplier *
                                         feature_slice))).astype(np.uint8)
            feature_stack.append(feature_slice)
        output_images.append(np.concatenate(feature_stack, axis=0))

    for i in range(len(rgbs)):
        output_paths.append('%s/feature_%03d.png' % (output_tmp_directory, i))

    parallel_write_images(save_8bit_png, list(zip(output_images,
                                                  output_paths)))

    # Now export the scene parameters and the network weights as a JSON.
    export_scene_params['format'] = 'png'
    scene_params_path = '%s/scene_params.json' % output_tmp_directory
    if jax.host_id() == 0:
        with utils.open_file(scene_params_path, 'wb') as f:
            f.write(json.dumps(export_scene_params).encode('utf-8'))

    # Again, make sure that the JAX hosts are in sync. Don't delete
    # output_tmp_directory before all files have been written.
    synchronize_jax_hosts()

    # Finally move the scene to the appropriate output path.
    output_png_directory = output_directory + '/png'
    if jax.host_id() == 0:
        # Delete the folder if it already exists.
        if utils.isdir(output_png_directory):
            tf.io.gfile.rmtree(output_png_directory)
        tf.io.gfile.rename(output_tmp_directory, output_png_directory)