def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean() stats = { 'loss/rgb': rgb_loss, } loss = rgb_loss if use_elastic_loss: elastic_fn = functools.partial(compute_elastic_loss, loss_type=elastic_loss_type) v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn)))) weights = lax.stop_gradient(model_out['weights']) jacobian = model_out['warp_jacobian'] # Pick the median point Jacobian. if elastic_reduce_method == 'median': depth_indices = model_utils.compute_depth_index(weights) jacobian = jnp.take_along_axis( # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. jacobian, depth_indices[..., None, None, None], axis=-3) # Compute loss using Jacobian. elastic_loss, elastic_residual = v_elastic_fn(jacobian) # Multiply weight if weighting by density. if elastic_reduce_method == 'weight': elastic_loss = weights * elastic_loss elastic_loss = elastic_loss.sum(axis=-1).mean() stats['loss/elastic'] = elastic_loss stats['residual/elastic'] = jnp.mean(elastic_residual) loss += scalar_params.elastic_loss_weight * elastic_loss if use_warp_reg_loss: weights = lax.stop_gradient(model_out['weights']) depth_indices = model_utils.compute_depth_index(weights) warp_mag = ((model_out['points'] - model_out['warped_points'])**2).sum(axis=-1) warp_reg_residual = jnp.take_along_axis(warp_mag, depth_indices[..., None], axis=-1) warp_reg_loss = utils.general_loss_with_squared_residual( warp_reg_residual, alpha=scalar_params.warp_reg_loss_alpha, scale=scalar_params.warp_reg_loss_scale).mean() stats['loss/warp_reg'] = warp_reg_loss stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual)) loss += scalar_params.warp_reg_loss_weight * warp_reg_loss if 'warp_jacobian' in model_out: jacobian = model_out['warp_jacobian'] jacobian_det = jnp.linalg.det(jacobian) jacobian_div = utils.jacobian_to_div(jacobian) jacobian_curl = utils.jacobian_to_curl(jacobian) stats['metric/jacobian_det'] = jnp.mean(jacobian_det) stats['metric/jacobian_div'] = jnp.mean(jacobian_div) stats['metric/jacobian_curl'] = jnp.mean( jnp.linalg.norm(jacobian_curl, axis=-1)) stats['loss/total'] = loss stats['metric/psnr'] = utils.compute_psnr(rgb_loss) return loss, stats
def _compute_loss_and_stats(params, model_out, use_elastic_loss=False): rgb_loss = ((model_out["rgb"] - batch["rgb"][..., :3])**2).mean() stats = { "loss/rgb": rgb_loss, } loss = rgb_loss if use_elastic_loss: v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss))) weights = lax.stop_gradient(model_out["weights"]) jacobian = model_out["warp_jacobian"] # Pick the median point Jacobian. if elastic_reduce_method == "median": depth_indices = model_utils.compute_depth_index(weights) jacobian = jnp.take_along_axis( # Unsqueeze axes: sample axis, Jacobian row, Jacobian col. jacobian, depth_indices[..., None, None, None], axis=-3, ) # Compute loss using Jacobian. elastic_loss, elastic_residual = v_elastic_fn(jacobian) # Multiply weight if weighting by density. if elastic_reduce_method == "weight": elastic_loss = weights * elastic_loss elastic_loss = elastic_loss.sum(axis=-1).mean() stats["loss/elastic"] = elastic_loss stats["residual/elastic"] = jnp.mean(elastic_residual) loss += scalar_params.elastic_loss_weight * elastic_loss if "warp_jacobian" in model_out: jacobian = model_out["warp_jacobian"] jacobian_det = jnp.linalg.det(jacobian) jacobian_div = utils.jacobian_to_div(jacobian) jacobian_curl = utils.jacobian_to_curl(jacobian) stats["metric/jacobian_det"] = jnp.mean(jacobian_det) stats["metric/jacobian_div"] = jnp.mean(jacobian_div) stats["metric/jacobian_curl"] = jnp.mean( jnp.linalg.norm(jacobian_curl, axis=-1)) stats["loss/total"] = loss stats["metric/psnr"] = utils.compute_psnr(rgb_loss) return loss, stats
def process_batch(*, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process and plot a single batch.""" item_id = item_id.replace('/', '_') render = render_fn(state, batch, rng=rng) out = {} if jax.process_index() != 0: return out rgb = render['rgb'] acc = render['acc'] depth_exp = render['depth'] depth_med = render['med_depth'] colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp) depth_med_viz = colorize_depth(depth_med) disp_exp_viz = viz.colorize(1.0 / depth_exp) disp_med_viz = viz.colorize(1.0 / depth_med) acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f'rgb_{item_id}.png', image_utils.image_to_uint8(rgb)) image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png', image_utils.image_to_uint8(depth_exp_viz)) image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png', depth_exp) image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png', image_utils.image_to_uint8(depth_med_viz)) image_utils.save_depth(save_dir / f'depth_median_{item_id}.png', depth_med) summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step) summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step) summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step) summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz, step) summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step) summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step) if 'rgb' in batch: rgb_target = batch['rgb'] mse = ((rgb - batch['rgb'])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out['mse'] = mse out['psnr'] = psnr out['ssim'] = ssim logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f', mse, psnr, ssim) rgb_abs_error = viz.colorize( abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize( ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step) summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step) summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step) if 'depth' in batch: depth_target = batch['depth'] depth_target_viz = colorize_depth(depth_target[..., 0]) out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image( f'depth-target/{tag}/{item_id}', depth_target_viz, step) depth_med_error = viz.colorize( abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-median-error/{tag}/{item_id}', depth_med_error, step) depth_exp_error = viz.colorize( abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step) return out
def process_batch( *, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process and plot a single batch.""" rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng) out = {} if jax.host_id() != 0: return out colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp[..., 0]) depth_med_viz = colorize_depth(depth_med[..., 0]) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f"rgb_{item_id}.png", image_utils.image_to_uint8(rgb)) image_utils.save_image( save_dir / f"depth_expected_viz_{item_id}.png", image_utils.image_to_uint8(depth_exp_viz), ) image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png", depth_med[..., 0]) image_utils.save_image( save_dir / f"depth_median_viz_{item_id}.png", image_utils.image_to_uint8(depth_med_viz), ) image_utils.save_depth(save_dir / f"depth_median_{item_id}.png", depth_med[..., 0]) summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step) summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz, step) summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step) summary_writer.image(f"acc/{tag}/{item_id}", acc, step) if "rgb" in batch: rgb_target = batch["rgb"] mse = ((rgb - batch["rgb"])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out["mse"] = mse out["psnr"] = psnr out["ssim"] = ssim logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr, ssim) rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step) summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error, step) summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error, step) if "depth" in batch: depth_target = batch["depth"] depth_target_viz = colorize_depth(depth_target[..., 0]) out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz, step) depth_med_error = viz.colorize(abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-median-error/{tag}/{item_id}", depth_med_error, step) depth_exp_error = viz.colorize(abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-expected-error/{tag}/{item_id}", depth_exp_error, step) rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0]) summary_writer.image(f"relative-disparity/{tag}/{item_id}", rel_disp_pred, step) return out