Exemple #1
0
def add_decayed_weights(
    weight_decay: float = 0.0,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
) -> base.GradientTransformation:
    """Add parameter scaled by `weight_decay`.

  Args:
    weight_decay: a scalar weight decay rate.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return AddDecayedWeightsState()

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError(base.NO_PARAMS_MSG)
        updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
                                    params)
        return updates, state

    # If mask is not `None`, apply mask to the gradient transformation.
    # E.g. it is common to skip weight decay on bias units and batch stats.
    if mask is not None:
        return wrappers.masked(base.GradientTransformation(init_fn, update_fn),
                               mask)
    return base.GradientTransformation(init_fn, update_fn)
Exemple #2
0
def add_noise(eta: float, gamma: float,
              seed: int) -> base.GradientTransformation:
    """Add gradient noise.

  References:
    [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

  Args:
    eta: base variance of the gaussian noise added to the gradient.
    gamma: decay exponent for annealing of the variance.
    seed: seed for random number generation.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return AddNoiseState(count=jnp.zeros([], jnp.int32),
                             rng_key=jax.random.PRNGKey(seed))

    def update_fn(updates, state, params=None):  # pylint: disable=missing-docstring
        del params
        num_vars = len(jax.tree_leaves(updates))
        treedef = jax.tree_structure(updates)
        count_inc = numerics.safe_int32_increment(state.count)
        variance = eta / count_inc**gamma
        all_keys = jax.random.split(state.rng_key, num=num_vars + 1)
        noise = jax.tree_multimap(
            lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype),
            updates, jax.tree_unflatten(treedef, all_keys[1:]))
        updates = jax.tree_multimap(
            lambda g, n: g + variance.astype(g.dtype) * n, updates, noise)
        return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])

    return base.GradientTransformation(init_fn, update_fn)
Exemple #3
0
def scale_by_param_block_norm(
    min_scale: float = 1e-3
) -> base.GradientTransformation:
  """Scale updates for each param block by the norm of that block's parameters.

  A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
  (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

  Args:
    min_scale: minimum scaling factor.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    del params
    return base.EmptyState()

  def update_fn(updates, state, params):
    if params is None:
      raise ValueError(base.NO_PARAMS_MSG)
    updates = jax.tree_multimap(
        lambda u, p: u * numerics.safe_norm(p, min_scale),
        updates, params)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #4
0
def apply_every(k: int = 1) -> base.GradientTransformation:
    """Accumulate gradients and apply them every k steps.

  Note that if this transformation is part of a chain, the states of the other
  transformations will still be updated at every step. In particular, using
  `apply_every` with a batch size of N/2 and k=2 is not necessarily equivalent
  to not using `apply_every` with a batch size of N. If this equivalence is
  important for you, consider using the `optax.MultiSteps`.

  Args:
    k: emit non-zero gradients every k steps, otherwise accumulate them.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        grad_acc = jax.tree_map(jnp.zeros_like, params)
        return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc)

    def update_fn(updates, state, params=None):
        del params
        c = state.count % k
        acc = c != 0
        grad_acc = jax.tree_multimap(lambda g, ga: acc * ga + g, updates,
                                     state.grad_acc)
        emit = c == (k - 1)
        updates = jax.tree_map(lambda ga: emit * ga, grad_acc)
        count_inc = numerics.safe_int32_increment(state.count)
        return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #5
