def loss_fn(variables): rays = batch["rays"] ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, _, _, sigma, _, _ = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) if len(ret) > 1: # If there are both coarse and fine predictions, we compute the loss for # the coarse prediction (ret[0]) as well. rgb_c, _, _, sigma_c, _, _ = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) sparsity_c = FLAGS.sparsity_strength * jax.numpy.log(1.0 + sigma_c**2 / 0.5).mean() else: loss_c = 0. psnr_c = 0. sparsity_c = 0.0 def tree_sum_fn(fn): return jax.tree_util.tree_reduce( lambda x, y: x + fn(y), variables, initializer=0) weight_l2 = ( tree_sum_fn(lambda z: jnp.sum(z**2)) / tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) sparsity = FLAGS.sparsity_strength * jax.numpy.log(1.0 + sigma**2 / 0.5).mean() stats = utils.Stats( loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2, sparsity=sparsity, sparsity_c=sparsity_c) return (loss + loss_c + FLAGS.weight_decay_mult * weight_l2 + sparsity + sparsity_c), stats
def main(unused_argv): rng = random.PRNGKey(20200823) # Shift the numpy random seed by host_id() to shuffle data loaded by different # hosts. np.random.seed(20201473 + jax.host_id()) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.batch_size % jax.device_count() != 0: raise ValueError("Batch size must be divisible by the number of devices.") if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("train", FLAGS) test_dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables) state = utils.TrainState(optimizer=optimizer) del optimizer, variables learning_rate_fn = functools.partial( utils.learning_rate_decay, lr_init=FLAGS.lr_init, lr_final=FLAGS.lr_final, max_steps=FLAGS.max_steps, lr_delay_steps=FLAGS.lr_delay_steps, lr_delay_mult=FLAGS.lr_delay_mult) train_pstep = jax.pmap( functools.partial(train_step, model), axis_name="batch", in_axes=(0, 0, 0, None), donate_argnums=(2,)) def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather( model.apply(variables, key_0, key_1, rays, FLAGS.randomized), axis_name="batch") render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), # Only distribute the data input. donate_argnums=(3,), axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit( functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") if not utils.isdir(FLAGS.train_dir): utils.makedirs(FLAGS.train_dir) state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) # Resume training a the step of the last checkpoint. init_step = state.optimizer.state.step + 1 state = flax.jax_utils.replicate(state) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir) # Prefetch_buffer_size = 3 x batch_size pdataset = flax.jax_utils.prefetch_to_device(dataset, 3) n_local_devices = jax.local_device_count() rng = rng + jax.host_id() # Make random seed separate across hosts. keys = random.split(rng, n_local_devices) # For pmapping RNG keys. gc.disable() # Disable automatic garbage collection for efficiency. stats_trace = [] reset_timer = True for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset): if reset_timer: t_loop_start = time.time() reset_timer = False lr = learning_rate_fn(step) state, stats, keys = train_pstep(keys, state, batch, lr) if jax.host_id() == 0: stats_trace.append(stats) if step % FLAGS.gc_every == 0: gc.collect() # Log training summaries. This is put behind a host_id check because in # multi-host evaluation, all hosts need to run inference even though we # only use host 0 to record results. if jax.host_id() == 0: if step % FLAGS.print_every == 0: summary_writer.scalar("train_loss", stats.loss[0], step) summary_writer.scalar("train_psnr", stats.psnr[0], step) summary_writer.scalar("train_sparsity", stats.sparsity[0], step) summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step) summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step) summary_writer.scalar("train_sparsity_coarse", stats.sparsity_c[0], step) summary_writer.scalar("weight_l2", stats.weight_l2[0], step) avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace])) avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace])) stats_trace = [] summary_writer.scalar("train_avg_loss", avg_loss, step) summary_writer.scalar("train_avg_psnr", avg_psnr, step) summary_writer.scalar("learning_rate", lr, step) steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start) reset_timer = True rays_per_sec = FLAGS.batch_size * steps_per_sec summary_writer.scalar("train_steps_per_sec", steps_per_sec, step) summary_writer.scalar("train_rays_per_sec", rays_per_sec, step) precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1 print(("{:" + "{:d}".format(precision) + "d}").format(step) + f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " + f"avg_loss={avg_loss:0.4f}, " + f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec") if step % FLAGS.save_every == 0: state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state_to_save, int(step), keep=100) # Test-set evaluation. if FLAGS.render_every > 0 and step % FLAGS.render_every == 0: # We reuse the same random number generator from the optimization step # here on purpose so that the visualization matches what happened in # training. t_eval_start = time.time() eval_variables = jax.device_get(jax.tree_map(lambda x: x[0], state)).optimizer.target test_case = next(test_dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image( functools.partial(render_pfn, eval_variables), test_case["rays"], keys[0], FLAGS.dataset == "llff", chunk=FLAGS.chunk) # Log eval summaries on host 0. if jax.host_id() == 0: psnr = utils.compute_psnr( ((pred_color - test_case["pixels"])**2).mean()) ssim = ssim_fn(pred_color, test_case["pixels"]) eval_time = time.time() - t_eval_start num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1])) rays_per_sec = num_rays / eval_time summary_writer.scalar("test_rays_per_sec", rays_per_sec, step) print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec") summary_writer.scalar("test_psnr", psnr, step) summary_writer.scalar("test_ssim", ssim, step) summary_writer.image("test_pred_color", pred_color, step) summary_writer.image("test_pred_disp", pred_disp, step) summary_writer.image("test_pred_acc", pred_acc, step) summary_writer.image("test_pred_features", pred_features, step) summary_writer.image("test_pred_specular", pred_specular, step) summary_writer.image("test_target", test_case["pixels"], step) if FLAGS.max_steps % FLAGS.save_every != 0: state = jax.device_get(jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint( FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
def main(unused_argv): # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = random.PRNGKey(20200823) if FLAGS.config is not None: utils.update_flags(FLAGS) if FLAGS.train_dir is None: raise ValueError("train_dir must be set. None set now.") if FLAGS.data_dir is None: raise ValueError("data_dir must be set. None set now.") dataset = datasets.get_dataset("test", FLAGS) rng, key = random.split(rng) model, init_variables = models.get_model(key, dataset.peek(), FLAGS) optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables) state = utils.TrainState(optimizer=optimizer) del optimizer, init_variables # Rendering is forced to be deterministic even if training was randomized, as # this eliminates "speckle" artifacts. def render_fn(variables, key_0, key_1, rays): return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays, False), axis_name="batch") # pmap over only the data input. render_pfn = jax.pmap( render_fn, in_axes=(None, None, None, 0), donate_argnums=3, axis_name="batch", ) # Compiling to the CPU because it's faster and more accurate. ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.), backend="cpu") last_step = 0 out_dir = path.join(FLAGS.train_dir, "path_renders" if FLAGS.render_path else "test_preds") if not FLAGS.eval_once: summary_writer = tensorboard.SummaryWriter( path.join(FLAGS.train_dir, "eval")) while True: state = checkpoints.restore_checkpoint(FLAGS.train_dir, state) step = int(state.optimizer.state.step) if step <= last_step: continue if FLAGS.save_output and (not utils.isdir(out_dir)): utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not FLAGS.eval_once: showcase_index = np.random.randint(0, dataset.size) for idx in range(dataset.size): print(f"Evaluating {idx+1}/{dataset.size}") batch = next(dataset) (pred_color, pred_disp, pred_acc, pred_features, pred_specular) = utils.render_image(functools.partial( render_pfn, state.optimizer.target), batch["rays"], rng, FLAGS.dataset == "llff", chunk=FLAGS.chunk) if jax.host_id() != 0: # Only record via host 0. continue if not FLAGS.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc showcase_features = pred_features showcase_specular = pred_specular if not FLAGS.render_path: showcase_gt = batch["pixels"] if not FLAGS.render_path: psnr = utils.compute_psnr( ((pred_color - batch["pixels"])**2).mean()) ssim = ssim_fn(pred_color, batch["pixels"]) print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) if FLAGS.save_output: utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx))) utils.save_img( pred_disp[Ellipsis, 0], path.join(out_dir, "disp_{:03d}.png".format(idx))) if (not FLAGS.eval_once) and (jax.host_id() == 0): summary_writer.image("pred_color", showcase_color, step) summary_writer.image("pred_disp", showcase_disp, step) summary_writer.image("pred_acc", showcase_acc, step) summary_writer.image("pred_features", showcase_features, step) summary_writer.image("pred_specular", showcase_specular, step) if not FLAGS.render_path: summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0): with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if FLAGS.eval_once: break if int(step) >= FLAGS.max_steps: break last_step = step