Example #1
0
def main(unused_argv):
    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = random.PRNGKey(20200823)

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    dataset = datasets.get_dataset("test", FLAGS)
    rng, key = random.split(rng)
    model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables

    # Rendering is forced to be deterministic even if training was randomized, as
    # this eliminates "speckle" artifacts.
    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, rays,
                                              False),
                                  axis_name="batch")

    # pmap over only the data input.
    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),
        donate_argnums=3,
        axis_name="batch",
    )

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(functools.partial(utils.compute_ssim, max_val=1.),
                      backend="cpu")

    last_step = 0
    out_dir = path.join(FLAGS.train_dir,
                        "path_renders" if FLAGS.render_path else "test_preds")
    if not FLAGS.eval_once:
        summary_writer = tensorboard.SummaryWriter(
            path.join(FLAGS.train_dir, "eval"))
    while True:
        state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
        step = int(state.optimizer.state.step)
        if step <= last_step:
            continue
        if FLAGS.save_output and (not utils.isdir(out_dir)):
            utils.makedirs(out_dir)
        psnr_values = []
        ssim_values = []
        if not FLAGS.eval_once:
            showcase_index = np.random.randint(0, dataset.size)
        for idx in range(dataset.size):
            print(f"Evaluating {idx+1}/{dataset.size}")
            batch = next(dataset)
            (pred_color, pred_disp, pred_acc, pred_features,
             pred_specular) = utils.render_image(functools.partial(
                 render_pfn, state.optimizer.target),
                                                 batch["rays"],
                                                 rng,
                                                 FLAGS.dataset == "llff",
                                                 chunk=FLAGS.chunk)
            if jax.host_id() != 0:  # Only record via host 0.
                continue

            if not FLAGS.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                showcase_features = pred_features
                showcase_specular = pred_specular
                if not FLAGS.render_path:
                    showcase_gt = batch["pixels"]

            if not FLAGS.render_path:
                psnr = utils.compute_psnr(
                    ((pred_color - batch["pixels"])**2).mean())
                ssim = ssim_fn(pred_color, batch["pixels"])
                print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
                psnr_values.append(float(psnr))
                ssim_values.append(float(ssim))

            if FLAGS.save_output:
                utils.save_img(pred_color,
                               path.join(out_dir, "{:03d}.png".format(idx)))
                utils.save_img(
                    pred_disp[Ellipsis, 0],
                    path.join(out_dir, "disp_{:03d}.png".format(idx)))

        if (not FLAGS.eval_once) and (jax.host_id() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            summary_writer.image("pred_disp", showcase_disp, step)
            summary_writer.image("pred_acc", showcase_acc, step)
            summary_writer.image("pred_features", showcase_features, step)
            summary_writer.image("pred_specular", showcase_specular, step)
            if not FLAGS.render_path:
                summary_writer.scalar("psnr", np.mean(np.array(psnr_values)),
                                      step)
                summary_writer.scalar("ssim", np.mean(np.array(ssim_values)),
                                      step)
                summary_writer.image("target", showcase_gt, step)

        if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id()
                                                              == 0):
            with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"),
                                 "w") as f:
                f.write(" ".join([str(v) for v in psnr_values]))
            with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"),
                                 "w") as f:
                f.write(" ".join([str(v) for v in ssim_values]))
            with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
                f.write("{}".format(np.mean(np.array(psnr_values))))
            with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
                f.write("{}".format(np.mean(np.array(ssim_values))))

        if FLAGS.eval_once:
            break
        if int(step) >= FLAGS.max_steps:
            break
        last_step = step
