def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image(functools.partial( render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc showcase_features = pred_features showcase_specular = pred_specular if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr( ((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img( pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) summary_writer.image("pred_features", showcase_features, step) summary_writer.image("pred_specular", showcase_specular, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError("Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial( utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap( functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2,)) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3,), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_sparsity", stats.sparsity[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("train_sparsity_coarse", stats.sparsity_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map(lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec") summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_pred_features", pred_features, step) summary_writer.image("test_pred_specular", pred_specular, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") # The viewdir MLP refinement code needs this, as it assumes that both datasets # are split into images, rather than a unordered bunch of rays. FLAGS.__dict__["batching"] = "single_image" train_dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, test_dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Initialize the parameters dictionaries for SNeRG. (render_params_init, culling_params_init, atlas_params_init, scene_params_init) = params.initialize_params(FLAGS) # Also initialize the JAX functions and tensorflow models needed to evaluate # image quality. quality_evaluator = eval_and_refine.ImageQualityEvaluator() last_step = 0 out_dir = path.join(FLAGS.train_dir, "baked") out_render_dir = path.join(out_dir, "test_preds") if jax.host_id() == 0: utils.makedirs(out_dir) utils.makedirs(out_render_dir) # Make sure that all JAX hosts have reached this point before proceeding. We # need to make sure that out_dir and out_render_dir both exist. export.synchronize_jax_hosts() if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "bake")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue # We interleave explicit calls to garbage collection throughout this loop, # with the hope of alleviating out-of-memory errors on systems with limited # CPU RAM. gc.collect() # Extract the MLPs we need for baking a SNeRG. (mlp_model, mlp_params, viewdir_mlp_model, viewdir_mlp_params) = model_utils.extract_snerg_mlps( state.optimizer.target, scene_params_init) # Render out the low-res grid used for culling. culling_grid_coordinates = baking.build_3d_grid( scene_params_init["min_xyz"], culling_params_init["_voxel_size"], culling_params_init["_grid_size"], scene_params_init["worldspace_T_opengl"], np.dtype(scene_params_init["dtype"])) _, culling_grid_alpha = baking.render_voxel_block( mlp_model, mlp_params, culling_grid_coordinates, culling_params_init["_voxel_size"], scene_params_init) # Early out in case the culling grid is completely empty. if culling_grid_alpha.max() < culling_params_init["alpha_threshold"]: if FLAGS.eval_once: break else: continue # Using this grid, maximize resolution with a tight crop on the scene. (render_params, culling_params, atlas_params, scene_params) = culling.crop_alpha_grid(render_params_init, culling_params_init, atlas_params_init, scene_params_init, culling_grid_alpha) # Recompute the low-res grid using the cropped scene bounds. culling_grid_coordinates = baking.build_3d_grid( scene_params["min_xyz"], culling_params["_voxel_size"], culling_params["_grid_size"], scene_params["worldspace_T_opengl"], np.dtype(scene_params["dtype"])) _, culling_grid_alpha = baking.render_voxel_block( mlp_model, mlp_params, culling_grid_coordinates, culling_params["_voxel_size"], scene_params) # Determine which voxels are visible from the training views. num_training_cameras = train_dataset.camtoworlds.shape[0] culling_grid_visibility = np.zeros_like(culling_grid_alpha) for camera_index in range( 0, num_training_cameras, culling_params["visibility_subsample_factor"]): culling.integrate_visibility_from_image( train_dataset.h * culling_params["visibility_image_factor"], train_dataset.w * culling_params["visibility_image_factor"], train_dataset.focal * culling_params["visibility_image_factor"], train_dataset.camtoworlds[camera_index], culling_grid_alpha, culling_grid_visibility, scene_params, culling_params) # Finally, using this updated low-res grid, compute the maximum alpha # within each macroblock. atlas_grid_alpha = culling.max_downsample_grid(culling_params, atlas_params, culling_grid_alpha) atlas_grid_visibility = culling.max_downsample_grid( culling_params, atlas_params, culling_grid_visibility) # Make the visibility grid more conservative by dilating it. We need to # temporarly cast to float32 here, as ndimage.maximum_filter doesn't work # with float16. atlas_grid_visibility = ndimage.maximum_filter( atlas_grid_visibility.astype(np.float32), culling_params["visibility_grid_dilation"]).astype( atlas_grid_visibility.dtype) # Now we're ready to extract the scene and pack it into a 3D texture atlas. atlas, atlas_block_indices = baking.extract_3d_atlas( mlp_model, mlp_params, scene_params, render_params, atlas_params, culling_params, atlas_grid_alpha, atlas_grid_visibility) # Free up CPU memory wherever we can to avoid OOM in the larger scenes. del atlas_grid_alpha del atlas_grid_visibility del culling_grid_alpha del culling_grid_visibility gc.collect() # Convert the atlas to a tensor, so we can use can use tensorflow's massive # CPU parallelism for ray marching. atlas_block_indices_t = tf.convert_to_tensor(atlas_block_indices) del atlas_block_indices gc.collect() atlas_t_list = [] for i in range(atlas.shape[2]): atlas_t_list.append(tf.convert_to_tensor(atlas[:, :, i, :])) del atlas gc.collect() atlas_t = tf.stack(atlas_t_list, 2) del atlas_t_list gc.collect() # Quantize the atlas to 8-bit precision, as this is the precision will be # working with for the exported PNGs. uint_multiplier = 2.0**8 - 1.0 atlas_t *= uint_multiplier gc.collect() atlas_t = tf.floor(atlas_t) gc.collect() atlas_t = tf.maximum(0.0, atlas_t) gc.collect() atlas_t = tf.minimum(uint_multiplier, atlas_t) gc.collect() atlas_t /= uint_multiplier gc.collect() # Ray march through the baked SNeRG scene to create training data for the # view-depdence MLP. (train_rgbs, _, train_directions, train_refs ) = eval_and_refine.build_sharded_dataset_for_view_dependence( train_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, render_params) # Refine the view-dependence MLP to alleviate the domain gap between a # deferred NeRF scene and the baked SNeRG scene. refined_viewdir_mlp_params = eval_and_refine.refine_view_dependence_mlp( train_rgbs, train_directions, train_refs, viewdir_mlp_model, viewdir_mlp_params, scene_params) del train_rgbs del train_directions del train_refs gc.collect() # Now that we've refined the MLP, create test data with ray marching too. (test_rgbs, _, test_directions, _) = eval_and_refine.build_sharded_dataset_for_view_dependence( test_dataset, atlas_t, atlas_block_indices_t, atlas_params, scene_params, render_params) # Now run the view-dependence on the ray marched output images to add # back view-depdenent effects. Note that we do this both before and after # refining the parameters. pre_refined_images = eval_and_refine.eval_dataset_and_unshard( viewdir_mlp_model, viewdir_mlp_params, test_rgbs, test_directions, test_dataset, scene_params) post_refined_images = eval_and_refine.eval_dataset_and_unshard( viewdir_mlp_model, refined_viewdir_mlp_params, test_rgbs, test_directions, test_dataset, scene_params) del test_rgbs del test_directions gc.collect() # Evaluate image quality metrics for the baked SNeRG scene, both before and # after refining the view-dependence MLP. pre_image_metrics = quality_evaluator.eval_image_list( pre_refined_images, test_dataset.images) post_image_metrics = quality_evaluator.eval_image_list( post_refined_images, test_dataset.images) pre_psnr, pre_ssim = pre_image_metrics[0], pre_image_metrics[1] post_psnr, post_ssim = post_image_metrics[0], post_image_metrics[1] gc.collect() # Export the baked scene so we can view it in the web-viewer. export.export_snerg_scene(out_dir, atlas_t.numpy(), atlas_block_indices_t.numpy(), refined_viewdir_mlp_params, render_params, atlas_params, scene_params, test_dataset.h, test_dataset.w, test_dataset.focal) gc.collect() # Compute the size of the exportet SNeRG scene. png_size_gb, byte_size_gb, float_size_gb = export.compute_scene_size( out_dir, atlas_block_indices_t.numpy(), atlas_params, scene_params) gc.collect() # Finally, export the rendered test set images and update tensorboard. # Parallelize the image export over JAX hosts to speed this up. renders_and_paths = [] paths = [] for i in range(test_dataset.camtoworlds.shape[0]): renders_and_paths.append((post_refined_images[i], path.join(out_render_dir, "{:03d}.png".format(i)))) export.parallel_write_images( lambda render_and_path: utils.save_img( # pylint: disable=g-long-lambda render_and_path[0], render_and_path[1]), renders_and_paths) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("baked_raw_color", pre_refined_images[0], step) summary_writer.image("baked_refined_color", post_refined_images[0], step) summary_writer.image("baked_target", test_dataset.images[0], step) summary_writer.scalar("baked_raw_psnr", pre_psnr, step) summary_writer.scalar("baked_raw_ssim", pre_ssim, step) summary_writer.scalar("baked_refined_psnr", post_psnr, step) summary_writer.scalar("baked_refined_ssim", post_ssim, step) summary_writer.scalar("baked_size_png_gb", png_size_gb, step) summary_writer.scalar("baked_size_byte_gb", byte_size_gb, step) summary_writer.scalar("baked_size_float_gb", float_size_gb, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(post_psnr)) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(post_ssim)) with utils.open_file(path.join(out_dir, "png_gb.txt"), "w") as f: f.write("{}".format(png_size_gb)) with utils.open_file(path.join(out_dir, "byte_gb.txt"), "w") as f: f.write("{}".format(byte_size_gb)) with utils.open_file(path.join(out_dir, "float_gb.txt"), "w") as f: f.write("{}".format(float_size_gb)) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step