コード例 #1
0
 def _load_renderings(self, args):
     """Load images from disk."""
     if args.render_path:
         raise ValueError(
             "render_path cannot be used for the blender dataset.")
     with utils.open_file(
             path.join(args.data_dir,
                       "transforms_{}.json".format(self.split)), "r") as fp:
         meta = json.load(fp)
     images = []
     cams = []
     for i in range(len(meta["frames"])):
         frame = meta["frames"][i]
         fname = os.path.join(args.data_dir, frame["file_path"] + ".png")
         with utils.open_file(fname, "rb") as imgin:
             image = np.array(Image.open(imgin), dtype=np.float32) / 255.
             if args.factor == 2:
                 [halfres_h,
                  halfres_w] = [hw // 2 for hw in image.shape[:2]]
                 image = cv2.resize(image, (halfres_w, halfres_h),
                                    interpolation=cv2.INTER_AREA)
             elif args.factor > 0:
                 raise ValueError(
                     "Blender dataset only supports factor=0 or 2, {} "
                     "set.".format(args.factor))
         cams.append(frame["transform_matrix"])
         images.append(image)
     self.images = np.stack(images, axis=0)
     if args.white_bkgd:
         self.images = (
             self.images[Ellipsis, :3] * self.images[Ellipsis, -1:] +
             (1. - self.images[Ellipsis, -1:]))
     else:
         self.images = self.images[Ellipsis, :3]
     self.h, self.w = self.images.shape[1:3]
     self.resolution = self.h * self.w
     self.camtoworlds = np.stack(cams, axis=0)
     camera_angle_x = float(meta["camera_angle_x"])
     self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
     self.n_examples = self.images.shape[0]
コード例 #2
0
  def _load_renderings(self, args):
    """Load images from disk."""
    # Load images.
    imgdir_suffix = ""
    if args.factor > 0:
      imgdir_suffix = "_{}".format(args.factor)
      factor = args.factor
    else:
      factor = 1
    imgdir = path.join(args.data_dir, "images" + imgdir_suffix)
    if not utils.file_exists(imgdir):
      raise ValueError("Image folder {} doesn't exist.".format(imgdir))
    imgfiles = [
        path.join(imgdir, f)
        for f in sorted(utils.listdir(imgdir))
        if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
    ]
    images = []
    for imgfile in imgfiles:
      with utils.open_file(imgfile, "rb") as imgin:
        image = np.array(Image.open(imgin), dtype=np.float32) / 255.
        images.append(image)
    images = np.stack(images, axis=-1)

    # Load poses and bds.
    with utils.open_file(path.join(args.data_dir, "poses_bounds.npy"),
                         "rb") as fp:
      poses_arr = np.load(fp)
    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
    bds = poses_arr[:, -2:].transpose([1, 0])
    if poses.shape[-1] != images.shape[-1]:
      raise RuntimeError("Mismatch between imgs {} and poses {}".format(
          images.shape[-1], poses.shape[-1]))

    # Update poses according to downsampling.
    poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
    poses[2, 4, :] = poses[2, 4, :] * 1. / factor

    # Correct rotation matrix ordering and move variable dim to axis 0.
    poses = np.concatenate(
        [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
    poses = np.moveaxis(poses, -1, 0).astype(np.float32)
    images = np.moveaxis(images, -1, 0)
    bds = np.moveaxis(bds, -1, 0).astype(np.float32)

    # Rescale according to a default bd factor.
    scale = 1. / (bds.min() * .75)
    poses[:, :3, 3] *= scale
    bds *= scale

    # Recenter poses.
    poses = self._recenter_poses(poses)

    # Generate a spiral/spherical ray path for rendering videos.
    if args.spherify:
      poses = self._generate_spherical_poses(poses, bds)
      self.spherify = True
    else:
      self.spherify = False
    if not args.spherify and self.split == "test":
      self._generate_spiral_poses(poses, bds)

    # Select the split.
    i_test = np.arange(images.shape[0])[::args.llffhold]
    i_train = np.array(
        [i for i in np.arange(int(images.shape[0])) if i not in i_test])
    if self.split == "train":
      indices = i_train
    else:
      indices = i_test
    images = images[indices]
    poses = poses[indices]

    self.images = images
    self.camtoworlds = poses[:, :3, :4]
    self.focal = poses[0, -1, -1]
    self.h, self.w = images.shape[1:3]
    self.resolution = self.h * self.w
    if args.render_path:
      self.n_examples = self.render_poses.shape[0]
    else:
      self.n_examples = images.shape[0]
コード例 #3
0
ファイル: eval.py プロジェクト: boxofpasta/jaxnerf
def main(unused_argv):
  # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
  # LPIPS computation or dataset loading.
  tf.config.experimental.set_visible_devices([], "GPU")
  tf.config.experimental.set_visible_devices([], "TPU")

  rng = random.PRNGKey(20200823)

  if FLAGS.config is not None:
    utils.update_flags(FLAGS)
  if FLAGS.train_dir is None:
    raise ValueError("train_dir must be set. None set now.")
  if FLAGS.data_dir is None:
    raise ValueError("data_dir must be set. None set now.")

  dataset = datasets.get_dataset("test", FLAGS)
  rng, key = random.split(rng)
  model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
  optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
  state = utils.TrainState(optimizer=optimizer)
  del optimizer, init_variables

  lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)

  # Rendering is forced to be deterministic even if training was randomized, as
  # this eliminates "speckle" artifacts.
  def render_fn(variables, key_0, key_1, rays):
    return jax.lax.all_gather(
        model.apply(variables, key_0, key_1, rays, False), axis_name="batch")

  # pmap over only the data input.
  render_pfn = jax.pmap(
      render_fn,
      in_axes=(None, None, None, 0),
      donate_argnums=3,
      axis_name="batch",
  )

  # Compiling to the CPU because it's faster and more accurate.
  ssim_fn = jax.jit(
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

  last_step = 0
  out_dir = path.join(FLAGS.train_dir,
                      "path_renders" if FLAGS.render_path else "test_preds")
  if not FLAGS.eval_once:
    summary_writer = tensorboard.SummaryWriter(
        path.join(FLAGS.train_dir, "eval"))
  while True:
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    step = int(state.optimizer.state.step)
    if step <= last_step:
      continue
    if FLAGS.save_output and (not utils.isdir(out_dir)):
      utils.makedirs(out_dir)
    psnr_values = []
    ssim_values = []
    lpips_values = []
    if not FLAGS.eval_once:
      showcase_index = np.random.randint(0, dataset.size)
    for idx in range(dataset.size):
      print(f"Evaluating {idx+1}/{dataset.size}")
      batch = next(dataset)
      pred_color, pred_disp, pred_acc = utils.render_image(
          functools.partial(render_pfn, state.optimizer.target),
          batch["rays"],
          rng,
          FLAGS.dataset == "llff",
          chunk=FLAGS.chunk)
      if jax.host_id() != 0:  # Only record via host 0.
        continue
      if not FLAGS.eval_once and idx == showcase_index:
        showcase_color = pred_color
        showcase_disp = pred_disp
        showcase_acc = pred_acc
        if not FLAGS.render_path:
          showcase_gt = batch["pixels"]
      if not FLAGS.render_path:
        psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
        ssim = ssim_fn(pred_color, batch["pixels"])
        lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
        print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
        psnr_values.append(float(psnr))
        ssim_values.append(float(ssim))
        lpips_values.append(float(lpips))
      if FLAGS.save_output:
        utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
        utils.save_img(pred_disp[Ellipsis, 0],
                       path.join(out_dir, "disp_{:03d}.png".format(idx)))
    if (not FLAGS.eval_once) and (jax.host_id() == 0):
      summary_writer.image("pred_color", showcase_color, step)
      summary_writer.image("pred_disp", showcase_disp, step)
      summary_writer.image("pred_acc", showcase_acc, step)
      if not FLAGS.render_path:
        summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
        summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
        summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
        summary_writer.image("target", showcase_gt, step)
    if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
      with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in psnr_values]))
      with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in ssim_values]))
      with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
        f.write(" ".join([str(v) for v in lpips_values]))
      with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(psnr_values))))
      with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(ssim_values))))
      with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
        f.write("{}".format(np.mean(np.array(lpips_values))))
    if FLAGS.eval_once:
      break
    if int(step) >= FLAGS.max_steps:
      break
    last_step = step