Example #2
0
def main(unused_argv):
  rng = random.PRNGKey(20200823)
  # Shift the numpy random seed by host_id() to shuffle data loaded by different
  # hosts.
  np.random.seed(20201473 + jax.host_id())

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.batch_size % jax.device_count() != 0:
    raise ValueError("Batch size must be divisible by the number of devices.")
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")
  dataset = datasets.get_dataset("train", FLAGS)
  test_dataset = datasets.get_dataset("test", FLAGS)

  rng, key = random.split(rng)
  model, variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, variables

  learning_rate_fn = functools.partial(
      utils.learning_rate_decay,
      lr_init=FLAGS.lr_init,
      lr_final=FLAGS.lr_final,
      max_steps=FLAGS.max_steps,
      lr_delay_steps=FLAGS.lr_delay_steps,
      lr_delay_mult=FLAGS.lr_delay_mult)

  train_pstep = jax.pmap(
      functools.partial(train_step, model),
      axis_name="batch",
      in_axes=(0, 0, 0, None),
      donate_argnums=(2,))

  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, FLAGS.randomized),
        axis_name="batch")

  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),  # Only distribute the data input.
      donate_argnums=(3,),
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  if not utils.isdir(FLAGS.train_dir):
    utils.makedirs(FLAGS.train_dir)
  state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
  # Resume training a the step of the last checkpoint.
  init_step = state.optimizer.state.step + 1
  state = flax.jax_utils.replicate(state)

  if jax.host_id() == 0:
    summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

  # Prefetch_buffer_size = 3 x batch_size
  pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
  n_local_devices = jax.local_device_count()
  rng = rng + jax.host_id()  # Make random seed separate across hosts.
  keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
  gc.disable()  # Disable automatic garbage collection for efficiency.
  stats_trace = []
  reset_timer = True
  for step, batch in zip(range(init_step, FLAGS.max_steps + 1), pdataset):
    if reset_timer:
      t_loop_start = time.time()
      reset_timer = False
    lr = learning_rate_fn(step)
    state, stats, keys = train_pstep(keys, state, batch, lr)
    if jax.host_id() == 0:
      stats_trace.append(stats)
    if step % FLAGS.gc_every == 0:
      gc.collect()

    # Log training summaries. This is put behind a host_id check because in
    # multi-host evaluation, all hosts need to run inference even though we
    # only use host 0 to record results.
    if jax.host_id() == 0:
      if step % FLAGS.print_every == 0:
        summary_writer.scalar("train_loss", stats.loss[0], step)
        summary_writer.scalar("train_psnr", stats.psnr[0], step)
        summary_writer.scalar("train_sparsity", stats.sparsity[0], step)
        summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
        summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
        summary_writer.scalar("train_sparsity_coarse", stats.sparsity_c[0],
                              step)

        summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
        avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
        avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
        stats_trace = []
        summary_writer.scalar("train_avg_loss", avg_loss, step)
        summary_writer.scalar("train_avg_psnr", avg_psnr, step)
        summary_writer.scalar("learning_rate", lr, step)
        steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
        reset_timer = True
        rays_per_sec = FLAGS.batch_size * steps_per_sec
        summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
        summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
        precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
        print(("{:" + "{:d}".format(precision) + "d}").format(step) +
              f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
              f"avg_loss={avg_loss:0.4f}, " +
              f"weight_l2={stats.weight_l2[0]:0.2e}, " + f"lr={lr:0.2e}, " +
              f"{rays_per_sec:0.0f} rays/sec")
      if step % FLAGS.save_every == 0:
        state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(
            FLAGS.train_dir, state_to_save, int(step), keep=100)

    # Test-set evaluation.
    if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
      # We reuse the same random number generator from the optimization step
      # here on purpose so that the visualization matches what happened in
      # training.
      t_eval_start = time.time()
      eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
                                                   state)).optimizer.target
      test_case = next(test_dataset)
      (pred_color, pred_disp, pred_acc, pred_features,
       pred_specular) = utils.render_image(
           functools.partial(render_pfn, eval_variables),
           test_case["rays"],
           keys[0],
           FLAGS.dataset == "llff",
           chunk=FLAGS.chunk)

      # Log eval summaries on host 0.
      if jax.host_id() == 0:
        psnr = utils.compute_psnr(
            ((pred_color - test_case["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, test_case["pixels"])
        eval_time = time.time() - t_eval_start
        num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
        rays_per_sec = num_rays / eval_time
        summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
        print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
        summary_writer.scalar("test_psnr", psnr, step)
        summary_writer.scalar("test_ssim", ssim, step)
        summary_writer.image("test_pred_color", pred_color, step)
        summary_writer.image("test_pred_disp", pred_disp, step)
        summary_writer.image("test_pred_acc", pred_acc, step)
        summary_writer.image("test_pred_features", pred_features, step)
        summary_writer.image("test_pred_specular", pred_specular, step)
        summary_writer.image("test_target", test_case["pixels"], step)

  if FLAGS.max_steps % FLAGS.save_every != 0:
    state = jax.device_get(jax.tree_map(lambda x: x[0], state))
    checkpoints.save_checkpoint(
        FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
Example #3
0
def export_snerg_scene(output_directory, atlas, atlas_block_indices,
                       viewdir_mlp_params, render_params, atlas_params,
                       scene_params, input_height, input_width, input_focal):
    """Exports a scene to web-viewer format: a collection of PNGs and a JSON file.

  The scene gets exported to output_directory/png. Any previous results will
  be overwritten.

  Args:
    output_directory: The root directory where the scene gets written.
    atlas: The SNeRG scene packed as a texture atlas in a [S, S, N, C] numpy
      array, where the channels C contain both RGB and features.
    atlas_block_indices: The indirection grid of the SNeRG scene, represented as
      a numpy int32 array of size (bW, bH, bD, 3).
    viewdir_mlp_params: A dict containing the MLP parameters for the per-sample
      view-dependence MLP.
    render_params: A dict with parameters for high-res rendering.
    atlas_params: A dict with params for building the 3D texture atlas.
    scene_params: A dict for scene specific params (bbox, rotation, resolution).
    input_height: Height (pixels) of the NDC camera (i.e. the training cameras).
    input_width: Width (pixels) of the NDC camera (i.e. the training cameras).
    input_focal: Focal length (pixels) of the NDC camera (i.e. the
      training cameras).
  """
    # Slice the atlas into images.
    rgbs = []
    alphas = []
    for i in range(0, atlas.shape[2], 4):
        rgb_stack = []
        alpha_stack = []
        for j in range(4):
            plane_index = i + j
            rgb_stack.append(atlas[:, :,
                                   plane_index, :][Ellipsis,
                                                   0:3].transpose([1, 0, 2]))
            alpha_stack.append(
                atlas[:, :,
                      plane_index, :][Ellipsis,
                                      scene_params['_channels']].transpose(
                                          [1, 0]))
        rgbs.append(np.concatenate(rgb_stack, axis=0))
        alphas.append(np.concatenate(alpha_stack, axis=0))

    atlas_index_image = np.transpose(atlas_block_indices,
                                     [2, 1, 0, 3]).reshape(
                                         (-1, atlas_block_indices.shape[0],
                                          3)).astype(np.uint8)

    # Build a dictionary of the scene parameters, so we can export it as a json.
    export_scene_params = {}
    export_scene_params['voxel_size'] = float(render_params['_voxel_size'])
    export_scene_params['block_size'] = atlas_params['_data_block_size']
    export_scene_params['grid_width'] = int(render_params['_grid_size'][0])
    export_scene_params['grid_height'] = int(render_params['_grid_size'][1])
    export_scene_params['grid_depth'] = int(render_params['_grid_size'][2])
    export_scene_params['atlas_width'] = atlas.shape[0]
    export_scene_params['atlas_height'] = atlas.shape[1]
    export_scene_params['atlas_depth'] = atlas.shape[2]
    export_scene_params['num_slices'] = len(rgbs)

    export_scene_params['min_x'] = float(scene_params['min_xyz'][0])
    export_scene_params['min_y'] = float(scene_params['min_xyz'][1])
    export_scene_params['min_z'] = float(scene_params['min_xyz'][2])

    export_scene_params['atlas_blocks_x'] = int(
        atlas.shape[0] / atlas_params['atlas_block_size'])
    export_scene_params['atlas_blocks_y'] = int(
        atlas.shape[1] / atlas_params['atlas_block_size'])
    export_scene_params['atlas_blocks_z'] = int(
        atlas.shape[2] / atlas_params['atlas_block_size'])

    export_scene_params['input_height'] = float(input_height)
    export_scene_params['input_width'] = float(input_width)
    export_scene_params['input_focal'] = float(input_focal)

    export_scene_params['worldspace_T_opengl'] = scene_params[
        'worldspace_T_opengl'].tolist()
    export_scene_params['ndc'] = scene_params['ndc']

    # Also include the network weights in this dictionary.
    export_scene_params['0_weights'] = viewdir_mlp_params['params']['Dense_0'][
        'kernel'].tolist()
    export_scene_params['1_weights'] = viewdir_mlp_params['params']['Dense_1'][
        'kernel'].tolist()
    export_scene_params['2_weights'] = viewdir_mlp_params['params']['Dense_3'][
        'kernel'].tolist()
    export_scene_params['0_bias'] = viewdir_mlp_params['params']['Dense_0'][
        'bias'].tolist()
    export_scene_params['1_bias'] = viewdir_mlp_params['params']['Dense_1'][
        'bias'].tolist()
    export_scene_params['2_bias'] = viewdir_mlp_params['params']['Dense_3'][
        'bias'].tolist()

    # To avoid partial overwrites, first dump the scene to a temporary directory.
    output_tmp_directory = output_directory + '/temp'

    if jax.host_id() == 0:
        # Delete the folder if it already exists.
        if utils.isdir(output_tmp_directory):
            tf.io.gfile.rmtree(output_tmp_directory)
        utils.makedirs(output_tmp_directory)

    # Now store the indirection grid.
    atlas_indices_path = '%s/atlas_indices.png' % output_tmp_directory
    if jax.host_id() == 0:
        save_8bit_png((atlas_index_image, atlas_indices_path))

    # Make sure that all JAX hosts have reached this point in the code before we
    # proceed. Things will get tricky if output_tmp_directory doesn't yet exist.
    synchronize_jax_hosts()

    # Save the alpha values and RGB colors as one set of PNG images.
    output_images = []
    output_paths = []
    for i, rgb_and_alpha in enumerate(zip(rgbs, alphas)):
        rgb, alpha = rgb_and_alpha
        rgba = np.concatenate([rgb, np.expand_dims(alpha, -1)], axis=-1)
        uint_multiplier = 2.0**8 - 1.0
        rgba = np.minimum(uint_multiplier,
                          np.maximum(0.0, np.floor(uint_multiplier *
                                                   rgba))).astype(np.uint8)
        output_images.append(rgba)
        atlas_rgba_path = '%s/rgba_%03d.png' % (output_tmp_directory, i)
        output_paths.append(atlas_rgba_path)

    # Save the computed features a separate collection of PNGs.
    uint_multiplier = 2.0**8 - 1.0
    for i in range(0, atlas.shape[2], 4):
        feature_stack = []
        for j in range(4):
            plane_index = i + j
            feature_slice = atlas[:, :,
                                  plane_index, :][Ellipsis,
                                                  3:-1].transpose([1, 0, 2])
            feature_slice = np.minimum(
                uint_multiplier,
                np.maximum(0.0, np.floor(uint_multiplier *
                                         feature_slice))).astype(np.uint8)
            feature_stack.append(feature_slice)
        output_images.append(np.concatenate(feature_stack, axis=0))

    for i in range(len(rgbs)):
        output_paths.append('%s/feature_%03d.png' % (output_tmp_directory, i))

    parallel_write_images(save_8bit_png, list(zip(output_images,
                                                  output_paths)))

    # Now export the scene parameters and the network weights as a JSON.
    export_scene_params['format'] = 'png'
    scene_params_path = '%s/scene_params.json' % output_tmp_directory
    if jax.host_id() == 0:
        with utils.open_file(scene_params_path, 'wb') as f:
            f.write(json.dumps(export_scene_params).encode('utf-8'))

    # Again, make sure that the JAX hosts are in sync. Don't delete
    # output_tmp_directory before all files have been written.
    synchronize_jax_hosts()

    # Finally move the scene to the appropriate output path.
    output_png_directory = output_directory + '/png'
    if jax.host_id() == 0:
        # Delete the folder if it already exists.
        if utils.isdir(output_png_directory):
            tf.io.gfile.rmtree(output_png_directory)
        tf.io.gfile.rename(output_tmp_directory, output_png_directory)