Example #1
0
    def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
        rgb_loss = ((model_out['rgb'] - batch['rgb'][..., :3])**2).mean()
        stats = {
            'loss/rgb': rgb_loss,
        }
        loss = rgb_loss
        if use_elastic_loss:
            elastic_fn = functools.partial(compute_elastic_loss,
                                           loss_type=elastic_loss_type)
            v_elastic_fn = jax.jit(vmap(vmap(jax.jit(elastic_fn))))
            weights = lax.stop_gradient(model_out['weights'])
            jacobian = model_out['warp_jacobian']
            # Pick the median point Jacobian.
            if elastic_reduce_method == 'median':
                depth_indices = model_utils.compute_depth_index(weights)
                jacobian = jnp.take_along_axis(
                    # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
                    jacobian,
                    depth_indices[..., None, None, None],
                    axis=-3)
            # Compute loss using Jacobian.
            elastic_loss, elastic_residual = v_elastic_fn(jacobian)
            # Multiply weight if weighting by density.
            if elastic_reduce_method == 'weight':
                elastic_loss = weights * elastic_loss
            elastic_loss = elastic_loss.sum(axis=-1).mean()
            stats['loss/elastic'] = elastic_loss
            stats['residual/elastic'] = jnp.mean(elastic_residual)
            loss += scalar_params.elastic_loss_weight * elastic_loss

        if use_warp_reg_loss:
            weights = lax.stop_gradient(model_out['weights'])
            depth_indices = model_utils.compute_depth_index(weights)
            warp_mag = ((model_out['points'] -
                         model_out['warped_points'])**2).sum(axis=-1)
            warp_reg_residual = jnp.take_along_axis(warp_mag,
                                                    depth_indices[..., None],
                                                    axis=-1)
            warp_reg_loss = utils.general_loss_with_squared_residual(
                warp_reg_residual,
                alpha=scalar_params.warp_reg_loss_alpha,
                scale=scalar_params.warp_reg_loss_scale).mean()
            stats['loss/warp_reg'] = warp_reg_loss
            stats['residual/warp_reg'] = jnp.mean(jnp.sqrt(warp_reg_residual))
            loss += scalar_params.warp_reg_loss_weight * warp_reg_loss

        if 'warp_jacobian' in model_out:
            jacobian = model_out['warp_jacobian']
            jacobian_det = jnp.linalg.det(jacobian)
            jacobian_div = utils.jacobian_to_div(jacobian)
            jacobian_curl = utils.jacobian_to_curl(jacobian)
            stats['metric/jacobian_det'] = jnp.mean(jacobian_det)
            stats['metric/jacobian_div'] = jnp.mean(jacobian_div)
            stats['metric/jacobian_curl'] = jnp.mean(
                jnp.linalg.norm(jacobian_curl, axis=-1))

        stats['loss/total'] = loss
        stats['metric/psnr'] = utils.compute_psnr(rgb_loss)
        return loss, stats
Example #2
0
    def _compute_loss_and_stats(params, model_out, use_elastic_loss=False):
        rgb_loss = ((model_out["rgb"] - batch["rgb"][..., :3])**2).mean()
        stats = {
            "loss/rgb": rgb_loss,
        }
        loss = rgb_loss
        if use_elastic_loss:
            v_elastic_fn = jax.jit(vmap(vmap(compute_elastic_loss)))
            weights = lax.stop_gradient(model_out["weights"])
            jacobian = model_out["warp_jacobian"]
            # Pick the median point Jacobian.
            if elastic_reduce_method == "median":
                depth_indices = model_utils.compute_depth_index(weights)
                jacobian = jnp.take_along_axis(
                    # Unsqueeze axes: sample axis, Jacobian row, Jacobian col.
                    jacobian,
                    depth_indices[..., None, None, None],
                    axis=-3,
                )
            # Compute loss using Jacobian.
            elastic_loss, elastic_residual = v_elastic_fn(jacobian)
            # Multiply weight if weighting by density.
            if elastic_reduce_method == "weight":
                elastic_loss = weights * elastic_loss
            elastic_loss = elastic_loss.sum(axis=-1).mean()
            stats["loss/elastic"] = elastic_loss
            stats["residual/elastic"] = jnp.mean(elastic_residual)
            loss += scalar_params.elastic_loss_weight * elastic_loss

        if "warp_jacobian" in model_out:
            jacobian = model_out["warp_jacobian"]
            jacobian_det = jnp.linalg.det(jacobian)
            jacobian_div = utils.jacobian_to_div(jacobian)
            jacobian_curl = utils.jacobian_to_curl(jacobian)
            stats["metric/jacobian_det"] = jnp.mean(jacobian_det)
            stats["metric/jacobian_div"] = jnp.mean(jacobian_div)
            stats["metric/jacobian_curl"] = jnp.mean(
                jnp.linalg.norm(jacobian_curl, axis=-1))

        stats["loss/total"] = loss
        stats["metric/psnr"] = utils.compute_psnr(rgb_loss)
        return loss, stats