コード例 #4
0
ファイル: eval.py プロジェクト: tomastokar/google-research
def main(unused_argv):
    rng = random.PRNGKey(20200823)

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    dataset = datasets.get_dataset("test", FLAGS)
    rng, key = random.split(rng)
    model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, init_variables

    # Rendering is forced to be deterministic even if training was randomized, as
    # this eliminates "speckle" artifacts.
    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(model.apply(variables, key_0, key_1, *rays,
                                              False),
                                  axis_name="batch")

    # pmap over only the data input.
    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),
        donate_argnums=3,
        axis_name="batch",
    )

    last_step = 0
    out_dir = path.join(FLAGS.train_dir,
                        "path_renders" if FLAGS.render_path else "test_preds")
    if not FLAGS.eval_once:
        summary_writer = tensorboard.SummaryWriter(
            path.join(FLAGS.train_dir, "eval"))
    while True:
        state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
        step = int(state.optimizer.state.step)
        if step <= last_step:
            continue
        if FLAGS.save_output and (not utils.isdir(out_dir)):
            utils.makedirs(out_dir)
        psnrs = []
        if not FLAGS.eval_once:
            showcase_index = np.random.randint(0, dataset.size)
        for idx in range(dataset.size):
            print(f"Evaluating {idx+1}/{dataset.size}")
            batch = next(dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, state.optimizer.target),
                batch["rays"],
                rng,
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)
            if jax.host_id() != 0:  # Only record via host 0.
                continue
            if not FLAGS.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                if not FLAGS.render_path:
                    showcase_gt = batch["pixels"]
            if not FLAGS.render_path:
                psnr = utils.compute_psnr(
                    ((pred_color - batch["pixels"])**2).mean())
                print(f"  PSNR = {psnr:.4f}")
                psnrs.append(float(psnr))
            if FLAGS.save_output:
                utils.save_img(pred_color,
                               path.join(out_dir, "{:03d}.png".format(idx)))
                utils.save_img(
                    pred_disp[Ellipsis, 0],
                    path.join(out_dir, "disp_{:03d}.png".format(idx)))
        if (not FLAGS.eval_once) and (jax.host_id() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            summary_writer.image("pred_disp", showcase_disp, step)
            summary_writer.image("pred_acc", showcase_acc, step)
            if not FLAGS.render_path:
                summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
                summary_writer.image("target", showcase_gt, step)
        if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id()
                                                              == 0):
            with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as pout:
                pout.write("{}".format(np.mean(np.array(psnrs))))
        if FLAGS.eval_once:
            break
        if int(step) >= FLAGS.max_steps:
            break
        last_step = step