def loss_fn(model): with nn.stateful(state.model_state) as new_model_state: rays = batch["rays"] ret = model(key_0, key_1, rays.origins, rays.directions, rays.viewdirs) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) stats = [utils.Stats(loss=loss, psnr=psnr)] if len(ret) > 1: # If there are both coarse and fine predictions, we compuate the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) stats.append(utils.Stats(loss=loss_c, psnr=psnr_c)) else: loss_c = 0. psnr_c = 0. return loss + loss_c, (new_model_state, stats)
def loss_fn(variables): rays = batch["rays"] ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) if len(ret) > 1: # If there are both coarse and fine predictions, we compute the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) else: loss_c = 0. psnr_c = 0. def tree_sum_fn(fn): return jax.tree_util.tree_reduce(lambda x, y: x + fn(y), variables, initializer=0) weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) / tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape)))) stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c, psnr_c=psnr_c, weight_l2=weight_l2) return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats