Beispiel #1
0
    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables,
                                             initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        stats = utils.Stats(loss=loss,
                            psnr=psnr,
                            loss_c=loss_c,
                            psnr_c=psnr_c,
                            weight_l2=weight_l2)
        return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats
Beispiel #2
0
 def loss_fn(model):
     with nn.stateful(state.model_state) as new_model_state:
         rays = batch["rays"]
         ret = model(key_0, key_1, rays.origins, rays.directions,
                     rays.viewdirs)
     if len(ret) not in (1, 2):
         raise ValueError(
             "ret should contain either 1 set of output (coarse only), or 2 sets"
             "of output (coarse as ret[0] and fine as ret[1]).")
     # The main prediction is always at the end of the ret list.
     rgb, unused_disp, unused_acc = ret[-1]
     loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
     psnr = utils.compute_psnr(loss)
     stats = [utils.Stats(loss=loss, psnr=psnr)]
     if len(ret) > 1:
         # If there are both coarse and fine predictions, we compuate the loss for
         # the coarse prediction (ret[0]) as well.
         rgb_c, unused_disp_c, unused_acc_c = ret[0]
         loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
         psnr_c = utils.compute_psnr(loss_c)
         stats.append(utils.Stats(loss=loss_c, psnr=psnr_c))
     else:
         loss_c = 0.
         psnr_c = 0.
     return loss + loss_c, (new_model_state, stats)
Beispiel #3
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)
    test_render_fn = jax.pmap(
        # Note rng_keys are useless in eval mode since there's no randomness.
        # pylint: disable=g-long-lambda
        lambda key_0, key_1, model, rays: jax.lax.all_gather(
            model(key_0, key_1, *rays), axis_name="batch"),
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=3,
        axis_name="batch",
    )
    rng, key = random.split(rng)
    init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
    optimizer_def = optim.Adam(FLAGS.lr_init)
    optimizer = optimizer_def.create(init_model)
    state = model_utils.TrainState(step=0,
                                   optimizer=optimizer,
                                   model_state=init_state)
    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    offset = state.step + 1
    state = jax_utils.replicate(state)
    del init_model, init_state

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
    t_loop_start = time.time()
    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    ptrain_step = jax.pmap(train_step,
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=2)
    # Prefetch_buffer_size = 3 x batch_size
    pdataset = jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    for step, batch in zip(range(offset, FLAGS.max_steps + 1), pdataset):
        lr = learning_rate_fn(step)
        state, stats, keys = ptrain_step(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats[0])
        if step % FLAGS.gc_every == 0:
            gc.collect()
        # --- Train logs start ---
        # Put the training time visualization before the host_id check as in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            state_to_eval = jax.device_get(jax.tree_map(lambda x: x[0], state))
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                state_to_eval,
                test_case["rays"],
                test_render_fn,
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)
        if jax.host_id() != 0:  # Only log via host 0.
            continue
        if step % FLAGS.print_every == 0:
            summary_writer.scalar("train_loss", stats[0].loss[0], step)
            summary_writer.scalar("train_psnr", stats[0].psnr[0], step)
            if len(stats) > 1:
                summary_writer.scalar("train_loss_coarse", stats[1].loss[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats[1].psnr[0],
                                      step)
            avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
            avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
            stats_trace = []
            summary_writer.scalar("train_avg_loss", avg_loss, step)
            summary_writer.scalar("train_avg_psnr", avg_psnr, step)
            summary_writer.scalar("learning_rate", lr, step)
            steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
            t_loop_start = time.time()
            rays_per_sec = FLAGS.batch_size * steps_per_sec
            summary_writer.scalar("steps_per_sec", steps_per_sec, step)
            summary_writer.scalar("rays_per_sec", rays_per_sec, step)
            precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
            print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                  f"/{FLAGS.max_steps:d}: " +
                  f"i_loss={stats[0].loss[0]:0.5f}, " +
                  f"avg_loss={avg_loss:0.5f}, " + f"lr={lr:0.2e}, " +
                  f"{rays_per_sec:0.3f} rays/sec")
        if step % FLAGS.save_every == 0:
            state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
            checkpoints.save_checkpoint(FLAGS.train_dir,
                                        state_to_save,
                                        state_to_save.step,
                                        keep=100)
        # --- Train logs end ---

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(state.step),
                                    keep=100)