0
def scale_by_param_block_rms(
    min_scale: float = 1e-3
) -> base.GradientTransformation:
  """Scale updates by rms of the gradient for each param vector or matrix.

  A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
  (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

  Args:
    min_scale: minimum scaling factor.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_):
    return base.EmptyState()

  def update_fn(updates, state, params):
    updates = jax.tree_map(
        lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale),
        updates, params)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #6
0
def clip_by_global_norm(max_norm) -> base.GradientTransformation:
  """Clip updates using their global norm.

  References:
    [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

  Args:
    max_norm: the maximum global norm for an update.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_):
    return ClipByGlobalNormState()

  def update_fn(updates, state, params=None):
    del params
    g_norm = linear_algebra.global_norm(updates)
    # TODO(b/163995078): revert back to the following (faster) implementation
    # once analysed how it affects backprop through update (e.g. meta-gradients)
    # g_norm = jnp.maximum(max_norm, g_norm)
    # updates = jax.tree_map(lambda t: (t / g_norm) * max_norm, updates)
    trigger = g_norm < max_norm
    updates = jax.tree_map(
        lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #7
0
def keep_params_nonnegative() -> base.GradientTransformation:
    """Modifies the updates to keep parameters non-negative, i.e. >= 0.

  This transformation ensures that parameters after the update will be
  larger than or equal to zero.
  In a chain of transformations, this should be the last one.

  WARNING: the transformation expects input params to be non-negative.
  When params is negative the transformed update will move them to 0.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        del params
        return NonNegativeParamsState()

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError(base.NO_PARAMS_MSG)

        updates = jax.tree_multimap(
            lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates)
        return updates, state

    return base.GradientTransformation(init_fn, update_fn)
Exemple #8
0
def clip_by_block_rms(threshold: float) -> base.GradientTransformation:
  """Clips updates to a max rms for the gradient of each param vector or matrix.

  A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix
  (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

  Args:
    threshold: The maximum rms for the gradient of each param vector or matrix.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    del params
    return base.EmptyState()

  def update_fn(updates, state, params=None):
    del params

    def _clip_fn(u):
      clip_denom = jnp.maximum(
          1.0,
          jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold)
      return u / clip_denom

    updates = jax.tree_map(_clip_fn, updates)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #9
0
def adaptive_grad_clip(clipping, eps=1e-3) -> base.GradientTransformation:
  """Clip updates to be at most clipping * parameter_norm, unit-wise.

  References:
    [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
    Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)

  Args:
    clipping: Maximum allowed ratio of update norm to parameter norm.
    eps: epsilon term to prevent clipping of zero-initialized params.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_):
    return AdaptiveGradClipState()

  def update_fn(updates, state, params):
    if params is None:
      raise ValueError(base.NO_PARAMS_MSG)
    g_norm = jax.tree_map(unitwise_norm, updates)
    p_norm = jax.tree_map(unitwise_norm, params)
    # Maximum allowable norm
    max_norm = jax.tree_map(lambda x: clipping * jnp.maximum(x, eps), p_norm)
    # If grad norm > clipping * param_norm, rescale
    updates = jax.tree_multimap(unitwise_clip, g_norm, max_norm, updates)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #10
0
def zero_nans() -> base.GradientTransformation:
    """A transformation which replaces NaNs with 0.

  Zeroing values in gradients is guaranteed to produce a direction of
  non-increasing loss.

  The state of the transformation has the same tree structure as that of the
  parameters. Each leaf is a single boolean which contains True iff a NaN was
  detected in the corresponding parameter array at the last call to `update`.
  This state is not used by the transformation internally, but lets users be
  aware when NaNs have been zeroed out.

  Returns:
    A `GradientTransformation`.
  """
    def init_fn(params):
        return ZeroNansState(
            jax.tree_map(lambda p: jnp.array(False, dtype=jnp.bool_), params))

    def update_fn(updates, opt_state, params=None):
        del params
        opt_state = ZeroNansState(
            jax.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates))
        updates = jax.tree_map(
            lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates)
        return updates, opt_state

    return base.GradientTransformation(init=init_fn, update=update_fn)
Exemple #11
0
def scale_by_rss(initial_accumulator_value: float = 0.1,
                 eps: float = 1e-7) -> base.GradientTransformation:
    """Rescale updates by the root of the sum of all squared gradients to date.

  References:
    [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
    [McMahan et al., 2010](https://arxiv.org/abs/1002.4908)

  Args:
    initial_accumulator_value: Starting value for accumulators, must be >= 0.
    eps: A small floating point value to avoid zero denominator.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        sum_of_squares = jax.tree_map(
            lambda t: jnp.full_like(t, initial_accumulator_value), params)
        return ScaleByRssState(sum_of_squares=sum_of_squares)

    def update_fn(updates, state, params=None):
        del params
        sum_of_squares = jax.tree_multimap(lambda g, t: jnp.square(g) + t,
                                           updates, state.sum_of_squares)
        inv_sqrt_g_square = jax.tree_map(
            lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0),
            sum_of_squares)
        updates = jax.tree_multimap(lambda scale, g: scale * g,
                                    inv_sqrt_g_square, updates)
        return updates, ScaleByRssState(sum_of_squares=sum_of_squares)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #12
0
def scale_by_schedule(
    step_size_fn: base.Schedule
) -> base.GradientTransformation:
  """Scale updates using a custom schedule for the `step_size`.

  Args:
    step_size_fn: a function that takes an update count as input and proposes
      the step_size to multiply the updates by.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_):
    return ScaleByScheduleState(count=jnp.zeros([], jnp.int32))

  def update_fn(updates, state, params=None):
    del params
    step_size = step_size_fn(state.count)
    updates = jax.tree_map(
        lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates)
    return updates, ScaleByScheduleState(
        count=numerics.safe_int32_increment(state.count))

  return base.GradientTransformation(init_fn, update_fn)
Exemple #13
0
def chain(*args: base.GradientTransformation) -> base.GradientTransformation:
    """Applies a list of chainable update transformations.

  Given a sequence of chainable transforms, `chain` returns an `init_fn`
  that constructs a `state` by concatenating the states of the individual
  transforms, and returns an `update_fn` which chains the update transformations
  feeding the appropriate state to each.

  Args:
    *args: a sequence of chainable (init_fn, update_fn) tuples.

  Returns:
    A single (init_fn, update_fn) tuple.
  """

    init_fns, update_fns = zip(*args)

    def init_fn(params):
        return tuple(fn(params) for fn in init_fns)

    def update_fn(updates, state, params=None):
        if len(update_fns) != len(state):
            raise ValueError(
                'The number of updates and states has to be the same in '
                'chain! Make sure you have called init first!')

        new_state = []
        for s, fn in zip(state, update_fns):
            updates, new_s = fn(updates, s, params)
            new_state.append(new_s)
        return updates, tuple(new_state)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #14
0
def scale_by_stddev(decay: float = 0.9,
                    eps: float = 1e-8,
                    initial_scale: float = 0.) -> base.GradientTransformation:
    """Rescale updates by the root of the centered exp. moving average of squares.

  References:
    [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

  Args:
    decay: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    initial_scale: initial value for second moment

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        mu = jax.tree_map(jnp.zeros_like, params)  # First moment
        nu = jax.tree_map(lambda n: jnp.full_like(n, initial_scale),
                          params)  # second moment
        return ScaleByRStdDevState(mu=mu, nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, decay, 1)
        nu = _update_moment(updates, state.nu, decay, 2)
        updates = jax.tree_multimap(
            lambda g, m, n: g * jax.lax.rsqrt(n - jnp.square(m) + eps),
            updates, mu, nu)
        return updates, ScaleByRStdDevState(mu=mu, nu=nu)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #15
0
def apply_if_finite(
        inner: base.GradientTransformation,
        max_consecutive_errors: int) -> base.GradientTransformation:
    """A function that wraps an optimiser to make it robust to a few NaNs or Infs.

  The purpose of this function is to prevent any optimisation to happen if the
  gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the
  gradients, the wrapped optimiser ignores that gradient update. If the NaNs or
  Infs persist after a given number of updates, the wrapped optimiser gives up
  and accepts the update.

  Args:
    inner: Inner transformation to be wrapped.
    max_consecutive_errors: Maximum number of consecutive gradient updates
      containing NaNs of Infs that the wrapped optimiser will ignore. After
      that many ignored updates, the optimiser will give up and accept.

  Returns:
    New GradientTransformation.
  """
    def init(params):
        return ApplyIfFiniteState(notfinite_count=jnp.zeros([], jnp.int32),
                                  last_finite=jnp.array(True, jnp.bool_),
                                  total_notfinite=jnp.zeros([], jnp.int32),
                                  inner_state=inner.init(params))

    def update(updates, state, params=None):
        inner_state = state.inner_state
        flat_updates = tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(
            isfinite, jnp.zeros([], jnp.int32),
            numerics.safe_int32_increment(state.notfinite_count))

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = lax.cond(jnp.logical_or(
            isfinite, notfinite_count > max_consecutive_errors),
                                            do_update,
                                            reject_update,
                                            operand=None)

        return updates, ApplyIfFiniteState(notfinite_count=notfinite_count,
                                           last_finite=isfinite,
                                           total_notfinite=jnp.where(
                                               isfinite, state.total_notfinite,
                                               numerics.safe_int32_increment(
                                                   state.total_notfinite)),
                                           inner_state=new_inner_state)

    return base.GradientTransformation(init=init, update=update)
Exemple #16
0
    def wrapped_transform(*args, **kwargs) -> base.GradientTransformation:
        bound_arguments = inner_signature.bind(*args, **kwargs)
        bound_arguments.apply_defaults()

        sched_hps, numeric_hps, other_hps = {}, {}, {}
        for name, value in bound_arguments.arguments.items():
            if name in static_args or isinstance(value, bool):
                other_hps[name] = value
            elif callable(value):
                sched_hps[name] = value
            elif isinstance(value, (int, float, chex.Array)):
                numeric_hps[name] = value
            else:
                other_hps[name] = value

        def schedule_fn(count, dtype):
            return {
                k: _convert_floats(f(count), dtype)
                for k, f in sched_hps.items()
            }

        def init_fn(params):
            count = jnp.zeros([], jnp.int32)
            dtype = getattr(next(iter(jax.tree_leaves(params)), None), 'dtype',
                            None)
            hparams = {
                k: jnp.asarray(_convert_floats(v, dtype))
                for k, v in numeric_hps.items()
            }
            hparams.update(schedule_fn(count, dtype))
            return InjectHyperparamsState(  # pylint:disable=too-many-function-args
                count, hparams,
                inner_factory(**other_hps, **hparams).init(params))

        def update_fn(updates, state, params=None):
            count_inc = utils.safe_int32_increment(state.count)
            dtype = getattr(next(iter(jax.tree_leaves(updates)), None),
                            'dtype', None)
            hparams = {
                k: _convert_floats(v, dtype)
                for k, v in state.hyperparams.items()
            }
            hparams.update(schedule_fn(count_inc, dtype))
            updates, inner_state = inner_factory(**other_hps,
                                                 **hparams).update(
                                                     updates,
                                                     state.inner_state, params)

            # pylint:disable=too-many-function-args
            return updates, InjectHyperparamsState(count_inc, hparams,
                                                   inner_state)
            # pylint:enable=too-many-function-args

        return base.GradientTransformation(init_fn, update_fn)
Exemple #17
0
def scale_by_radam(b1: float = 0.9,
                   b2: float = 0.999,
                   eps: float = 1e-8,
                   eps_root: float = 0.0,
                   threshold: float = 5.0) -> base.GradientTransformation:
    """Rescale updates according to the Rectified Adam algorithm.

  References:
    [Liu et al, 2020](https://arxiv.org/abs/1908.03265)

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    threshold: Threshold for variance tractability

  Returns:
    An (init_fn, update_fn) tuple.
  """

    ro_inf = 2. / (1 - b2) - 1

    def _radam_update(params):
        ro = params[0]
        mu_hat = params[1]
        nu_hat = params[2]
        r = jnp.sqrt(
            (ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro))
        updates = jax.tree_multimap(
            lambda m, v: r * m / (jnp.sqrt(v + eps_root) + eps), mu_hat,
            nu_hat)
        return updates

    def init_fn(params):
        mu = jax.tree_map(jnp.zeros_like, params)  # First moment
        nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
        return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, b1, 1)
        nu = _update_moment(updates, state.nu, b2, 2)
        count_inc = numerics.safe_int32_increment(state.count)
        b2t = b2**count_inc
        ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)
        mu_hat = _bias_correction(mu, b1, count_inc)
        nu_hat = _bias_correction(nu, b2, count_inc)
        updates = jax.lax.cond(ro >= threshold, _radam_update,
                               lambda _: mu_hat, (ro, mu_hat, nu_hat))
        return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #18
0
def differentially_private_aggregate(l2_norm_clip: float,
                                     noise_multiplier: float,
                                     seed: int) -> base.GradientTransformation:
    """Aggregates gradients based on the DPSGD algorithm.

  WARNING: Unlike other transforms, `differentially_private_aggregate` expects
  the input updates to have a batch dimension in the 0th axis. That is, this
  function expects per-example gradients as input (which are easy to obtain in
  JAX using `jax.vmap`). It can still be composed with other transformations as
  long as it is the first in the chain.

  References:
    [Abadi et al, 2016](https://arxiv.org/abs/1607.00133)

  Args:
    l2_norm_clip: maximum L2 norm of the per-example gradients.
    noise_multiplier: ratio of standard deviation to the clipping norm.
    seed: initial seed used for the jax.random.PRNGKey

  Returns:
    A `GradientTransformation`.
  """
    noise_std = l2_norm_clip * noise_multiplier

    def init_fn(_):
        return DifferentiallyPrivateAggregateState(
            rng_key=jax.random.PRNGKey(seed))

    def update_fn(updates, state, params=None):
        del params
        grads_flat, grads_treedef = jax.tree_flatten(updates)
        bsize = grads_flat[0].shape[0]

        if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
            raise ValueError(
                'Unlike other transforms, `differentially_private_aggregate` expects'
                ' `updates` to have a batch dimension in the 0th axis. That is, this'
                ' function expects per-example gradients as input.')

        new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1)
        global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads_flat)
        divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
        clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1)
                   for g in grads_flat]
        noised = [
            (g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
            for g, r in zip(clipped, rngs)
        ]
        return (jax.tree_unflatten(grads_treedef, noised),
                DifferentiallyPrivateAggregateState(rng_key=new_key))

    return base.GradientTransformation(init_fn, update_fn)
Exemple #19
0
def flatten(
    inner: base.GradientTransformation
) -> base.GradientTransformation:
  """Flattens parameters and gradients for init and update of inner transform.

  This can reduce the overhead of performing many calculations on lots of small
  variables, at the cost of slightly increased memory usage.

  Args:
    inner: Inner transformation to flatten inputs for.

  Returns:
    New GradientTransformation.
  """

  def _flatten(params):
    """Flattens and concatenates all tensors in params to a single vector."""
    params, _ = tree_flatten(params)
    return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])

  def _unflatten(updates, flat):
    """Extracts tensors from flat, using the structure and shapes of params."""
    updates_flat, treedef = tree_flatten(updates)
    offsets = []
    for update in updates_flat:
      size = np.prod(update.shape)
      if offsets:
        offsets.append(size + offsets[-1])
      else:
        offsets.append(size)
    del offsets[-1]
    flat_split = jnp.split(flat, offsets)
    reshaped = [
        jnp.reshape(flat_update, update.shape)
        for flat_update, update in zip(flat_split, updates_flat)
    ]
    return tree_unflatten(treedef, reshaped)

  def init_fn(params):
    flat = _flatten(params)
    return inner.init(flat)

  def update_fn(updates, state, params=None):
    if params is not None:
      params = _flatten(params)
    updates_flat, state = inner.update(_flatten(updates), state, params)
    updates = _unflatten(updates, updates_flat)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #20
