예제 #1
0
def save_img(img, pth):
  """Save an image to disk.

  Args:
    img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
      before saved to pth.
    pth: string, path to save the image to.
  """
  with file_utils.open_file(pth, "wb") as imgout:
    Image.fromarray(np.array(
        (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
  def _load_images(self, imgdir):
    """Function to load images."""

    if not file_utils.file_exists(imgdir):
      raise ValueError("Image folder {} doesn't exist.".format(imgdir))
    imgfiles = [
        path.join(imgdir, f)
        for f in sorted(file_utils.listdir(imgdir))
        if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
    ]
    images = []
    for imgfile in tqdm.tqdm(imgfiles):
      with file_utils.open_file(imgfile, "rb") as imgin:
        image = np.array(Image.open(imgin), dtype=np.uint8)
        images.append(image)
    images = np.stack(images, axis=-1)
    return images
  def _get_suffix(self, args):
    """Get the suffix for reshaped image directory."""

    imgdir_suffix = ""
    if args.dataset.factor > 0:
      imgdir_suffix = "_{}".format(int(args.dataset.factor))
      factor = args.dataset.factor
    elif args.dataset.image_height > 0:
      img_path = path.join(args.dataset.data_dir, "images")
      img0 = [
          path.join(args.dataset.data_dir, "images", f)
          for f in sorted(file_utils.listdir(img_path))
          if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
      ][0]
      with file_utils.open_file(img0, "rb") as imgin:
        sh = np.array(Image.open(imgin), dtype=np.uint8).shape
      factor = sh[0] / float(args.dataset.image_height)
      width = int(sh[1] / factor)
      imgdir_suffix = "_{}x{}".format(width, args.dataset.image_height)
    else:
      factor = 1

    return imgdir_suffix, factor
예제 #4
0
def evaluate(config, workdir):
    """Evalution function."""

    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = jax.random.PRNGKey(config.seed)

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, test_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    rng, key = jax.random.split(rng)

    #----------------------------------------------------------------------------
    # Initialize model.
    learning_rate_fn = train_utils.create_learning_rate_fn(config)
    model, state = models.create_train_state(config,
                                             key,
                                             learning_rate_fn=learning_rate_fn,
                                             example_batch=example_batch)

    #----------------------------------------------------------------------------
    # Get the rendering function. Renderig is forced ot be deterministic even if
    # trainin is randomized
    render_pfn = render_utils.get_render_function(model,
                                                  config,
                                                  randomized=False)

    last_step = 0
    out_dir = os.path.join(
        workdir,
        "path_renders" if config.dataset.render_path else "test_preds")

    if not config.eval.eval_once:
        # Prepare Metric Writers
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))

    while True:
        state = checkpoints.restore_checkpoint(workdir, state)
        step = int(state.step)
        if step <= last_step:
            continue

        if config.eval.save_output and (not file_utils.isdir(out_dir)):
            file_utils.makedirs(out_dir)

        psnr_values = []
        ssim_values = []

        if not config.eval.eval_once:
            showcase_index = np.random.randint(0, test_ds.size)

        for idx in range(test_ds.size):
            logging.info("Evaluating [%d / %d].", idx, test_ds.size)
            batch = next(test_ds)
            test_pixels = batch.target_view.rgb
            if test_pixels is not None:
                test_pixels = model_utils.uint2float(test_pixels)

            #-----------------------------------------------------------
            # Render Image
            variables = {"params": state.params}
            pred_color, pred_disp, pred_acc = render_utils.render_image(
                functools.partial(render_pfn, variables),
                batch,
                rng,
                render_utils.normalize_disp(config.dataset.name),
                chunk=config.eval.chunk)

            if jax.process_index() != 0:
                continue

            #-----------------------------------------------------------
            # Get showcase example for logging
            if not config.eval.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                if not config.dataset.render_path:
                    showcase_gt = test_pixels
            #-----------------------------------------------------------
            # If get pixels available, evaluate
            if not config.dataset.render_path:
                psnr = model_utils.compute_psnr(
                    ((pred_color - test_pixels)**2).mean())
                ssim = skmetrics.structural_similarity(
                    pred_color.astype(np.float32),
                    test_pixels.astype(np.float32),
                    win_size=11,
                    multichannel=True,
                    gaussian_weights=True)
                logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")  # pylint: disable=logging-format-interpolation
                psnr_values.append(float(psnr))
                ssim_values.append(float(ssim))

            #-----------------------------------------------------------
            # Save generated image
            if config.eval.save_output:
                model_utils.save_img(
                    pred_color, os.path.join(out_dir,
                                             "{:03d}.png".format(idx)))
                if pred_disp is not None:
                    model_utils.save_img(
                        pred_disp[Ellipsis, 0],
                        os.path.join(out_dir, "disp_{:03d}.png".format(idx)))
        #-----------------------------------------------------------
        if (not config.eval.eval_once) and (jax.process_index() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            if showcase_disp is not None:
                summary_writer.image("pred_disp", showcase_disp, step)
            if showcase_acc is not None:
                summary_writer.image("pred_acc", showcase_acc, step)
            if not config.dataset.render_path:
                summary_writer.scalar("eval_metric/psnr",
                                      np.mean(np.array(psnr_values)), step)
                summary_writer.scalar("eval_metric/ssim",
                                      np.mean(np.array(ssim_values)), step)
                summary_writer.image("target", showcase_gt, step)

        #-----------------------------------------------------------
        # Save the metric to file
        if config.eval.save_output and (not config.dataset.render_path) and (
                jax.process_index() == 0):
            with file_utils.open_file(
                    os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in psnr_values]))
            with file_utils.open_file(
                    os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in ssim_values]))
            with file_utils.open_file(os.path.join(out_dir, "psnr.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(psnr_values))))
            with file_utils.open_file(os.path.join(out_dir, "ssim.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(ssim_values))))
        if config.eval.eval_once:
            break
        if int(step) >= config.train.max_steps:
            break
        last_step = step

    logging.info("Finishing evaluation at step %d", last_step)
  def _load_renderings(self, args):
    """Load images and camera information."""

    #-------------------------------------------
    # Load images.
    #-------------------------------------------
    imgdir_suffix, factor = self._get_suffix(args)
    imgdir = path.join(args.dataset.data_dir, "images" + imgdir_suffix)
    images = self._load_images(imgdir)

    #-------------------------------------------
    # Load poses and bds.
    #-------------------------------------------
    with file_utils.open_file(
        path.join(args.dataset.data_dir, "poses_bounds.npy"), "rb") as fp:
      poses_arr = np.load(fp)

    # Get the intrinsic matrix
    with file_utils.open_file(
        path.join(args.dataset.data_dir, "hwf_cxcy.npy"), "rb") as fp:
      self.intrinsic_arr = np.load(fp)

    # Update the intrinsic matix to accounto for resizing
    self.intrinsic_arr = self.intrinsic_arr * 1. / factor

    # poses_arr contains an array consisting of a 3x4 pose matrices and
    # 2 depth bounds for each image. The pose matrix contain [R t] as the
    # left 3x4 matrix
    # pose_arr has shape (...,14) {3x4 + 2}
    poses = poses_arr[:, :-2].reshape([-1, 3, 4]).transpose([1, 2, 0])
    bds = poses_arr[:, -2:].transpose([1, 0])

    # Convert R matrix from the form [down right back] to [right up back]
    poses = np.concatenate(
        [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)

    # Transpose such that the first dimension is number of images
    images = np.moveaxis(images, -1, 0)
    poses = np.moveaxis(poses, -1, 0).astype(np.float32)
    bds = np.moveaxis(bds, -1, 0).astype(np.float32)

    scale = 1. / (bds.min() * .75)

    poses[:, :3, 3] *= scale
    bds *= scale
    poses = self._recenter_poses(poses)

    # Get the min and max depth of the scene
    self.min_depth = bds.min()
    self.max_depth = bds.max()
    # Use this to set the near and far plane
    args.model.near = self.min_depth.item()
    args.model.far = self.max_depth.item()

    if self.split == "test":
      self._generate_spiral_poses(poses, bds)

    # Select the split.
    i_test = np.arange(images.shape[0])[::args.dataset.llffhold]
    i_train = np.array(
        [i for i in np.arange(int(images.shape[0])) if i not in i_test])
    if self.split == "train":
      indices = i_train
    else:
      indices = i_test
    images = images[indices]
    poses = poses[indices]

    self.images = images
    self.camtoworlds = poses[:, :3, :4]

    # intrinsic arr has H, W, fx, fy, cx, cy
    self.focal = self.intrinsic_arr[2][0]
    self.h, self.w = images.shape[1:3]
    self.resolution = self.h * self.w

    if args.dataset.render_path and self.split == "test":
      self.n_examples = self.render_poses.shape[0]
    else:
      self.n_examples = images.shape[0]

    _, _, fx, fy, cx, cy = self.intrinsic_arr[:, 0]
    self.intrinsic_matrix = np.array([[fx, 0, -cx, 0], [0, -fy, -cy, 0],
                                      [0, 0, 1, 0]]).astype(np.float32)
예제 #6
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, eval_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.
    state = checkpoints.restore_checkpoint(workdir, state)
    initial_step = int(state.step) + 1

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weight=True)
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)