Example #3
0
def process_batch(*,
                  batch: Dict[str, jnp.ndarray],
                  rng: types.PRNGKey,
                  state: model_utils.TrainState,
                  tag: str,
                  item_id: str,
                  step: int,
                  summary_writer: tensorboard.SummaryWriter,
                  render_fn: Any,
                  save_dir: Optional[gpath.GPath],
                  datasource: datasets.DataSource):
  """Process and plot a single batch."""
  item_id = item_id.replace('/', '_')
  render = render_fn(state, batch, rng=rng)
  out = {}
  if jax.process_index() != 0:
    return out

  rgb = render['rgb']
  acc = render['acc']
  depth_exp = render['depth']
  depth_med = render['med_depth']
  colorize_depth = functools.partial(viz.colorize,
                                     cmin=datasource.near,
                                     cmax=datasource.far,
                                     invert=True)

  depth_exp_viz = colorize_depth(depth_exp)
  depth_med_viz = colorize_depth(depth_med)
  disp_exp_viz = viz.colorize(1.0 / depth_exp)
  disp_med_viz = viz.colorize(1.0 / depth_med)
  acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0)
  if save_dir:
    save_dir.mkdir(parents=True, exist_ok=True)
    image_utils.save_image(save_dir / f'rgb_{item_id}.png',
                           image_utils.image_to_uint8(rgb))
    image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_exp_viz))
    image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png',
                           depth_exp)
    image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png',
                           image_utils.image_to_uint8(depth_med_viz))
    image_utils.save_depth(save_dir / f'depth_median_{item_id}.png',
                           depth_med)

  summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step)
  summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step)
  summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step)
  summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz,
                       step)
  summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step)
  summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step)

  if 'rgb' in batch:
    rgb_target = batch['rgb']
    mse = ((rgb - batch['rgb'])**2).mean()
    psnr = utils.compute_psnr(mse)
    ssim = compute_multiscale_ssim(rgb_target, rgb)
    out['mse'] = mse
    out['psnr'] = psnr
    out['ssim'] = ssim
    logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f',
                 mse, psnr, ssim)

    rgb_abs_error = viz.colorize(
        abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1)
    rgb_sq_error = viz.colorize(
        ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1)
    summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step)
    summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step)
    summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step)

  if 'depth' in batch:
    depth_target = batch['depth']
    depth_target_viz = colorize_depth(depth_target[..., 0])
    out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med))
    summary_writer.image(
        f'depth-target/{tag}/{item_id}', depth_target_viz, step)
    depth_med_error = viz.colorize(
        abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-median-error/{tag}/{item_id}', depth_med_error, step)
    depth_exp_error = viz.colorize(
        abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1)
    summary_writer.image(
        f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step)

  return out
Example #4
0
def process_batch(
    *,
    batch: Dict[str, jnp.ndarray],
    rng: types.PRNGKey,
    state: model_utils.TrainState,
    tag: str,
    item_id: str,
    step: int,
    summary_writer: tensorboard.SummaryWriter,
    render_fn: Any,
    save_dir: Optional[gpath.GPath],
    datasource: datasets.DataSource,
):
    """Process and plot a single batch."""
    rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng)
    out = {}
    if jax.host_id() != 0:
        return out

    colorize_depth = functools.partial(viz.colorize,
                                       cmin=datasource.near,
                                       cmax=datasource.far,
                                       invert=True)

    depth_exp_viz = colorize_depth(depth_exp[..., 0])
    depth_med_viz = colorize_depth(depth_med[..., 0])
    if save_dir:
        save_dir.mkdir(parents=True, exist_ok=True)
        image_utils.save_image(save_dir / f"rgb_{item_id}.png",
                               image_utils.image_to_uint8(rgb))
        image_utils.save_image(
            save_dir / f"depth_expected_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_exp_viz),
        )
        image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png",
                               depth_med[..., 0])
        image_utils.save_image(
            save_dir / f"depth_median_viz_{item_id}.png",
            image_utils.image_to_uint8(depth_med_viz),
        )
        image_utils.save_depth(save_dir / f"depth_median_{item_id}.png",
                               depth_med[..., 0])

    summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step)
    summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz,
                         step)
    summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step)
    summary_writer.image(f"acc/{tag}/{item_id}", acc, step)

    if "rgb" in batch:
        rgb_target = batch["rgb"]
        mse = ((rgb - batch["rgb"])**2).mean()
        psnr = utils.compute_psnr(mse)
        ssim = compute_multiscale_ssim(rgb_target, rgb)
        out["mse"] = mse
        out["psnr"] = psnr
        out["ssim"] = ssim
        logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr,
                     ssim)

        rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1),
                                     cmin=0,
                                     cmax=1)
        rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1),
                                    cmin=0,
                                    cmax=1)
        summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step)
        summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error,
                             step)
        summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error,
                             step)

    if "depth" in batch:
        depth_target = batch["depth"]
        depth_target_viz = colorize_depth(depth_target[..., 0])
        out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med))
        summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz,
                             step)
        depth_med_error = viz.colorize(abs(depth_target -
                                           depth_med).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-median-error/{tag}/{item_id}",
                             depth_med_error, step)
        depth_exp_error = viz.colorize(abs(depth_target -
                                           depth_exp).squeeze(axis=-1),
                                       cmin=0,
                                       cmax=1)
        summary_writer.image(f"depth-expected-error/{tag}/{item_id}",
                             depth_exp_error, step)
        rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0])
        summary_writer.image(f"relative-disparity/{tag}/{item_id}",
                             rel_disp_pred, step)

    return out