Beispiel #4
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")
    dataset = datasets.get_dataset("train", FLAGS)
    test_dataset = datasets.get_dataset("test", FLAGS)

    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables

    learning_rate_fn = functools.partial(utils.learning_rate_decay,
                                         lr_init=FLAGS.lr_init,
                                         lr_final=FLAGS.lr_final,
                                         max_steps=FLAGS.max_steps,
                                         lr_delay_steps=FLAGS.lr_delay_steps,
                                         lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(functools.partial(train_step, model),
                           axis_name="batch",
                           in_axes=(0, 0, 0, None),
                           donate_argnums=(2, ))

    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              FLAGS.randomized),
                                  axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3, ),
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1
    state = flax.jax_utils.replicate(state)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_deices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_deices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True
    for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)
        state, stats, keys = train_pstep(keys, state, batch, lr)
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0],
                                      step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0],
                                      step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(
                    np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(
                    np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() -
                                                     t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec,
                                      step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " +
                      f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                      f"lr={lr:0.2e}, " + f"{rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(
                    jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(FLAGS.train_dir,
                                            state_to_save,
                                            int(step),
                                            keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(
                lambda x: x[0], state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(
                    jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(
                    f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec"
                )
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(FLAGS.train_dir,
                                    state,
                                    int(FLAGS.max_steps),
                                    keep=100)
Beispiel #5
0
def main(unused_argv):
  # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
  # LPIPS computation or dataset loading.
  tf.config.experimental.set_visible_devices([], "GPU")
  tf.config.experimental.set_visible_devices([], "TPU")

  rng = random.PRNGKey(20200823)

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")

  dataset = datasets.get_dataset("test", FLAGS)
  rng, key = random.split(rng)
  model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates "speckle" artifacts.
  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, False), axis_name="batch")

  # pmap over only the data input.
  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),
      donate_argnums=3,
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      "path_renders" if FLAGS.render_path else "test_preds")
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, "eval"))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnr_values = []
    ssim_values = []
    lpips_values = []
    if not FLAGS.eval_once:
      showcase_index = np.random.randint(0, dataset.size)
    for idx in range(dataset.size):
      print(f"Evaluating {idx+1}/{dataset.size}")
      batch = next(dataset)
      pred_color, pred_disp, pred_acc = utils.render_image(
          functools.partial(render_pfn, state.optimizer.target),
          batch["rays"],
          rng,
          FLAGS.dataset == "llff",
          chunk=FLAGS.chunk)
      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_disp = pred_disp
        showcase_acc = pred_acc
        if not FLAGS.render_path:
          showcase_gt = batch["pixels"]
      if not FLAGS.render_path:
        psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, batch["pixels"])
        lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
        print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
        psnr_values.append(float(psnr))
        ssim_values.append(float(ssim))
        lpips_values.append(float(lpips))
      if FLAGS.save_output:
        utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
        utils.save_img(pred_disp[Ellipsis, 0],
                       path.join(out_dir, "disp_{:03d}.png".format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image("pred_color", showcase_color, step)
      summary_writer.image("pred_disp", showcase_disp, step)
      summary_writer.image("pred_acc", showcase_acc, step)
      if not FLAGS.render_path:
        summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
        summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
        summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
        summary_writer.image("target", showcase_gt, step)
    if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in psnr_values]))
      with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in ssim_values]))
      with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in lpips_values]))
      with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(psnr_values))))
      with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(ssim_values))))
      with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(lpips_values))))
    if FLAGS.eval_once:
      break
    if int(step) >= FLAGS.max_steps:
      break
    last_step = step
Beispiel #6
0
def main(unused_argv):
    rng = random.PRNGKey(20200823)

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    dataset = datasets.get_dataset("test", FLAGS)
    rng, key = random.split(rng)
    model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables

    # Rendering is forced to be deterministic even if training was randomized, as
    # this eliminates "speckle" artifacts.
    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, *rays,
                                              False),
                                  axis_name="batch")

    # pmap over only the data input.
    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),
        donate_argnums=3,
        axis_name="batch",
    )

    last_step = 0
    out_dir = path.join(FLAGS.train_dir,
                        "path_renders" if FLAGS.render_path else "test_preds")
    if not FLAGS.eval_once:
        summary_writer = tensorboard.SummaryWriter(
            path.join(FLAGS.train_dir, "eval"))
    while True:
        state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
        step = int(state.optimizer.state.step)
        if step <= last_step:
            continue
        if FLAGS.save_output and (not utils.isdir(out_dir)):
            utils.makedirs(out_dir)
        psnrs = []
        if not FLAGS.eval_once:
            showcase_index = np.random.randint(0, dataset.size)
        for idx in range(dataset.size):
            print(f"Evaluating {idx+1}/{dataset.size}")
            batch = next(dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, state.optimizer.target),
                batch["rays"],
                rng,
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() != 0:  # Only record via host 0.
                continue
            if not FLAGS.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                if not FLAGS.render_path:
                    showcase_gt = batch["pixels"]
            if not FLAGS.render_path:
                psnr = utils.compute_psnr(
                    ((pred_color - batch["pixels"])**2).mean())
                print(f"  PSNR = {psnr:.4f}")
                psnrs.append(float(psnr))
            if FLAGS.save_output:
                utils.save_img(pred_color,
                               path.join(out_dir, "{:03d}.png".format(idx)))
                utils.save_img(
                    pred_disp[Ellipsis, 0],
                    path.join(out_dir, "disp_{:03d}.png".format(idx)))
        if (not FLAGS.eval_once) and (jax.host_id() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            summary_writer.image("pred_disp", showcase_disp, step)
            summary_writer.image("pred_acc", showcase_acc, step)
            if not FLAGS.render_path:
                summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
                summary_writer.image("target", showcase_gt, step)
        if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id()
                                                              == 0):
            with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as pout:
                pout.write("{}".format(np.mean(np.array(psnrs))))
        if FLAGS.eval_once:
            break
        if int(step) >= FLAGS.max_steps:
            break
        last_step = step