Beispiel #1
0
 def bounding_fn(lag_grad: Tensor, acts: ParamSet) -> Tensor:
     x_0 = acts[index]
     bound_contrib = _sum_over_acts(
         jnp.maximum(lag_grad, 0.) * (lb - x_0) +
         jnp.minimum(lag_grad, 0.) * (ub - x_0))
     return bound_contrib
Beispiel #2
0
def clipped_objective(probab_ratios, advantages, reward_mask, epsilon=0.2):
    return np.minimum(
        probab_ratios * advantages,
        clipped_probab_ratios(probab_ratios, epsilon=epsilon) *
        advantages) * reward_mask
Beispiel #3
0
def compute_ssim(img0,
                 img1,
                 max_val,
                 filter_size=11,
                 filter_sigma=1.5,
                 k1=0.01,
                 k2=0.03,
                 return_map=False):
    """Computes SSIM from two images.

  This function was modeled after tf.image.ssim, and should produce comparable
  output.

  Args:
    img0: array. An image of size [..., width, height, num_channels].
    img1: array. An image of size [..., width, height, num_channels].
    max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
    filter_size: int >= 1. Window size.
    filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
    k1: float > 0. One of the SSIM dampening parameters.
    k2: float > 0. One of the SSIM dampening parameters.
    return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned

  Returns:
    Each image's mean SSIM, or a tensor of individual values if `return_map`.
  """
    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = jnp.exp(-0.5 * f_i)
    filt /= jnp.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
    filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")

    # Vmap the blurs to the tensor size, and then compose them.
    num_dims = len(img0.shape)
    map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
    for d in map_axes:
        filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
        filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
    filt_fn = lambda z: filt_fn1(filt_fn2(z))

    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = jnp.maximum(0., sigma00)
    sigma11 = jnp.maximum(0., sigma11)
    sigma01 = jnp.sign(sigma01) * jnp.minimum(jnp.sqrt(sigma00 * sigma11),
                                              jnp.abs(sigma01))

    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
    return ssim_map if return_map else ssim
Beispiel #4
0
 def step_fn(step):
     epoch = step / steps_per_epoch
     lr = cosine_decay(base_learning_rate, epoch - warmup_epochs,
                       num_epochs - warmup_epochs)
     warmup = jnp.minimum(1., epoch / warmup_epochs)
     return lr * warmup
Beispiel #5
0
 def clip(x, lo, hi):
   return np.minimum(hi, np.maximum(lo, x))
Beispiel #6
0
 def acceptance_probability(
         self, scenario: Scenario, reject_state: cdict, reject_extra: cdict,
         proposed_state: cdict,
         proposed_extra: cdict) -> Union[float, jnp.ndarray]:
     return jnp.minimum(
         1., jnp.exp(-proposed_state.potential + reject_state.potential))
Beispiel #7
0
def train_step(model, rng, state, batch, lr):
    """One optimization step.

  Args:
    model: The linen model.
    rng: jnp.ndarray, random number generator.
    state: utils.TrainState, state of the model/optimizer.
    batch: dict, a mini-batch of data for training.
    lr: float, real-time learning rate.

  Returns:
    new_state: utils.TrainState, new training state.
    stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
    rng: jnp.ndarray, updated random number generator.
  """
    rng, key_0, key_1 = random.split(rng, 3)

    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables,
                                             initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z**2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        stats = utils.Stats(loss=loss,
                            psnr=psnr,
                            loss_c=loss_c,
                            psnr_c=psnr_c,
                            weight_l2=weight_l2)
        return loss + loss_c + FLAGS.weight_decay_mult * weight_l2, stats

    (_,
     stats), grad = (jax.value_and_grad(loss_fn,
                                        has_aux=True)(state.optimizer.target))
    grad = jax.lax.pmean(grad, axis_name="batch")
    stats = jax.lax.pmean(stats, axis_name="batch")

    # Clip the gradient by value.
    if FLAGS.grad_max_val > 0:
        clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val
                                     )
        grad = jax.tree_util.tree_map(clip_fn, grad)

    # Clip the (possibly value-clipped) gradient by norm.
    if FLAGS.grad_max_norm > 0:
        grad_norm = jnp.sqrt(
            jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2),
                                      grad,
                                      initializer=0))
        mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
        grad = jax.tree_util.tree_map(lambda z: mult * z, grad)

    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
    new_state = state.replace(optimizer=new_optimizer)
    return new_state, stats, rng
def isoneutral_diffusion_pre(maskT, maskU, maskV, maskW, dxt, dxu, dyt, dyu,
                             dzt, dzw, cost, cosu, salt, temp, zt, K_iso, K_11,
                             K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by):
    """
    Isopycnal diffusion for tracer
    following functional formulation by Griffies et al
    Code adopted from MOM2.1
    """
    epsln = 1e-20
    iso_slopec = 1e-3
    iso_dslope = 1e-3
    K_iso_steep = 50.
    tau = 0

    dTdx = np.zeros_like(K_11)
    dSdx = np.zeros_like(K_11)
    dTdy = np.zeros_like(K_11)
    dSdy = np.zeros_like(K_11)
    dTdz = np.zeros_like(K_11)
    dSdz = np.zeros_like(K_11)
    """
    drho_dt and drho_ds at centers of T cells
    """
    drdT = maskT * get_drhodT(salt[:, :, :, tau], temp[:, :, :, tau],
                              np.abs(zt))
    drdS = maskT * get_drhodS(salt[:, :, :, tau], temp[:, :, :, tau],
                              np.abs(zt))
    """
    gradients at top face of T cells
    """
    dTdz = jax.ops.index_update(
        dTdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \
        (temp[:, :, 1:, tau] - temp[:, :, :-1, tau]) / \
        dzw[np.newaxis, np.newaxis, :-1]
    )
    dSdz = jax.ops.index_update(
        dSdz, jax.ops.index[:, :, :-1], maskW[:, :, :-1] * \
        (salt[:, :, 1:, tau] - salt[:, :, :-1, tau]) / \
        dzw[np.newaxis, np.newaxis, :-1]
    )
    """
    gradients at eastern face of T cells
    """
    dTdx = jax.ops.index_update(
        dTdx, jax.ops.index[:-1, :, :], maskU[:-1, :, :] * (temp[1:, :, :, tau] - temp[:-1, :, :, tau]) \
        / (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis])
    )
    dSdx = jax.ops.index_update(
        dSdx, jax.ops.index[:-1, :, :],
        maskU[:-1, :, :] * (salt[1:, :, :, tau] - salt[:-1, :, :, tau]) /
        (dxu[:-1, np.newaxis, np.newaxis] * cost[np.newaxis, :, np.newaxis]))
    """
    gradients at northern face of T cells
    """
    dTdy = jax.ops.index_update(
        dTdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \
        (temp[:, 1:, :, tau] - temp[:, :-1, :, tau]) \
        / dyu[np.newaxis, :-1, np.newaxis]
    )
    dSdy = jax.ops.index_update(dSdy, jax.ops.index[:, :-1, :], maskV[:, :-1, :] * \
        (salt[:, 1:, :, tau] - salt[:, :-1, :, tau]) \
        / dyu[np.newaxis, :-1, np.newaxis]
    )

    def dm_taper(sx):
        """
        tapering function for isopycnal slopes
        """
        return 0.5 * (1. + np.tanh((-np.abs(sx) + iso_slopec) / iso_dslope))

    """
    Compute Ai_ez and K11 on center of east face of T cell.
    """
    diffloc = np.zeros_like(K_11)
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[1:-2, 2:-2, 1:],
        0.25 * (K_iso[1:-2, 2:-2, 1:] + K_iso[1:-2, 2:-2, :-1] +
                K_iso[2:-1, 2:-2, 1:] + K_iso[2:-1, 2:-2, :-1]))
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[1:-2, 2:-2, 0],
        0.5 * (K_iso[1:-2, 2:-2, 0] + K_iso[2:-1, 2:-2, 0]))

    sumz = np.zeros_like(K_11)[1:-2, 2:-2]
    for kr in range(2):
        ki = 0 if kr == 1 else 1
        for ip in range(2):
            drodxe = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:] \
                + drdS[1 + ip:-2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:]
            drodze = drdT[1 + ip:-2 + ip, 2:-2, ki:] * dTdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None] \
                + drdS[1 + ip:-2 + ip, 2:-2, ki:] * \
                dSdz[1 + ip:-2 + ip, 2:-2, :-1 + kr or None]
            sxe = -drodxe / (np.minimum(0., drodze) - epsln)
            taper = dm_taper(sxe)
            sumz = jax.ops.index_update(
                sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] +
                dzw[np.newaxis, np.newaxis, :-1 + kr or None] *
                maskU[1:-2, 2:-2, ki:] *
                np.maximum(K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper))
            Ai_ez = jax.ops.index_update(
                Ai_ez, jax.ops.index[1:-2, 2:-2, ki:, ip, kr],
                taper * sxe * maskU[1:-2, 2:-2, ki:])

    K_11 = jax.ops.index_update(K_11, jax.ops.index[1:-2, 2:-2, :],
                                sumz / (4. * dzt[np.newaxis, np.newaxis, :]))
    """
    Compute Ai_nz and K_22 on center of north face of T cell.
    """
    diffloc = jax.ops.index_update(diffloc, jax.ops.index[...], 0)
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[2:-2, 1:-2, 1:],
        0.25 * (K_iso[2:-2, 1:-2, 1:] + K_iso[2:-2, 1:-2, :-1] +
                K_iso[2:-2, 2:-1, 1:] + K_iso[2:-2, 2:-1, :-1]))
    diffloc = jax.ops.index_update(
        diffloc, jax.ops.index[2:-2, 1:-2, 0],
        0.5 * (K_iso[2:-2, 1:-2, 0] + K_iso[2:-2, 2:-1, 0]))

    sumz = np.zeros_like(K_11)[2:-2, 1:-2]
    for kr in range(2):
        ki = 0 if kr == 1 else 1
        for jp in range(2):
            drodyn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:] + \
                drdS[2:-2, 1 + jp:-2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:]
            drodzn = drdT[2:-2, 1 + jp:-2 + jp, ki:] * dTdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None] \
                + drdS[2:-2, 1 + jp:-2 + jp, ki:] * \
                dSdz[2:-2, 1 + jp:-2 + jp, :-1 + kr or None]
            syn = -drodyn / (np.minimum(0., drodzn) - epsln)
            taper = dm_taper(syn)
            sumz = jax.ops.index_update(
                sumz, jax.ops.index[:, :, ki:], sumz[..., ki:] +
                dzw[np.newaxis, np.newaxis, :-1 + kr or None] *
                maskV[2:-2, 1:-2, ki:] *
                np.maximum(K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper))
            Ai_nz = jax.ops.index_update(
                Ai_nz, jax.ops.index[2:-2, 1:-2, ki:, jp, kr],
                taper * syn * maskV[2:-2, 1:-2, ki:])
    K_22 = jax.ops.index_update(K_22, jax.ops.index[2:-2, 1:-2, :],
                                sumz / (4. * dzt[np.newaxis, np.newaxis, :]))
    """
    compute Ai_bx, Ai_by and K33 on top face of T cell.
    """
    sumx = np.zeros_like(K_11)[2:-2, 2:-2, :-1]
    sumy = np.zeros_like(K_11)[2:-2, 2:-2, :-1]

    for kr in range(2):
        drodzb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdz[2:-2, 2:-2, :-1] \
            + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdz[2:-2, 2:-2, :-1]

        # eastward slopes at the top of T cells
        for ip in range(2):
            drodxb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None] \
                + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdx[1 + ip:-3 + ip, 2:-2, kr:-1 + kr or None]
            sxb = -drodxb / (np.minimum(0., drodzb) - epsln)
            taper = dm_taper(sxb)
            sumx += dxu[1 + ip:-3 + ip, np.newaxis, np.newaxis] * \
                K_iso[2:-2, 2:-2, :-1] * taper * \
                sxb**2 * maskW[2:-2, 2:-2, :-1]
            Ai_bx = jax.ops.index_update(
                Ai_bx, jax.ops.index[2:-2, 2:-2, :-1, ip, kr],
                taper * sxb * maskW[2:-2, 2:-2, :-1])

        # northward slopes at the top of T cells
        for jp in range(2):
            facty = cosu[1 + jp:-3 + jp] * dyu[1 + jp:-3 + jp]
            drodyb = drdT[2:-2, 2:-2, kr:-1 + kr or None] * dTdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None] \
                + drdS[2:-2, 2:-2, kr:-1 + kr or None] * dSdy[2:-2, 1 + jp:-3 + jp, kr:-1 + kr or None]
            syb = -drodyb / (np.minimum(0., drodzb) - epsln)
            taper = dm_taper(syb)
            sumy += facty[np.newaxis, :, np.newaxis] * K_iso[2:-2, 2:-2, :-1] \
                * taper * syb**2 * maskW[2:-2, 2:-2, :-1]
            Ai_by = jax.ops.index_update(
                Ai_by, jax.ops.index[2:-2, 2:-2, :-1, jp, kr],
                taper * syb * maskW[2:-2, 2:-2, :-1])

    K_33 = jax.ops.index_update(
        K_33, jax.ops.index[2:-2, 2:-2, :-1],
        sumx / (4 * dxt[2:-2, np.newaxis, np.newaxis]) + \
        sumy / (4 * dyt[np.newaxis, 2:-2, np.newaxis]
                * cost[np.newaxis, 2:-2, np.newaxis])
    )
    K_33 = jax.ops.index_update(K_33, jax.ops.index[2:-2, 2:-2, -1], 0.)

    return K_11, K_22, K_33, Ai_ez, Ai_nz, Ai_bx, Ai_by
Beispiel #9
0
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
    """Piecewise-Constant PDF sampling.

    Args:
      key: jnp.ndarray(float32), [2,], random number generator.
      bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
      weights: jnp.ndarray(float32), [batch_size, num_bins].
      num_samples: int, the number of samples.
      randomized: bool, use randomized samples.

    Returns:
      z_samples: jnp.ndarray(float32), [batch_size, num_samples].
    """
    # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
    # avoids NaNs when the input is zeros or small, but has no effect otherwise.
    eps = 1e-5
    weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
    padding = jnp.maximum(0, eps - weight_sum)
    weights += padding / weights.shape[-1]
    weight_sum += padding

    # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
    # starts with exactly 0 and ends with exactly 1.
    pdf = weights / weight_sum
    cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
    cdf = jnp.concatenate(
        [
            jnp.zeros(list(cdf.shape[:-1]) + [1]),
            cdf,
            jnp.ones(list(cdf.shape[:-1]) + [1]),
        ],
        axis=-1,
    )

    # Draw uniform samples.
    if randomized:
        # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
        u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
    else:
        # Match the behavior of random.uniform() by spanning [0, 1-eps].
        u = jnp.linspace(0.0, 1.0 - jnp.finfo("float32").eps, num_samples)
        u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])

    # Identify the location in `cdf` that corresponds to a random sample.
    # The final `True` index in `mask` will be the start of the sampled interval.
    mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]

    def find_interval(x):
        # Grab the value where `mask` switches from True to False, and vice versa.
        # This approach takes advantage of the fact that `x` is sorted.
        x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]),
                     -2)
        x1 = jnp.min(
            jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
        return x0, x1

    bins_g0, bins_g1 = find_interval(bins)
    cdf_g0, cdf_g1 = find_interval(cdf)

    t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
    samples = bins_g0 + t * (bins_g1 - bins_g0)

    # Prevent gradient from backprop-ing through `samples`.
    return lax.stop_gradient(samples)
Beispiel #10
0
 def value(self, step):
     step = super().value(step)
     return step, jnp.minimum(100.0 / step, 0.01)
Beispiel #11
0
def cosine_decay(base_learning_rate, step, decay_steps, alpha=0.001):
    ratio = jnp.minimum(jnp.maximum(0., step / decay_steps), 1.)
    decay = 0.5 * (1. + jnp.cos(jnp.pi * ratio))
    decayed = (1 - alpha) * decay + alpha
    return decayed * base_learning_rate
Beispiel #12
0
def HardTanh(x, **unused_kwargs):
    """Linear approximation to tanh."""
    return np.maximum(-1, np.minimum(1, x))
Beispiel #13
0
def HardSigmoid(x, **unused_kwargs):
    """Linear approximation to sigmoid."""
    return np.maximum(0, np.minimum(1, (1 + x)))
if args.estimate == 'tbptt':
    for i in range(args.outer_iterations):

        if t >= args.T:
            x = jnp.array(initial_point)  # Reset the inner parameters
            t = 0

        theta_grad, aux = grad_unroll(x, theta, t, args.T, args.K)
        x = aux[
            0]  # Update to be the x we get by unrolling the optimization with theta
        t = aux[2]  # Update to t_current from the unroll we used to get x

        # Gradient clipping
        if args.outer_clip > 0:
            theta_grad = theta_grad * jnp.minimum(
                1., args.outer_clip / (jnp.linalg.norm(theta_grad) + 1e-8))

        theta, optim_params = optimizer_step(theta, theta_grad, optim_params,
                                             i)

        if i % args.log_interval == 0:
            L, _ = unroll(jnp.array(initial_point), theta, 0, args.T,
                          args.T)  # Evaluate on the full unroll
            iteration_logger.writerow({
                'time_elapsed': time.time() - start_time,
                'iteration': i,
                'inner_problem_steps': i * args.K,
                'theta0': float(theta[0]),
                'theta1': float(theta[1]),
                'theta0_grad': float(theta_grad[0]),
                'theta1_grad': float(theta_grad[1]),
Beispiel #15
0
 def loss_fn(variables):
   residual = model_utils.viewdir_fn(model, variables, rgb_features,
                                     directions, scene_params)
   final_rgb = jnp.minimum(1.0, rgb_features[Ellipsis, 0:3] + residual)
   loss = ((final_rgb - ref[Ellipsis, :3])**2).mean()
   return loss
Beispiel #16
0
def clipped_objective(probab_ratios, advantages, action_mask, epsilon):
  advantages = advantages
  return np.minimum(
      probab_ratios * advantages,
      clipped_probab_ratios(probab_ratios, epsilon=epsilon) *
      advantages) * action_mask
Beispiel #17
0
 def schedule(count):
     count = jnp.minimum(count, decay_steps)
     cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * count / decay_steps))
     decayed = (1 - alpha) * cosine_decay + alpha
     return init_value * decayed
Beispiel #18
0
 def schedule(step):
   t = jnp.minimum(step / burnin_steps, 1.)
   coef = (1 + jnp.cos(t * onp.pi)) * 0.5
   return coef * init_lr + (1 - coef) * final_lr