Exemplo n.º 1
0
 def loss_fn(model):
     with nn.stateful(state.model_state) as new_model_state:
         rays = batch["rays"]
         ret = model(key_0, key_1, rays.origins, rays.directions,
                     rays.viewdirs)
     if len(ret) not in (1, 2):
         raise ValueError(
             "ret should contain either 1 set of output (coarse only), or 2 sets"
             "of output (coarse as ret[0] and fine as ret[1]).")
     # The main prediction is always at the end of the ret list.
     rgb, unused_disp, unused_acc = ret[-1]
     loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
     psnr = utils.compute_psnr(loss)
     stats = [utils.Stats(loss=loss, psnr=psnr)]
     if len(ret) > 1:
         # If there are both coarse and fine predictions, we compuate the loss for
         # the coarse prediction (ret[0]) as well.
         rgb_c, unused_disp_c, unused_acc_c = ret[0]
         loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
         psnr_c = utils.compute_psnr(loss_c)
         stats.append(utils.Stats(loss=loss_c, psnr=psnr_c))
     else:
         loss_c = 0.
         psnr_c = 0.
     return loss + loss_c, (new_model_state, stats)
Exemplo n.º 2
0
    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables,
                                             initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        stats = utils.Stats(loss=loss,
                            psnr=psnr,
                            loss_c=loss_c,
                            psnr_c=psnr_c,
                            weight_l2=weight_l2)
        return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats