def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image(functools.partial( render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc showcase_features = pred_features showcase_specular = pred_specular if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr( ((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img( pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) summary_writer.image("pred_features", showcase_features, step) summary_writer.image("pred_specular", showcase_specular, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError("Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial( utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap( functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2,)) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3,), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_sparsity", stats.sparsity[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("train_sparsity_coarse", stats.sparsity_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map(lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec") summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_pred_features", pred_features, step) summary_writer.image("test_pred_specular", pred_specular, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def export_snerg_scene(output_directory, atlas, atlas_block_indices, viewdir_mlp_params, render_params, atlas_params, scene_params, input_height, input_width, input_focal): """Exports a scene to web-viewer format: a collection of PNGs and a JSON file. The scene gets exported to output_directory/png. Any previous results will be overwritten. Args: output_directory: The root directory where the scene gets written. atlas: The SNeRG scene packed as a texture atlas in a [S, S, N, C] numpy array, where the channels C contain both RGB and features. atlas_block_indices: The indirection grid of the SNeRG scene, represented as a numpy int32 array of size (bW, bH, bD, 3). viewdir_mlp_params: A dict containing the MLP parameters for the per-sample view-dependence MLP. render_params: A dict with parameters for high-res rendering. atlas_params: A dict with params for building the 3D texture atlas. scene_params: A dict for scene specific params (bbox, rotation, resolution). input_height: Height (pixels) of the NDC camera (i.e. the training cameras). input_width: Width (pixels) of the NDC camera (i.e. the training cameras). input_focal: Focal length (pixels) of the NDC camera (i.e. the training cameras). """ # Slice the atlas into images. rgbs = [] alphas = [] for i in range(0, atlas.shape[2], 4): rgb_stack = [] alpha_stack = [] for j in range(4): plane_index = i + j rgb_stack.append(atlas[:, :, plane_index, :][Ellipsis, 0:3].transpose([1, 0, 2])) alpha_stack.append( atlas[:, :, plane_index, :][Ellipsis, scene_params['_channels']].transpose( [1, 0])) rgbs.append(np.concatenate(rgb_stack, axis=0)) alphas.append(np.concatenate(alpha_stack, axis=0)) atlas_index_image = np.transpose(atlas_block_indices, [2, 1, 0, 3]).reshape( (-1, atlas_block_indices.shape[0], 3)).astype(np.uint8) # Build a dictionary of the scene parameters, so we can export it as a json. export_scene_params = {} export_scene_params['voxel_size'] = float(render_params['_voxel_size']) export_scene_params['block_size'] = atlas_params['_data_block_size'] export_scene_params['grid_width'] = int(render_params['_grid_size'][0]) export_scene_params['grid_height'] = int(render_params['_grid_size'][1]) export_scene_params['grid_depth'] = int(render_params['_grid_size'][2]) export_scene_params['atlas_width'] = atlas.shape[0] export_scene_params['atlas_height'] = atlas.shape[1] export_scene_params['atlas_depth'] = atlas.shape[2] export_scene_params['num_slices'] = len(rgbs) export_scene_params['min_x'] = float(scene_params['min_xyz'][0]) export_scene_params['min_y'] = float(scene_params['min_xyz'][1]) export_scene_params['min_z'] = float(scene_params['min_xyz'][2]) export_scene_params['atlas_blocks_x'] = int( atlas.shape[0] / atlas_params['atlas_block_size']) export_scene_params['atlas_blocks_y'] = int( atlas.shape[1] / atlas_params['atlas_block_size']) export_scene_params['atlas_blocks_z'] = int( atlas.shape[2] / atlas_params['atlas_block_size']) export_scene_params['input_height'] = float(input_height) export_scene_params['input_width'] = float(input_width) export_scene_params['input_focal'] = float(input_focal) export_scene_params['worldspace_T_opengl'] = scene_params[ 'worldspace_T_opengl'].tolist() export_scene_params['ndc'] = scene_params['ndc'] # Also include the network weights in this dictionary. export_scene_params['0_weights'] = viewdir_mlp_params['params']['Dense_0'][ 'kernel'].tolist() export_scene_params['1_weights'] = viewdir_mlp_params['params']['Dense_1'][ 'kernel'].tolist() export_scene_params['2_weights'] = viewdir_mlp_params['params']['Dense_3'][ 'kernel'].tolist() export_scene_params['0_bias'] = viewdir_mlp_params['params']['Dense_0'][ 'bias'].tolist() export_scene_params['1_bias'] = viewdir_mlp_params['params']['Dense_1'][ 'bias'].tolist() export_scene_params['2_bias'] = viewdir_mlp_params['params']['Dense_3'][ 'bias'].tolist() # To avoid partial overwrites, first dump the scene to a temporary directory. output_tmp_directory = output_directory + '/temp' if jax.host_id() == 0: # Delete the folder if it already exists. if utils.isdir(output_tmp_directory): tf.io.gfile.rmtree(output_tmp_directory) utils.makedirs(output_tmp_directory) # Now store the indirection grid. atlas_indices_path = '%s/atlas_indices.png' % output_tmp_directory if jax.host_id() == 0: save_8bit_png((atlas_index_image, atlas_indices_path)) # Make sure that all JAX hosts have reached this point in the code before we # proceed. Things will get tricky if output_tmp_directory doesn't yet exist. synchronize_jax_hosts() # Save the alpha values and RGB colors as one set of PNG images. output_images = [] output_paths = [] for i, rgb_and_alpha in enumerate(zip(rgbs, alphas)): rgb, alpha = rgb_and_alpha rgba = np.concatenate([rgb, np.expand_dims(alpha, -1)], axis=-1) uint_multiplier = 2.0**8 - 1.0 rgba = np.minimum(uint_multiplier, np.maximum(0.0, np.floor(uint_multiplier * rgba))).astype(np.uint8) output_images.append(rgba) atlas_rgba_path = '%s/rgba_%03d.png' % (output_tmp_directory, i) output_paths.append(atlas_rgba_path) # Save the computed features a separate collection of PNGs. uint_multiplier = 2.0**8 - 1.0 for i in range(0, atlas.shape[2], 4): feature_stack = [] for j in range(4): plane_index = i + j feature_slice = atlas[:, :, plane_index, :][Ellipsis, 3:-1].transpose([1, 0, 2]) feature_slice = np.minimum( uint_multiplier, np.maximum(0.0, np.floor(uint_multiplier * feature_slice))).astype(np.uint8) feature_stack.append(feature_slice) output_images.append(np.concatenate(feature_stack, axis=0)) for i in range(len(rgbs)): output_paths.append('%s/feature_%03d.png' % (output_tmp_directory, i)) parallel_write_images(save_8bit_png, list(zip(output_images, output_paths))) # Now export the scene parameters and the network weights as a JSON. export_scene_params['format'] = 'png' scene_params_path = '%s/scene_params.json' % output_tmp_directory if jax.host_id() == 0: with utils.open_file(scene_params_path, 'wb') as f: f.write(json.dumps(export_scene_params).encode('utf-8')) # Again, make sure that the JAX hosts are in sync. Don't delete # output_tmp_directory before all files have been written. synchronize_jax_hosts() # Finally move the scene to the appropriate output path. output_png_directory = output_directory + '/png' if jax.host_id() == 0: # Delete the folder if it already exists. if utils.isdir(output_png_directory): tf.io.gfile.rmtree(output_png_directory) tf.io.gfile.rename(output_tmp_directory, output_png_directory)