Beispiel #1
0
 def imread(fs):
     if fs.endswith("png"):
         with file_utils.open_file(fs) as f:
             return imageio.imread(f, ignoregamma=True)
     else:
         with file_utils.open_file(fs) as f:
             return imageio.imread(f)
Beispiel #2
0
def get_train_scene_list(config):
  """Function to get the list of scenes.

  Args:
    config: experiment config.

  Returns:
    scene_path_list: list of scenes.
  """

  if config.dataset.name == "ff_epipolar":
    corrupted_and_test_list = [
        "howardzhou_010_internal_drawing_vase", "howardzhou_059_narcissus",
        "howardzhou_087_yellow_chain_links",
        "howardzhou_089_whilte_bmw_x3_front", "howardzhou_085_sweet_onions",
        "qq18", "qq33", "data2_fernvlsb", "data2_hugetrike", "data2_trexsanta",
        "data3_orchid", "data5_leafscene", "data5_lotr", "data5_redflower"
    ]
    scene_path_list = file_utils.listdir(config.dataset.ff_base_dir)
    scene_path_list = list(set(scene_path_list) - set(corrupted_and_test_list))

  elif config.dataset.name == "dtu":
    with file_utils.open_file(
        os.path.join(config.dataset.dtu_base_dir, "configs", "lists",
                     "dtu_train_all.txt")) as f:
      scene_path_list = [
          line.rstrip().decode("utf-8") for line in f.readlines()
      ]
  elif config.dataset.name == "blender_rot":
    scene_path_list = ["lego"]

  return scene_path_list
