def loss_fn(params): variables = {"params": params} ret = model.apply(variables, key_0, key_1, batch, randomized=config.model.randomized) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") #------------------------------------------------------------------------ # Main prediction # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] batch_pixels = model_utils.uint2float(batch.target_view.rgb) loss = ((rgb - batch_pixels[Ellipsis, :3])**2).mean() psnr = model_utils.compute_psnr(loss) #------------------------------------------------------------------------ # Coarse / Regularization Prediction if len(ret) > 1: # If there are both coarse and fine predictions, we compute the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch_pixels[Ellipsis, :3])**2).mean() psnr_c = model_utils.compute_psnr(loss_c) else: loss_c = 0. psnr_c = 0. #------------------------------------------------------------------------ # Weight Regularization weight_penalty_params = jax.tree_leaves(variables["params"]) weight_l2 = sum( [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = weight_decay * 0.5 * weight_l2 #------------------------------------------------------------------------ # Compute total loss and wrap the stats total_loss = loss + loss_c + weight_penalty stats = train_utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2) return total_loss, stats
def __call__(self, rng_0, rng_1, batch, randomized): """Light Field Neural Rendering Model. Args: rng_0: jnp.ndarray, random number generator for coarse model sampling. rng_1: jnp.ndarray, random number generator for fine model sampling. batch: data batch. data_types.Batch randomized: bool, use randomized stratified sampling. Returns: ret: list, [(rgb, None, Optional[acc])] """ del rng_1 # Get the batch rays batch_rays = batch.target_view.rays #--------------------------------------------------------------------------------------- # Operations relating epipolar projections. # Using the given rays, project them to the world and then to nearby # cameras. projected_coordinates, _, wcoords = self.projector.epipolar_projection( rng_0, batch_rays, batch.reference_views.ref_worldtocamera, batch.reference_views.intrinsic_matrix, randomized) # Next we need to get the rgb values and the rays corresponding to these # projections. projected_rays = self.projector.get_near_rays( projected_coordinates, batch.reference_views.ref_cameratoworld, batch.reference_views.intrinsic_matrix) ref_images = model_utils.uint2float(batch.reference_views.rgb) projected_rgb_and_feat = self._get_pixel_projection( projected_coordinates, ref_images) batch.reference_views.rgb = None #---------------------------------------------------------------------------------------- # Get LF representation of the batch and the projected rays. # Below we consider the representation extracted from the batch rays as the # query and representation extracted from the projected rays as keys and the # projected rgb as the values. _, input_q, _ = self._get_query(batch_rays) _, input_k, _, learned_embedding = self._get_key( projected_rays, projected_rgb_and_feat, wcoords, batch.reference_views.idx) # Get the average feature over each epipolar line avg_projection_features, e_attn = self._get_avg_features( input_q, input_k, randomized=randomized) rgb, n_attn = self._predict_color(input_q, avg_projection_features, learned_embedding, randomized) rgb_coarse = self._get_reg_prediction(projected_rgb_and_feat, e_attn, n_attn) ret = [(rgb_coarse, None, None)] ret.append((rgb, None, None)) if self.return_attn: return ret, { "e_attn": e_attn, "n_attn": n_attn, "p_coord": projected_coordinates.swapaxes(0, 1) } else: return ret
def evaluate(config, workdir): """Evalution function.""" # Hide the GPUs and TPUs from TF so it does not reserve memory on them for # LPIPS computation or dataset loading. tf.config.experimental.set_visible_devices([], "GPU") tf.config.experimental.set_visible_devices([], "TPU") rng = jax.random.PRNGKey(config.seed) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, test_ds = datasets.create_dataset(config) example_batch = train_ds.peek() rng, key = jax.random.split(rng) #---------------------------------------------------------------------------- # Initialize model. learning_rate_fn = train_utils.create_learning_rate_fn(config) model, state = models.create_train_state(config, key, learning_rate_fn=learning_rate_fn, example_batch=example_batch) #---------------------------------------------------------------------------- # Get the rendering function. Renderig is forced ot be deterministic even if # trainin is randomized render_pfn = render_utils.get_render_function(model, config, randomized=False) last_step = 0 out_dir = os.path.join( workdir, "path_renders" if config.dataset.render_path else "test_preds") if not config.eval.eval_once: # Prepare Metric Writers summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) while True: state = checkpoints.restore_checkpoint(workdir, state) step = int(state.step) if step <= last_step: continue if config.eval.save_output and (not file_utils.isdir(out_dir)): file_utils.makedirs(out_dir) psnr_values = [] ssim_values = [] if not config.eval.eval_once: showcase_index = np.random.randint(0, test_ds.size) for idx in range(test_ds.size): logging.info("Evaluating [%d / %d].", idx, test_ds.size) batch = next(test_ds) test_pixels = batch.target_view.rgb if test_pixels is not None: test_pixels = model_utils.uint2float(test_pixels) #----------------------------------------------------------- # Render Image variables = {"params": state.params} pred_color, pred_disp, pred_acc = render_utils.render_image( functools.partial(render_pfn, variables), batch, rng, render_utils.normalize_disp(config.dataset.name), chunk=config.eval.chunk) if jax.process_index() != 0: continue #----------------------------------------------------------- # Get showcase example for logging if not config.eval.eval_once and idx == showcase_index: showcase_color = pred_color showcase_disp = pred_disp showcase_acc = pred_acc if not config.dataset.render_path: showcase_gt = test_pixels #----------------------------------------------------------- # If get pixels available, evaluate if not config.dataset.render_path: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weights=True) logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}") # pylint: disable=logging-format-interpolation psnr_values.append(float(psnr)) ssim_values.append(float(ssim)) #----------------------------------------------------------- # Save generated image if config.eval.save_output: model_utils.save_img( pred_color, os.path.join(out_dir, "{:03d}.png".format(idx))) if pred_disp is not None: model_utils.save_img( pred_disp[Ellipsis, 0], os.path.join(out_dir, "disp_{:03d}.png".format(idx))) #----------------------------------------------------------- if (not config.eval.eval_once) and (jax.process_index() == 0): summary_writer.image("pred_color", showcase_color, step) if showcase_disp is not None: summary_writer.image("pred_disp", showcase_disp, step) if showcase_acc is not None: summary_writer.image("pred_acc", showcase_acc, step) if not config.dataset.render_path: summary_writer.scalar("eval_metric/psnr", np.mean(np.array(psnr_values)), step) summary_writer.scalar("eval_metric/ssim", np.mean(np.array(ssim_values)), step) summary_writer.image("target", showcase_gt, step) #----------------------------------------------------------- # Save the metric to file if config.eval.save_output and (not config.dataset.render_path) and ( jax.process_index() == 0): with file_utils.open_file( os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in psnr_values])) with file_utils.open_file( os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f: f.write(" ".join([str(v) for v in ssim_values])) with file_utils.open_file(os.path.join(out_dir, "psnr.txt"), "w") as f: f.write("{}".format(np.mean(np.array(psnr_values)))) with file_utils.open_file(os.path.join(out_dir, "ssim.txt"), "w") as f: f.write("{}".format(np.mean(np.array(ssim_values)))) if config.eval.eval_once: break if int(step) >= config.train.max_steps: break last_step = step logging.info("Finishing evaluation at step %d", last_step)
def train_and_evaluate(config, workdir): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ if config.dataset.batch_size % jax.device_count() != 0: raise ValueError( "Batch size must be divisible by the number of devices.") tf.io.gfile.makedirs(workdir) # Deterministic training. rng = jax.random.PRNGKey(config.seed) # Shift the numpy random seed by process_index() to shuffle data loaded # by different hosts np.random.seed(20201473 + jax.process_index()) #---------------------------------------------------------------------------- # Build input pipeline. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.process_index()) config.dataset.data_dir = os.path.join(config.dataset.base_dir, config.dataset.scene) train_ds, eval_ds = datasets.create_dataset(config) example_batch = train_ds.peek() #---------------------------------------------------------------------------- # Learning rate schedule. num_train_steps = config.train.max_steps if num_train_steps == -1: num_train_steps = train_ds.size() steps_per_epoch = num_train_steps // config.train.num_epochs logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch) learning_rate_fn = train_utils.create_learning_rate_fn(config) #---------------------------------------------------------------------------- # Initialize model. rng, model_rng = jax.random.split(rng) model, state = models.create_train_state( config, model_rng, learning_rate_fn=learning_rate_fn, example_batch=example_batch, ) #---------------------------------------------------------------------------- # Set up checkpointing of the model and the input pipeline. state = checkpoints.restore_checkpoint(workdir, state) initial_step = int(state.step) + 1 #---------------------------------------------------------------------------- # Distribute training. state = flax_utils.replicate(state) p_train_step = jax.pmap( functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn, weight_decay=config.train.weight_decay, config=config, ), axis_name="batch", ) # Get distributed rendering function render_pfn = render_utils.get_render_function( model=model, config=config, randomized=False, # No randomization for evaluation. ) #---------------------------------------------------------------------------- # Prepare Metric Writers writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if initial_step == 1: writer.write_hparams(dict(config)) logging.info("Starting training loop at step %d.", initial_step) hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, ] train_metrics = None # Prefetch_buffer_size = 6 x batch_size ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6) n_local_devices = jax.local_device_count() rng = rng + jax.process_index() # Make random seed separate across hosts. keys = jax.random.split(rng, n_local_devices) # For pmapping RNG keys. with metric_writers.ensure_flushes(writer): for step in range(initial_step, num_train_steps + 1): # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU # devices. is_last_step = step == num_train_steps with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = next(ptrain_ds) state, metrics_update, keys = p_train_step(rng=keys, state=state, batch=batch) metric_update = flax_utils.unreplicate(metrics_update) train_metrics = (metric_update if train_metrics is None else train_metrics.merge(metric_update)) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) if step % config.train.log_loss_every_steps == 0 or is_last_step: writer.write_scalars(step, train_metrics.compute()) train_metrics = None if step % config.train.render_every_steps == 0 or is_last_step: test_batch = next(eval_ds) test_pixels = model_utils.uint2float( test_batch.target_view.rgb) # extract for evaluation with report_progress.timed("eval"): pred_color, pred_disp, pred_acc = eval_step( state, keys[0], test_batch, render_pfn, config) #------------------------------------------------------------------ # Log metrics and images for host 0 #------------------------------------------------------------------ if jax.process_index() == 0: psnr = model_utils.compute_psnr( ((pred_color - test_pixels)**2).mean()) ssim = skmetrics.structural_similarity( pred_color.astype(np.float32), test_pixels.astype(np.float32), win_size=11, multichannel=True, gaussian_weight=True) writer.write_scalars( step, { "train_eval/test_psnr": psnr, "train_eval/test_ssim": ssim, }) writer.write_images( step, { "test_pred_color": pred_color[None, :], "test_target": test_pixels[None, :] }) if pred_disp is not None: writer.write_images( step, {"test_pred_disp": pred_disp[None, :]}) if pred_acc is not None: writer.write_images( step, {"test_pred_acc": pred_acc[None, :]}) #------------------------------------------------------------------ if (jax.process_index() == 0) and (step % config.train.checkpoint_every_steps == 0 or is_last_step): # Write final metrics to file with file_utils.open_file( os.path.join(workdir, "train_logs.json"), "w") as f: log_dict = metric_update.compute() for k, v in log_dict.items(): log_dict[k] = v.item() f.write(json.dumps(log_dict)) with report_progress.timed("checkpoint"): state_to_save = jax.device_get( jax.tree_map(lambda x: x[0], state)) checkpoints.save_checkpoint(workdir, state_to_save, step, keep=100) logging.info("Finishing training at step %d", num_train_steps)