0
def scale_by_yogi(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-3,
    eps_root: float = 0.0,
    initial_accumulator_value: float = 1e-6
) -> base.GradientTransformation:
  """Rescale updates according to the Yogi algorithm.

  Supports complex numbers, see
  https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29

  References:
    [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of variance of grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    initial_accumulator_value: The starting value for accumulators.
      Only positive values are allowed.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    value_like = lambda p: jnp.full_like(p, initial_accumulator_value)
    mu = jax.tree_map(value_like, params)  # First moment
    nu = jax.tree_map(value_like, params)  # Second Central moment
    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = _update_moment(updates, state.mu, b1, 1)
    nu = jax.tree_multimap(
        lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g),
        updates, state.nu)
    count_inc = numerics.safe_int32_increment(state.count)
    mu_hat = _bias_correction(mu, b1, count_inc)
    nu_hat = _bias_correction(nu, b2, count_inc)
    updates = jax.tree_multimap(
        lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
    return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)
Exemple #21
0
def _test_optimizer(step_size: float) -> base.GradientTransformation:
    """Fast optimizer for the lookahead tests."""

    # Use SGD for simplicity but add non-trivial optimizer state so that the
    # resetting behaviour of lookahead can be tested.
    def init_fn(params):
        aggregate_grads = jax.tree_map(jnp.zeros_like, params)
        return TestOptimizerState(aggregate_grads, is_reset=True)

    def update_fn(updates, state, params=None):
        del params  # unused by the test optimizer
        aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
        updates = jax.tree_map(lambda u: step_size * u, updates)
        return updates, TestOptimizerState(aggregate_grads, is_reset=False)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #22
0
def scale_by_trust_ratio(
    min_norm: float = 0.0,
    trust_coefficient: float = 1.,
    eps: float = 0.,
) -> base.GradientTransformation:
  """Scale updates by trust ratio`.

  References:
    [You et. al 2020](https://arxiv.org/abs/1904.00962)

  Args:
    min_norm: minimum norm for params and gradient norms; by default is zero.
    trust_coefficient: a multiplier for the trust ratio.
    eps: additive constant added to the denominator for numerical stability.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    del params
    return ScaleByTrustRatioState()

  def update_fn(updates, state, params):
    if params is None:
      raise ValueError(base.NO_PARAMS_MSG)

    def _scale_update(update, param):

      # Clip norms to minimum value, by default no clipping.
      param_norm = numerics.safe_norm(param, min_norm)
      update_norm = numerics.safe_norm(update, min_norm)
      trust_ratio = trust_coefficient * param_norm / (update_norm + eps)

      # If no minimum norm clipping is used
      # Set trust_ratio to 1 in case where parameters would never be updated.
      zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.)
      safe_trust_ratio = jnp.where(
          zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio)

      return update * safe_trust_ratio

    updates = jax.tree_multimap(_scale_update, updates, params)
    return updates, state

  return base.GradientTransformation(init_fn, update_fn)
Exemple #23
0
def scale_by_adam(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
) -> base.GradientTransformation:
  """Rescale updates according to the Adam algorithm.

  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    mu_dtype: optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype is inferred from `params` and `updates`.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  mu_dtype = utils.canonicalize_dtype(mu_dtype)

  def init_fn(params):
    mu = jax.tree_map(  # First moment
        lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
    nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = _update_moment(updates, state.mu, b1, 1)
    nu = _update_moment_per_elem_norm(updates, state.nu, b2, 2)
    count_inc = numerics.safe_int32_increment(state.count)
    mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype)
    nu_hat = _bias_correction(nu, b2, count_inc)
    updates = jax.tree_multimap(
        lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
    return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)
Exemple #24
0
def centralize() -> base.GradientTransformation:
    """Centralize gradients.

  References:
    [Yong et al, 2020](https://arxiv.org/abs/2004.01461)

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return CentralState()

    def update_fn(updates, state, params=None):
        del params
        updates = jax.tree_map(_subtract_mean, updates)
        return updates, state

    return base.GradientTransformation(init_fn, update_fn)
Exemple #25
0
def scale(step_size: float) -> base.GradientTransformation:
    """Scale updates by some fixed scalar `step_size`.

  Args:
    step_size: a scalar corresponding to a fixed scaling factor for updates.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return ScaleState()

    def update_fn(updates, state, params=None):
        del params
        updates = jax.tree_map(lambda g: step_size * g, updates)
        return updates, state

    return base.GradientTransformation(init_fn, update_fn)
Exemple #26
0
def _test_optimizer(step_size: float) -> base.GradientTransformation:
    """Fast optimizer for the lookahead tests."""

    # Use SGD for simplicity but add non-trivial optimizer state so that the
    # resetting behaviour of lookahead can be tested.
    def init_fn(params):
        aggregate_grads = jax.tree_map(jnp.zeros_like, params)
        return TestOptimizerState(aggregate_grads, is_reset=True)

    def update_fn(updates, state, params):
        # The test optimizer does not use the parameters, but we check that they
        # have been passed correctly.
        chex.assert_trees_all_equal_shapes(updates, params)
        aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
        updates = jax.tree_map(lambda u: step_size * u, updates)
        return updates, TestOptimizerState(aggregate_grads, is_reset=False)

    return base.GradientTransformation(init_fn, update_fn)
Exemple #27
0
def scale_by_param_norm(
        min_scale: float = 1e-3) -> base.GradientTransformation:
    """Scale updates for each layer by the norm of that layer's parameters.

  Args:
    min_scale: minimum scaling factor.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return base.EmptyState()

    def update_fn(updates, state, params):
        updates = jax.tree_multimap(
            lambda u, p: u * numerics.safe_norm(p, min_scale), updates, params)
        return updates, state

    return base.GradientTransformation(init_fn, update_fn)
Exemple #28
0
def clip(max_delta) -> base.GradientTransformation:
    """Clip updates element-wise, to be between -max_delta and +max_delta.

  Args:
    max_delta: the maximum absolute value for each element in the update.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return ClipState()

    def update_fn(updates, state, params=None):
        del params
        updates = jax.tree_map(lambda g: jnp.clip(g, -max_delta, max_delta),
                               updates)
        return updates, state

    return base.GradientTransformation(init_fn, update_fn)
Exemple #29
0
def scale_by_belief(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-16,
    eps_root: float = 1e-16
) -> base.GradientTransformation:
  """Rescale updates according to the AdaBelief algorithm.

  References:
    [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of variance of grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the second moment of the prediction error to
      improve numerical stability. If backpropagating gradients through the
      gradient transformation (e.g. for meta-learning), this must be non-zero.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(params):
    mu = jax.tree_map(jnp.zeros_like, params)  # First moment
    s = jax.tree_map(jnp.zeros_like, params)  # Second Central moment
    return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s)

  def update_fn(updates, state, params=None):
    del params
    mu = _update_moment(updates, state.mu, b1, 1)
    prediction_error = jax.tree_multimap(lambda g, m: g-m, updates, state.mu)
    nu = _update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
    nu = jax.tree_map(lambda v: v + eps_root, nu)
    count_inc = numerics.safe_int32_increment(state.count)
    mu_hat = _bias_correction(mu, b1, count_inc)
    nu_hat = _bias_correction(nu, b2, count_inc)
    updates = jax.tree_multimap(
        lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
    return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)
Exemple #30
0
def maybe_update(
        inner: base.GradientTransformation,
        should_update_fn: Callable[[Array],
                                   Array]) -> base.GradientTransformation:
    """Calls the inner update function only at certain steps.

  Creates a transformation wrapper which counts the number of times the `update`
  function has been called. This counter is passed to the `should_update_fn` to
  decide when to call the inner update function.

  When not calling the inner update function, the `updates` and the inner state
  are left untouched and just passed through. The step counter is increased
  regardless.

  Args:
    inner: the inner transformation.
    should_update_fn: this function takes in a step counter (array of shape []
      and dtype int32), and returns a boolean array of shape [].

  Returns:
    An `optax.GradientTransformation`.
  """
    def init_fn(params):
        return MaybeUpdateState(inner_state=inner.init(params),
                                step=jnp.zeros([], dtype=jnp.int32))

    def update_fn(updates, state, params=None):
        def do_update(_):
            return inner.update(updates, state.inner_state, params)

        def reject_update(_):
            return updates, state.inner_state

        updates, new_inner_state = lax.cond(should_update_fn(state.step),
                                            do_update,
                                            reject_update,
                                            operand=None)
        return updates, MaybeUpdateState(
            new_inner_state, numerics.safe_int32_increment(state.step))

    return base.GradientTransformation(init_fn, update_fn)