def loss_fn(params):
        variables = {"params": params}
        ret = model.apply(variables,
                          key_0,
                          key_1,
                          batch,
                          randomized=config.model.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        #------------------------------------------------------------------------
        # Main prediction
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        batch_pixels = model_utils.uint2float(batch.target_view.rgb)
        loss = ((rgb - batch_pixels[Ellipsis, :3])**2).mean()
        psnr = model_utils.compute_psnr(loss)

        #------------------------------------------------------------------------
        # Coarse / Regularization Prediction
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch_pixels[Ellipsis, :3])**2).mean()
            psnr_c = model_utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        #------------------------------------------------------------------------
        # Weight Regularization
        weight_penalty_params = jax.tree_leaves(variables["params"])
        weight_l2 = sum(
            [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
        weight_penalty = weight_decay * 0.5 * weight_l2

        #------------------------------------------------------------------------
        # Compute total loss and wrap the stats
        total_loss = loss + loss_c + weight_penalty
        stats = train_utils.Stats(loss=loss,
                                  psnr=psnr,
                                  loss_c=loss_c,
                                  psnr_c=psnr_c,
                                  weight_l2=weight_l2)
        return total_loss, stats
Beispiel #2
0
    def __call__(self, rng_0, rng_1, batch, randomized):
        """Light Field Neural Rendering Model.

    Args:
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
      rng_1: jnp.ndarray, random number generator for fine model sampling.
      batch: data batch. data_types.Batch
      randomized: bool, use randomized stratified sampling.

    Returns:
      ret: list, [(rgb, None, Optional[acc])]
    """
        del rng_1

        # Get the batch rays
        batch_rays = batch.target_view.rays

        #---------------------------------------------------------------------------------------
        # Operations relating epipolar projections.
        # Using the given rays, project them to the world and then to nearby
        # cameras.
        projected_coordinates, _, wcoords = self.projector.epipolar_projection(
            rng_0, batch_rays, batch.reference_views.ref_worldtocamera,
            batch.reference_views.intrinsic_matrix, randomized)

        # Next we need to get the rgb values and the rays corresponding to these
        # projections.
        projected_rays = self.projector.get_near_rays(
            projected_coordinates, batch.reference_views.ref_cameratoworld,
            batch.reference_views.intrinsic_matrix)

        ref_images = model_utils.uint2float(batch.reference_views.rgb)
        projected_rgb_and_feat = self._get_pixel_projection(
            projected_coordinates, ref_images)

        batch.reference_views.rgb = None
        #----------------------------------------------------------------------------------------
        # Get LF representation of the batch and the projected rays.
        # Below we consider the representation extracted from the batch rays as the
        # query and representation extracted from the projected rays as keys and the
        # projected rgb as the values.
        _, input_q, _ = self._get_query(batch_rays)
        _, input_k, _, learned_embedding = self._get_key(
            projected_rays, projected_rgb_and_feat, wcoords,
            batch.reference_views.idx)

        # Get the average feature over each epipolar line
        avg_projection_features, e_attn = self._get_avg_features(
            input_q, input_k, randomized=randomized)

        rgb, n_attn = self._predict_color(input_q, avg_projection_features,
                                          learned_embedding, randomized)
        rgb_coarse = self._get_reg_prediction(projected_rgb_and_feat, e_attn,
                                              n_attn)

        ret = [(rgb_coarse, None, None)]
        ret.append((rgb, None, None))

        if self.return_attn:
            return ret, {
                "e_attn": e_attn,
                "n_attn": n_attn,
                "p_coord": projected_coordinates.swapaxes(0, 1)
            }
        else:
            return ret
Beispiel #3
0
def evaluate(config, workdir):
    """Evalution function."""

    # Hide the GPUs and TPUs from TF so it does not reserve memory on them for
    # LPIPS computation or dataset loading.
    tf.config.experimental.set_visible_devices([], "GPU")
    tf.config.experimental.set_visible_devices([], "TPU")

    rng = jax.random.PRNGKey(config.seed)

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, test_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    rng, key = jax.random.split(rng)

    #----------------------------------------------------------------------------
    # Initialize model.
    learning_rate_fn = train_utils.create_learning_rate_fn(config)
    model, state = models.create_train_state(config,
                                             key,
                                             learning_rate_fn=learning_rate_fn,
                                             example_batch=example_batch)

    #----------------------------------------------------------------------------
    # Get the rendering function. Renderig is forced ot be deterministic even if
    # trainin is randomized
    render_pfn = render_utils.get_render_function(model,
                                                  config,
                                                  randomized=False)

    last_step = 0
    out_dir = os.path.join(
        workdir,
        "path_renders" if config.dataset.render_path else "test_preds")

    if not config.eval.eval_once:
        # Prepare Metric Writers
        summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))

    while True:
        state = checkpoints.restore_checkpoint(workdir, state)
        step = int(state.step)
        if step <= last_step:
            continue

        if config.eval.save_output and (not file_utils.isdir(out_dir)):
            file_utils.makedirs(out_dir)

        psnr_values = []
        ssim_values = []

        if not config.eval.eval_once:
            showcase_index = np.random.randint(0, test_ds.size)

        for idx in range(test_ds.size):
            logging.info("Evaluating [%d / %d].", idx, test_ds.size)
            batch = next(test_ds)
            test_pixels = batch.target_view.rgb
            if test_pixels is not None:
                test_pixels = model_utils.uint2float(test_pixels)

            #-----------------------------------------------------------
            # Render Image
            variables = {"params": state.params}
            pred_color, pred_disp, pred_acc = render_utils.render_image(
                functools.partial(render_pfn, variables),
                batch,
                rng,
                render_utils.normalize_disp(config.dataset.name),
                chunk=config.eval.chunk)

            if jax.process_index() != 0:
                continue

            #-----------------------------------------------------------
            # Get showcase example for logging
            if not config.eval.eval_once and idx == showcase_index:
                showcase_color = pred_color
                showcase_disp = pred_disp
                showcase_acc = pred_acc
                if not config.dataset.render_path:
                    showcase_gt = test_pixels
            #-----------------------------------------------------------
            # If get pixels available, evaluate
            if not config.dataset.render_path:
                psnr = model_utils.compute_psnr(
                    ((pred_color - test_pixels)**2).mean())
                ssim = skmetrics.structural_similarity(
                    pred_color.astype(np.float32),
                    test_pixels.astype(np.float32),
                    win_size=11,
                    multichannel=True,
                    gaussian_weights=True)
                logging.info(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")  # pylint: disable=logging-format-interpolation
                psnr_values.append(float(psnr))
                ssim_values.append(float(ssim))

            #-----------------------------------------------------------
            # Save generated image
            if config.eval.save_output:
                model_utils.save_img(
                    pred_color, os.path.join(out_dir,
                                             "{:03d}.png".format(idx)))
                if pred_disp is not None:
                    model_utils.save_img(
                        pred_disp[Ellipsis, 0],
                        os.path.join(out_dir, "disp_{:03d}.png".format(idx)))
        #-----------------------------------------------------------
        if (not config.eval.eval_once) and (jax.process_index() == 0):
            summary_writer.image("pred_color", showcase_color, step)
            if showcase_disp is not None:
                summary_writer.image("pred_disp", showcase_disp, step)
            if showcase_acc is not None:
                summary_writer.image("pred_acc", showcase_acc, step)
            if not config.dataset.render_path:
                summary_writer.scalar("eval_metric/psnr",
                                      np.mean(np.array(psnr_values)), step)
                summary_writer.scalar("eval_metric/ssim",
                                      np.mean(np.array(ssim_values)), step)
                summary_writer.image("target", showcase_gt, step)

        #-----------------------------------------------------------
        # Save the metric to file
        if config.eval.save_output and (not config.dataset.render_path) and (
                jax.process_index() == 0):
            with file_utils.open_file(
                    os.path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in psnr_values]))
            with file_utils.open_file(
                    os.path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
                f.write(" ".join([str(v) for v in ssim_values]))
            with file_utils.open_file(os.path.join(out_dir, "psnr.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(psnr_values))))
            with file_utils.open_file(os.path.join(out_dir, "ssim.txt"),
                                      "w") as f:
                f.write("{}".format(np.mean(np.array(ssim_values))))
        if config.eval.eval_once:
            break
        if int(step) >= config.train.max_steps:
            break
        last_step = step

    logging.info("Finishing evaluation at step %d", last_step)
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    if config.dataset.batch_size % jax.device_count() != 0:
        raise ValueError(
            "Batch size must be divisible by the number of devices.")

    tf.io.gfile.makedirs(workdir)
    # Deterministic training.
    rng = jax.random.PRNGKey(config.seed)
    # Shift the numpy random seed by process_index() to shuffle data loaded
    # by different hosts
    np.random.seed(20201473 + jax.process_index())

    #----------------------------------------------------------------------------
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    config.dataset.data_dir = os.path.join(config.dataset.base_dir,
                                           config.dataset.scene)
    train_ds, eval_ds = datasets.create_dataset(config)
    example_batch = train_ds.peek()

    #----------------------------------------------------------------------------
    # Learning rate schedule.
    num_train_steps = config.train.max_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.size()
    steps_per_epoch = num_train_steps // config.train.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)

    learning_rate_fn = train_utils.create_learning_rate_fn(config)

    #----------------------------------------------------------------------------
    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = models.create_train_state(
        config,
        model_rng,
        learning_rate_fn=learning_rate_fn,
        example_batch=example_batch,
    )

    #----------------------------------------------------------------------------
    # Set up checkpointing of the model and the input pipeline.
    state = checkpoints.restore_checkpoint(workdir, state)
    initial_step = int(state.step) + 1

    #----------------------------------------------------------------------------
    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(
        functools.partial(
            train_step,
            model=model,
            learning_rate_fn=learning_rate_fn,
            weight_decay=config.train.weight_decay,
            config=config,
        ),
        axis_name="batch",
    )

    # Get distributed rendering function
    render_pfn = render_utils.get_render_function(
        model=model,
        config=config,
        randomized=False,  # No randomization for evaluation.
    )

    #----------------------------------------------------------------------------
    # Prepare Metric Writers
    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
        ]
    train_metrics = None

    # Prefetch_buffer_size = 6 x batch_size
    ptrain_ds = flax.jax_utils.prefetch_to_device(train_ds, 6)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.process_index()  # Make random seed separate across hosts.
    keys = jax.random.split(rng, n_local_devices)  # For pmapping RNG keys.

    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceAnnotation("train", step_num=step):
                batch = next(ptrain_ds)
                state, metrics_update, keys = p_train_step(rng=keys,
                                                           state=state,
                                                           batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))
            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.train.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            if step % config.train.render_every_steps == 0 or is_last_step:
                test_batch = next(eval_ds)
                test_pixels = model_utils.uint2float(
                    test_batch.target_view.rgb)  # extract for evaluation
                with report_progress.timed("eval"):
                    pred_color, pred_disp, pred_acc = eval_step(
                        state, keys[0], test_batch, render_pfn, config)
                #------------------------------------------------------------------
                # Log metrics and images for host 0
                #------------------------------------------------------------------
                if jax.process_index() == 0:
                    psnr = model_utils.compute_psnr(
                        ((pred_color - test_pixels)**2).mean())
                    ssim = skmetrics.structural_similarity(
                        pred_color.astype(np.float32),
                        test_pixels.astype(np.float32),
                        win_size=11,
                        multichannel=True,
                        gaussian_weight=True)
                    writer.write_scalars(
                        step, {
                            "train_eval/test_psnr": psnr,
                            "train_eval/test_ssim": ssim,
                        })
                    writer.write_images(
                        step, {
                            "test_pred_color": pred_color[None, :],
                            "test_target": test_pixels[None, :]
                        })
                    if pred_disp is not None:
                        writer.write_images(
                            step, {"test_pred_disp": pred_disp[None, :]})
                    if pred_acc is not None:
                        writer.write_images(
                            step, {"test_pred_acc": pred_acc[None, :]})
                #------------------------------------------------------------------

            if (jax.process_index()
                    == 0) and (step % config.train.checkpoint_every_steps == 0
                               or is_last_step):
                # Write final metrics to file
                with file_utils.open_file(
                        os.path.join(workdir, "train_logs.json"), "w") as f:
                    log_dict = metric_update.compute()
                    for k, v in log_dict.items():
                        log_dict[k] = v.item()
                    f.write(json.dumps(log_dict))
                with report_progress.timed("checkpoint"):
                    state_to_save = jax.device_get(
                        jax.tree_map(lambda x: x[0], state))
                    checkpoints.save_checkpoint(workdir,
                                                state_to_save,
                                                step,
                                                keep=100)

    logging.info("Finishing training at step %d", num_train_steps)