Beispiel #3
0
def save_img(img, pth):
    """Save an image to disk.

  Args:
    img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
      before saved to pth.
    pth: string, path to save the image to.
  """
    with file_utils.open_file(pth, "wb") as imgout:
        Image.fromarray(
            np.array((np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(
                imgout, "PNG")
    def _load_renderings(self, args):
        """Load images and camera information."""

        self.cam_transform = np.array([[1, 0, 0, 0], [0, -1, 0, 0],
                                       [0, 0, -1, 0], [0, 0, 0, 1]])
        #-------------------------------------------
        # Load images.
        #-------------------------------------------
        basedir = path.join(args.dataset.eval_ff_dir, self.scene)
        img0 = [
            os.path.join(basedir, "images", f) for f in sorted(
                file_utils.listdir(os.path.join(basedir, "images")))
            if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
        ][0]
        with file_utils.open_file(img0) as f:
            sh = imageio.imread(f).shape
        if sh[0] / sh[
                1] != args.dataset.eval_ff_image_height / args.dataset.eval_ff_image_width:
            raise ValueError("not expected height width ratio")

        factor = sh[0] / args.dataset.eval_ff_image_height

        sfx = "_4"
        imgdir = os.path.join(basedir, "images" + sfx)
        if not file_utils.file_exists(imgdir):
            imgdir = os.path.join(basedir, "images")
            if not file_utils.file_exists(imgdir):
                raise ValueError("{} does not exist".format(imgdir))

        images = self._load_images(imgdir, args.dataset.eval_ff_image_width,
                                   args.dataset.eval_ff_image_height)

        #-------------------------------------------
        # Load poses and bds.
        #-------------------------------------------
        with file_utils.open_file(path.join(basedir, "poses_bounds.npy"),
                                  "rb") as fp:
            poses_arr = np.load(fp)

        # Get the intrinsic matrix
        with file_utils.open_file(path.join(basedir, "hwf_cxcy.npy"),
                                  "rb") as fp:
            self.intrinsic_arr = np.load(fp)

        # Update the intrinsic matix to accounto for resizing
        self.intrinsic_arr = self.intrinsic_arr * 1. / factor

        # poses_arr contains an array consisting of a 3x4 pose matrices and
        # 2 depth bounds for each image. The pose matrix contain [R t] as the
        # left 3x4 matrix
        # pose_arr has shape (...,14) {3x4 + 2}
        poses = poses_arr[:, :-2].reshape([-1, 3, 4]).transpose([1, 2, 0])
        bds = poses_arr[:, -2:].transpose([1, 0])

        # Convert R matrix from the form [down right back] to [right up back]
        poses = np.concatenate(
            [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)

        # Transpose such that the first dimension is number of images
        images = np.moveaxis(images, -1, 0)
        poses = np.moveaxis(poses, -1, 0).astype(np.float32)
        bds = np.moveaxis(bds, -1, 0).astype(np.float32)

        if args.dataset.normalize:
            scale = 1. / bds.max()
        else:
            scale = 1. / (bds.min() * .75)

        poses[:, :3, 3] *= scale
        bds *= scale
        poses_copy = poses.copy()
        poses_copy = pose_utils.recenter_poses(poses, None)
        poses = pose_utils.recenter_poses(poses, self.cam_transform)

        # Get the min and max depth of the scene
        self.min_depth = np.array([bds.min()])
        self.max_depth = np.array([bds.max()])
        # Use this to set the near and far plane
        args.model.near = self.min_depth.item()
        args.model.far = self.max_depth.item()

        if self.split == "test":
            self.render_poses = pose_utils.generate_spiral_poses(
                poses_copy, bds, self.cam_transform)

        # Select the split.
        if args.eval.mvsn_style:
            with file_utils.open_file(
                    os.path.join(os.path.dirname(basedir), "pairs.npz")) as f:
                img_ids = np.load(f)
                i_test = img_ids["{}_{}".format(os.path.basename(basedir),
                                                "test")]
                i_train = img_ids["{}_{}".format(os.path.basename(basedir),
                                                 "train")]
        else:
            i_test = np.arange(images.shape[0])[::args.dataset.llffhold]
            i_train = np.array([
                i for i in np.arange(int(images.shape[0])) if i not in i_test
            ])

        if self.split == "train":
            indices = i_train
        else:
            indices = i_test
        images = images[indices]
        poses = poses[indices]

        self.images = images
        self.camtoworlds = poses[:, :3, :4]

        # intrinsic arr has H, W, fx, fy, cx, cy
        self.focal = self.intrinsic_arr[2][0]
        self.h, self.w = images.shape[1:3]
        self.resolution = self.h * self.w

        if args.dataset.render_path and self.split == "test":
            self.n_examples = self.render_poses.shape[0]
        else:
            self.n_examples = images.shape[0]

        _, _, fx, fy, cx, cy = self.intrinsic_arr[:, 0]
        self.intrinsic_matrix = np.array([
            [fx, 0, cx, 0],
            [0, fy, cy, 0],
            [0, 0, 1, 0],
        ]).astype(np.float32)
Beispiel #5
0
def evaluate(config, workdir):
    """Evalution function."""

    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = jax.random.PRNGKey(config.seed)

    if config.eval.return_coarse:
        sfx = "coarse"
    else:
        sfx = ""

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())

    train_ds, test_ds_dict = datasets.create_eval_dataset(config)
    example_batch = train_ds.peek()

    rng, key = jax.random.split(rng)

    #----------------------------------------------------------------------------
    # Initialize model.
    learning_rate_fn = train_utils.create_learning_rate_fn(config)
    model, state = models.create_train_state(config,
                                             key,
                                             learning_rate_fn=learning_rate_fn,
                                             example_batch=example_batch)

    #----------------------------------------------------------------------------
    # Get the rendering function. Renderig is forced ot be deterministic even if
    # trainin is randomized
    render_pfn = render_utils.get_render_function(model,
                                                  config,
                                                  randomized=False)

    last_step = 0
    out_dir_dict = {}
    for key in test_ds_dict:
        out_dir_dict[key] = os.path.join(
            workdir, "path_renders" +
            sfx if config.dataset.render_path else "test_preds" + sfx, key)

    if not config.eval.eval_once:
        summary_writer_dict = {}
        for key in test_ds_dict:
            # Prepare Metric Writers
            summary_writer_dict[key] = tensorboard.SummaryWriter(
                os.path.join(workdir, "eval" + sfx, key))
        summary_writer_dict["all"] = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval" + sfx))

    while True:
        state = checkpoints.restore_checkpoint(
            workdir,
            state,
            step=None if config.eval.checkpoint_step == -1 else
            config.eval.checkpoint_step)
        step = int(state.step)
        if step <= last_step:
            continue
        total_psnr = 0

        for scene_name, test_ds in test_ds_dict.items():
            out_dir = out_dir_dict[scene_name]
            if config.eval.save_output and (not file_utils.isdir(out_dir)):
                file_utils.makedirs(out_dir)

            psnr_values = []
            ssim_values = []

            if not config.eval.eval_once:
                showcase_index = np.random.randint(0, test_ds.size)

            for idx in range(test_ds.size):
                logging.info("Evaluating scene %s [%d / %d].", scene_name, idx,
                             test_ds.size)
                batch = next(test_ds)
                test_pixels = batch.target_view.rgb
                test_mask = batch.target_view.mask
                if test_pixels is not None:
                    test_pixels = model_utils.uint2float(test_pixels)

                #-----------------------------------------------------------
                # Render Image
                variables = {"params": state.params}
                pred_color, pred_disp, pred_acc = render_utils.render_image(
                    functools.partial(render_pfn, variables),
                    batch,
                    rng,
                    render_utils.normalize_disp(config.dataset.name),
                    chunk=config.eval.chunk,
                    return_coarse=config.eval.return_coarse,
                )

                if jax.process_index() != 0:
                    continue

                #-----------------------------------------------------------
                # Get showcase example for logging
                if not config.eval.eval_once and idx == showcase_index:
                    showcase_color = pred_color
                    showcase_disp = pred_disp
                    showcase_acc = pred_acc
                    if not config.dataset.render_path:
                        showcase_gt = test_pixels
                #-----------------------------------------------------------
                # If get pixels available, evaluate
                if not config.dataset.render_path:
                    if config.eval.mvsn_style:
                        h_crop, w_crop = np.array(pred_color.shape[:2]) // 10
                        pred_color = pred_color[h_crop:-h_crop, w_crop:-w_crop]
                        test_pixels = test_pixels[h_crop:-h_crop,
                                                  w_crop:-w_crop]

                    if test_mask is not None:
                        psnr = model_utils.compute_psnr(
                            ((pred_color[test_mask] -
                              test_pixels[test_mask])**2).mean())
                    else:
                        psnr = model_utils.compute_psnr(
                            ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weights=True)
                    logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")  # pylint: disable=logging-fstring-interpolation
                    psnr_values.append(float(psnr))
                    ssim_values.append(float(ssim))

                #-----------------------------------------------------------
                # Save generated image
                if config.eval.save_output:
                    model_utils.save_img(
                        pred_color,
                        os.path.join(out_dir, "{:03d}.png".format(idx)))
                    if pred_disp is not None:
                        model_utils.save_img(
                            pred_disp[Ellipsis, 0],
                            os.path.join(out_dir,
                                         "disp_{:03d}.png".format(idx)))
            #-----------------------------------------------------------
            if (not config.eval.eval_once) and (jax.process_index() == 0):
                summary_writer_dict[scene_name].image(
                    "eval/{}_pred_color".format(scene_name), showcase_color,
                    step)
                if showcase_disp is not None:
                    summary_writer_dict[scene_name].image(
                        "eval/{}_pred_disp".format(scene_name), showcase_disp,
                        step)
                if showcase_acc is not None:
                    summary_writer_dict[scene_name].image(
                        "eval/{}_pred_acc".format(scene_name), showcase_acc,
                        step)
                if not config.dataset.render_path:
                    summary_writer_dict[scene_name].scalar(
                        "eval_metric/psnr", np.mean(np.array(psnr_values)),
                        step)
                    summary_writer_dict[scene_name].scalar(
                        "eval_metric/ssim", np.mean(np.array(ssim_values)),
                        step)
                    summary_writer_dict[scene_name].image(
                        "eval/{}_target".format(scene_name), showcase_gt, step)

            #-----------------------------------------------------------
            # Save the metric to file
            if config.eval.save_output and (not config.dataset.render_path
                                            ) and (jax.process_index() == 0):
                with file_utils.open_file(
                        os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
                    f.write(" ".join([str(v) for v in psnr_values]))
                with file_utils.open_file(
                        os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
                    f.write(" ".join([str(v) for v in ssim_values]))
                with file_utils.open_file(os.path.join(out_dir, "psnr.txt"),
                                          "w") as f:
                    f.write("{}".format(np.mean(np.array(psnr_values))))
                with file_utils.open_file(os.path.join(out_dir, "ssim.txt"),
                                          "w") as f:
                    f.write("{}".format(np.mean(np.array(ssim_values))))
            total_psnr += np.mean(np.array(psnr_values))

        if not config.eval.eval_once:
            summary_writer_dict["all"].scalar(
                "eval_metric/avg_psnr_{}".format(config.dataset.eval_dataset),
                total_psnr / len(test_ds_dict.keys()), step)

        if config.eval.eval_once:
            break
        if int(step) >= config.train.max_steps:
            break
        last_step = step

    logging.info("Finishing evaluation at step %d", last_step)
Beispiel #6
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())

    scene_path_list = train_utils.get_train_scene_list(config)

    train_ds = datasets.create_train_dataset(config, scene_path_list[0])
    _, eval_ds_dict = datasets.create_eval_dataset(config)
    _, eval_ds = eval_ds_dict.popitem()
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.

    # check if the job was stopped and relaunced
    latest_ckpt = checkpoints.latest_checkpoint(workdir)
    if latest_ckpt is None:
        # No previous checkpoint. Then check for pretrained weights.
        if config.train.pretrain_dir:
            state = checkpoints.restore_checkpoint(config.train.pretrain_dir,
                                                   state)
    else:
        state = checkpoints.restore_checkpoint(workdir, state)

    initial_step = int(state.step) + 1
    step_per_scene = config.train.switch_scene_iter
    if config.dev_run:
        jnp.set_printoptions(precision=2)
        np.set_printoptions(precision=2)
        step_per_scene = 3

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            if step % step_per_scene == 0:
                scene_idx = np.random.randint(len(scene_path_list))
                logging.info("Loading scene {}".format(
                    scene_path_list[scene_idx]))  # pylint: disable=logging-format-interpolation
                curr_scene = scene_path_list[scene_idx]
                if config.dataset.name == "dtu":
                    # lighting can take values between 0 and 6 (both included)
                    config.dataset.dtu_light_idx = np.random.randint(low=0,
                                                                     high=7)
                train_ds = datasets.create_train_dataset(config, curr_scene)
                ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)

            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = 0.
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)