def save_8bit_png(img_and_path): """Save an 8bit numpy array as a PNG on disk. Args: img_and_path: A tuple of an image (numpy array, 8bit, [height, width, channels]) and a path where the image is saved (string). """ img, pth = img_and_path with utils.open_file(pth, 'wb') as imgout: Image.fromarray(img).save(imgout, 'PNG')
def _load_renderings(self, args): """Load images from disk.""" if args.render_path: raise ValueError( "render_path cannot be used for the blender dataset.") with utils.open_file( path.join(args.data_dir, "transforms_{}.json".format(self.split)), "r") as fp: meta = json.load(fp) images = [] cams = [] for i in range(len(meta["frames"])): frame = meta["frames"][i] fname = os.path.join(args.data_dir, frame["file_path"] + ".png") with utils.open_file(fname, "rb") as imgin: image = np.array(Image.open(imgin), dtype=np.float32) / 255. if args.factor == 2: [halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]] image = cv2.resize(image, (halfres_w, halfres_h), interpolation=cv2.INTER_AREA) elif args.factor > 0: raise ValueError( "Blender dataset only supports factor=0 or 2, {} " "set.".format(args.factor)) cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) images.append(image) self.images = np.stack(images, axis=0) if args.white_bkgd: self.images = ( self.images[Ellipsis, :3] * self.images[Ellipsis, -1:] + (1. - self.images[Ellipsis, -1:])) else: self.images = self.images[Ellipsis, :3] self.h, self.w = self.images.shape[1:3] self.resolution = self.h * self.w self.camtoworlds = np.stack(cams, axis=0) camera_angle_x = float(meta["camera_angle_x"]) self.focal = .5 * self.w / np.tan(.5 * camera_angle_x) self.n_examples = self.images.shape[0]
def _load_renderings(self, args): """Load images from disk.""" # Load images. imgdir_suffix = "" if args.factor > 0: imgdir_suffix = "_{}".format(args.factor) factor = args.factor else: factor = 1 imgdir = path.join(args.data_dir, "images" + imgdir_suffix) if not utils.file_exists(imgdir): raise ValueError("Image folder {} doesn't exist.".format(imgdir)) imgfiles = [ path.join(imgdir, f) for f in sorted(utils.listdir(imgdir)) if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") ] images = [] for imgfile in imgfiles: with utils.open_file(imgfile, "rb") as imgin: image = np.array(Image.open(imgin), dtype=np.float32) / 255. images.append(image) images = np.stack(images, axis=-1) # Load poses and bds. with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"), "rb") as fp: poses_arr = np.load(fp) poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) bds = poses_arr[:, -2:].transpose([1, 0]) if poses.shape[-1] != images.shape[-1]: raise RuntimeError("Mismatch between imgs {} and poses {}".format( images.shape[-1], poses.shape[-1])) # Update poses according to downsampling. poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) poses[2, 4, :] = poses[2, 4, :] * 1. / factor # Correct rotation matrix ordering and move variable dim to axis 0. poses = np.concatenate( [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) poses = np.moveaxis(poses, -1, 0).astype(np.float32) images = np.moveaxis(images, -1, 0) bds = np.moveaxis(bds, -1, 0).astype(np.float32) # Rescale according to a default bd factor. scale = 1. / (bds.min() * .75) poses[:, :3, 3] *= scale bds *= scale # Recenter poses. poses = self._recenter_poses(poses) # Generate a spiral/spherical ray path for rendering videos. if args.spherify: poses = self._generate_spherical_poses(poses, bds) self.spherify = True else: self.spherify = False if not args.spherify and self.split == "test": self._generate_spiral_poses(poses, bds) # Select the split. i_test = np.arange(images.shape[0])[::args.llffhold] i_train = np.array( [i for i in np.arange(int(images.shape[0])) if i not in i_test]) if self.split == "train": indices = i_train else: indices = i_test images = images[indices] poses = poses[indices] self.images = images self.camtoworlds = poses[:, :3, :4] self.focal = poses[0, -1, -1] self.h, self.w = images.shape[1:3] self.resolution = self.h * self.w if args.render_path: self.n_examples = self.render_poses.shape[0] else: self.n_examples = images.shape[0]
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image(functools.partial( render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc showcase_features = pred_features showcase_specular = pred_specular if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr( ((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img( pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) summary_writer.image("pred_features", showcase_features, step) summary_writer.image("pred_specular", showcase_specular, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") # The viewdir MLP refinement code needs this, as it assumes that both datasets # are split into images, rather than a unordered bunch of rays. FLAGS.__dict__["batching"] = "single_image" train_dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, test_dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Initialize the parameters dictionaries for SNeRG. (render_params_init, culling_params_init, atlas_params_init, scene_params_init) = params.initialize_params(FLAGS) # Also initialize the JAX functions and tensorflow models needed to evaluate # image quality. quality_evaluator = eval_and_refine.ImageQualityEvaluator() last_step = 0 out_dir = path.join(FLAGS.train_dir, "baked") out_render_dir = path.join(out_dir, "test_preds") if jax.host_id() == 0: utils.makedirs(out_dir) utils.makedirs(out_render_dir) # Make sure that all JAX hosts have reached this point before proceeding. We # need to make sure that out_dir and out_render_dir both exist. export.synchronize_jax_hosts() if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "bake")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue # We interleave explicit calls to garbage collection throughout this loop, # with the hope of alleviating out-of-memory errors on systems with limited # CPU RAM. gc.collect() # Extract the MLPs we need for baking a SNeRG. (mlp_model, mlp_params, viewdir_mlp_model, viewdir_mlp_params) = model_utils.extract_snerg_mlps( state.optimizer.target, scene_params_init) # Render out the low-res grid used for culling. culling_grid_coordinates = baking.build_3d_grid( scene_params_init["min_xyz"], culling_params_init["_voxel_size"], culling_params_init["_grid_size"], scene_params_init["worldspace_T_opengl"], np.dtype(scene_params_init["dtype"])) _, culling_grid_alpha = baking.render_voxel_block( mlp_model, mlp_params, culling_grid_coordinates, culling_params_init["_voxel_size"], scene_params_init) # Early out in case the culling grid is completely empty. if culling_grid_alpha.max() < culling_params_init["alpha_threshold"]: if FLAGS.eval_once: break else: continue # Using this grid, maximize resolution with a tight crop on the scene. (render_params, culling_params, atlas_params, scene_params) = culling.crop_alpha_grid(render_params_init, culling_params_init, atlas_params_init, scene_params_init, culling_grid_alpha) # Recompute the low-res grid using the cropped scene bounds. culling_grid_coordinates = baking.build_3d_grid( scene_params["min_xyz"], culling_params["_voxel_size"], culling_params["_grid_size"], scene_params["worldspace_T_opengl"], np.dtype(scene_params["dtype"])) _, culling_grid_alpha = baking.render_voxel_block( mlp_model, mlp_params, culling_grid_coordinates, culling_params["_voxel_size"], scene_params) # Determine which voxels are visible from the training views. num_training_cameras = train_dataset.camtoworlds.shape[0] culling_grid_visibility = np.zeros_like(culling_grid_alpha) for camera_index in range( 0, num_training_cameras, culling_params["visibility_subsample_factor"]): culling.integrate_visibility_from_image( train_dataset.h * culling_params["visibility_image_factor"], train_dataset.w * culling_params["visibility_image_factor"], train_dataset.focal * culling_params["visibility_image_factor"], train_dataset.camtoworlds[camera_index], culling_grid_alpha, culling_grid_visibility, scene_params, culling_params) # Finally, using this updated low-res grid, compute the maximum alpha # within each macroblock. atlas_grid_alpha = culling.max_downsample_grid(culling_params, atlas_params, culling_grid_alpha) atlas_grid_visibility = culling.max_downsample_grid( culling_params, atlas_params, culling_grid_visibility) # Make the visibility grid more conservative by dilating it. We need to # temporarly cast to float32 here, as ndimage.maximum_filter doesn't work # with float16. atlas_grid_visibility = ndimage.maximum_filter( atlas_grid_visibility.astype(np.float32), culling_params["visibility_grid_dilation"]).astype( atlas_grid_visibility.dtype) # Now we're ready to extract the scene and pack it into a 3D texture atlas. atlas, atlas_block_indices = baking.extract_3d_atlas( mlp_model, mlp_params, scene_params, render_params, atlas_params, culling_params, atlas_grid_alpha, atlas_grid_visibility) # Free up CPU memory wherever we can to avoid OOM in the larger scenes. del atlas_grid_alpha del atlas_grid_visibility del culling_grid_alpha del culling_grid_visibility gc.collect() # Convert the atlas to a tensor, so we can use can use tensorflow's massive # CPU parallelism for ray marching. atlas_block_indices_t = tf.convert_to_tensor(atlas_block_indices) del atlas_block_indices gc.collect() atlas_t_list = [] for i in range(atlas.shape[2]): atlas_t_list.append(tf.convert_to_tensor(atlas[:, :, i, :])) del atlas gc.collect() atlas_t = tf.stack(atlas_t_list, 2) del atlas_t_list gc.collect() # Quantize the atlas to 8-bit precision, as this is the precision will be # working with for the exported PNGs. uint_multiplier = 2.0**8 - 1.0 atlas_t *= uint_multiplier gc.collect() atlas_t = tf.floor(atlas_t) gc.collect() atlas_t = tf.maximum(0.0, atlas_t) gc.collect() atlas_t = tf.minimum(uint_multiplier, atlas_t) gc.collect() atlas_t /= uint_multiplier gc.collect() # Ray march through the baked SNeRG scene to create training data for the # view-depdence MLP. (train_rgbs, _, train_directions, train_refs ) = eval_and_refine.build_sharded_dataset_for_view_dependence( train_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, render_params) # Refine the view-dependence MLP to alleviate the domain gap between a # deferred NeRF scene and the baked SNeRG scene. refined_viewdir_mlp_params = eval_and_refine.refine_view_dependence_mlp( train_rgbs, train_directions, train_refs, viewdir_mlp_model, viewdir_mlp_params, scene_params) del train_rgbs del train_directions del train_refs gc.collect() # Now that we've refined the MLP, create test data with ray marching too. (test_rgbs, _, test_directions, _) = eval_and_refine.build_sharded_dataset_for_view_dependence( test_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, render_params) # Now run the view-dependence on the ray marched output images to add # back view-depdenent effects. Note that we do this both before and after # refining the parameters. pre_refined_images = eval_and_refine.eval_dataset_and_unshard( viewdir_mlp_model, viewdir_mlp_params, test_rgbs, test_directions, test_dataset, scene_params) post_refined_images = eval_and_refine.eval_dataset_and_unshard( viewdir_mlp_model, refined_viewdir_mlp_params, test_rgbs, test_directions, test_dataset, scene_params) del test_rgbs del test_directions gc.collect() # Evaluate image quality metrics for the baked SNeRG scene, both before and # after refining the view-dependence MLP. pre_image_metrics = quality_evaluator.eval_image_list( pre_refined_images, test_dataset.images) post_image_metrics = quality_evaluator.eval_image_list( post_refined_images, test_dataset.images) pre_psnr, pre_ssim = pre_image_metrics[0], pre_image_metrics[1] post_psnr, post_ssim = post_image_metrics[0], post_image_metrics[1] gc.collect() # Export the baked scene so we can view it in the web-viewer. export.export_snerg_scene(out_dir, atlas_t.numpy(), atlas_block_indices_t.numpy(), refined_viewdir_mlp_params, render_params, atlas_params, scene_params, test_dataset.h, test_dataset.w, test_dataset.focal) gc.collect() # Compute the size of the exportet SNeRG scene. png_size_gb, byte_size_gb, float_size_gb = export.compute_scene_size( out_dir, atlas_block_indices_t.numpy(), atlas_params, scene_params) gc.collect() # Finally, export the rendered test set images and update tensorboard. # Parallelize the image export over JAX hosts to speed this up. renders_and_paths = [] paths = [] for i in range(test_dataset.camtoworlds.shape[0]): renders_and_paths.append((post_refined_images[i], path.join(out_render_dir, "{:03d}.png".format(i)))) export.parallel_write_images( lambda render_and_path: utils.save_img( # pylint: disable=g-long-lambda render_and_path[0], render_and_path[1]), renders_and_paths) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("baked_raw_color", pre_refined_images[0], step) summary_writer.image("baked_refined_color", post_refined_images[0], step) summary_writer.image("baked_target", test_dataset.images[0], step) summary_writer.scalar("baked_raw_psnr", pre_psnr, step) summary_writer.scalar("baked_raw_ssim", pre_ssim, step) summary_writer.scalar("baked_refined_psnr", post_psnr, step) summary_writer.scalar("baked_refined_ssim", post_ssim, step) summary_writer.scalar("baked_size_png_gb", png_size_gb, step) summary_writer.scalar("baked_size_byte_gb", byte_size_gb, step) summary_writer.scalar("baked_size_float_gb", float_size_gb, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(post_psnr)) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(post_ssim)) with utils.open_file(path.join(out_dir, "png_gb.txt"), "w") as f: f.write("{}".format(png_size_gb)) with utils.open_file(path.join(out_dir, "byte_gb.txt"), "w") as f: f.write("{}".format(byte_size_gb)) with utils.open_file(path.join(out_dir, "float_gb.txt"), "w") as f: f.write("{}".format(float_size_gb)) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def export_snerg_scene(output_directory, atlas, atlas_block_indices, viewdir_mlp_params, render_params, atlas_params, scene_params, input_height, input_width, input_focal): """Exports a scene to web-viewer format: a collection of PNGs and a JSON file. The scene gets exported to output_directory/png. Any previous results will be overwritten. Args: output_directory: The root directory where the scene gets written. atlas: The SNeRG scene packed as a texture atlas in a [S, S, N, C] numpy array, where the channels C contain both RGB and features. atlas_block_indices: The indirection grid of the SNeRG scene, represented as a numpy int32 array of size (bW, bH, bD, 3). viewdir_mlp_params: A dict containing the MLP parameters for the per-sample view-dependence MLP. render_params: A dict with parameters for high-res rendering. atlas_params: A dict with params for building the 3D texture atlas. scene_params: A dict for scene specific params (bbox, rotation, resolution). input_height: Height (pixels) of the NDC camera (i.e. the training cameras). input_width: Width (pixels) of the NDC camera (i.e. the training cameras). input_focal: Focal length (pixels) of the NDC camera (i.e. the training cameras). """ # Slice the atlas into images. rgbs = [] alphas = [] for i in range(0, atlas.shape[2], 4): rgb_stack = [] alpha_stack = [] for j in range(4): plane_index = i + j rgb_stack.append(atlas[:, :, plane_index, :][Ellipsis, 0:3].transpose([1, 0, 2])) alpha_stack.append( atlas[:, :, plane_index, :][Ellipsis, scene_params['_channels']].transpose( [1, 0])) rgbs.append(np.concatenate(rgb_stack, axis=0)) alphas.append(np.concatenate(alpha_stack, axis=0)) atlas_index_image = np.transpose(atlas_block_indices, [2, 1, 0, 3]).reshape( (-1, atlas_block_indices.shape[0], 3)).astype(np.uint8) # Build a dictionary of the scene parameters, so we can export it as a json. export_scene_params = {} export_scene_params['voxel_size'] = float(render_params['_voxel_size']) export_scene_params['block_size'] = atlas_params['_data_block_size'] export_scene_params['grid_width'] = int(render_params['_grid_size'][0]) export_scene_params['grid_height'] = int(render_params['_grid_size'][1]) export_scene_params['grid_depth'] = int(render_params['_grid_size'][2]) export_scene_params['atlas_width'] = atlas.shape[0] export_scene_params['atlas_height'] = atlas.shape[1] export_scene_params['atlas_depth'] = atlas.shape[2] export_scene_params['num_slices'] = len(rgbs) export_scene_params['min_x'] = float(scene_params['min_xyz'][0]) export_scene_params['min_y'] = float(scene_params['min_xyz'][1]) export_scene_params['min_z'] = float(scene_params['min_xyz'][2]) export_scene_params['atlas_blocks_x'] = int( atlas.shape[0] / atlas_params['atlas_block_size']) export_scene_params['atlas_blocks_y'] = int( atlas.shape[1] / atlas_params['atlas_block_size']) export_scene_params['atlas_blocks_z'] = int( atlas.shape[2] / atlas_params['atlas_block_size']) export_scene_params['input_height'] = float(input_height) export_scene_params['input_width'] = float(input_width) export_scene_params['input_focal'] = float(input_focal) export_scene_params['worldspace_T_opengl'] = scene_params[ 'worldspace_T_opengl'].tolist() export_scene_params['ndc'] = scene_params['ndc'] # Also include the network weights in this dictionary. export_scene_params['0_weights'] = viewdir_mlp_params['params']['Dense_0'][ 'kernel'].tolist() export_scene_params['1_weights'] = viewdir_mlp_params['params']['Dense_1'][ 'kernel'].tolist() export_scene_params['2_weights'] = viewdir_mlp_params['params']['Dense_3'][ 'kernel'].tolist() export_scene_params['0_bias'] = viewdir_mlp_params['params']['Dense_0'][ 'bias'].tolist() export_scene_params['1_bias'] = viewdir_mlp_params['params']['Dense_1'][ 'bias'].tolist() export_scene_params['2_bias'] = viewdir_mlp_params['params']['Dense_3'][ 'bias'].tolist() # To avoid partial overwrites, first dump the scene to a temporary directory. output_tmp_directory = output_directory + '/temp' if jax.host_id() == 0: # Delete the folder if it already exists. if utils.isdir(output_tmp_directory): tf.io.gfile.rmtree(output_tmp_directory) utils.makedirs(output_tmp_directory) # Now store the indirection grid. atlas_indices_path = '%s/atlas_indices.png' % output_tmp_directory if jax.host_id() == 0: save_8bit_png((atlas_index_image, atlas_indices_path)) # Make sure that all JAX hosts have reached this point in the code before we # proceed. Things will get tricky if output_tmp_directory doesn't yet exist. synchronize_jax_hosts() # Save the alpha values and RGB colors as one set of PNG images. output_images = [] output_paths = [] for i, rgb_and_alpha in enumerate(zip(rgbs, alphas)): rgb, alpha = rgb_and_alpha rgba = np.concatenate([rgb, np.expand_dims(alpha, -1)], axis=-1) uint_multiplier = 2.0**8 - 1.0 rgba = np.minimum(uint_multiplier, np.maximum(0.0, np.floor(uint_multiplier * rgba))).astype(np.uint8) output_images.append(rgba) atlas_rgba_path = '%s/rgba_%03d.png' % (output_tmp_directory, i) output_paths.append(atlas_rgba_path) # Save the computed features a separate collection of PNGs. uint_multiplier = 2.0**8 - 1.0 for i in range(0, atlas.shape[2], 4): feature_stack = [] for j in range(4): plane_index = i + j feature_slice = atlas[:, :, plane_index, :][Ellipsis, 3:-1].transpose([1, 0, 2]) feature_slice = np.minimum( uint_multiplier, np.maximum(0.0, np.floor(uint_multiplier * feature_slice))).astype(np.uint8) feature_stack.append(feature_slice) output_images.append(np.concatenate(feature_stack, axis=0)) for i in range(len(rgbs)): output_paths.append('%s/feature_%03d.png' % (output_tmp_directory, i)) parallel_write_images(save_8bit_png, list(zip(output_images, output_paths))) # Now export the scene parameters and the network weights as a JSON. export_scene_params['format'] = 'png' scene_params_path = '%s/scene_params.json' % output_tmp_directory if jax.host_id() == 0: with utils.open_file(scene_params_path, 'wb') as f: f.write(json.dumps(export_scene_params).encode('utf-8')) # Again, make sure that the JAX hosts are in sync. Don't delete # output_tmp_directory before all files have been written. synchronize_jax_hosts() # Finally move the scene to the appropriate output path. output_png_directory = output_directory + '/png' if jax.host_id() == 0: # Delete the folder if it already exists. if utils.isdir(output_png_directory): tf.io.gfile.rmtree(output_png_directory) tf.io.gfile.rename(output_tmp_directory, output_png_directory)