def add_decayed_weights( weight_decay: float = 0.0, mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None ) -> base.GradientTransformation: """Add parameter scaled by `weight_decay`. Args: weight_decay: a scalar weight decay rate. mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AddDecayedWeightsState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params) return updates, state # If mask is not `None`, apply mask to the gradient transformation. # E.g. it is common to skip weight decay on bias units and batch stats. if mask is not None: return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) return base.GradientTransformation(init_fn, update_fn)
def add_noise(eta: float, gamma: float, seed: int) -> base.GradientTransformation: """Add gradient noise. References: [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) Args: eta: base variance of the gaussian noise added to the gradient. gamma: decay exponent for annealing of the variance. seed: seed for random number generation. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AddNoiseState(count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed)) def update_fn(updates, state, params=None): # pylint: disable=missing-docstring del params num_vars = len(jax.tree_leaves(updates)) treedef = jax.tree_structure(updates) count_inc = numerics.safe_int32_increment(state.count) variance = eta / count_inc**gamma all_keys = jax.random.split(state.rng_key, num=num_vars + 1) noise = jax.tree_multimap( lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), updates, jax.tree_unflatten(treedef, all_keys[1:])) updates = jax.tree_multimap( lambda g, n: g + variance.astype(g.dtype) * n, updates, noise) return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0]) return base.GradientTransformation(init_fn, update_fn)
def scale_by_param_block_norm( min_scale: float = 1e-3 ) -> base.GradientTransformation: """Scale updates for each param block by the norm of that block's parameters. A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. Args: min_scale: minimum scaling factor. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): del params return base.EmptyState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) updates = jax.tree_multimap( lambda u, p: u * numerics.safe_norm(p, min_scale), updates, params) return updates, state return base.GradientTransformation(init_fn, update_fn)
def apply_every(k: int = 1) -> base.GradientTransformation: """Accumulate gradients and apply them every k steps. Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using `apply_every` with a batch size of N/2 and k=2 is not necessarily equivalent to not using `apply_every` with a batch size of N. If this equivalence is important for you, consider using the `optax.MultiSteps`. Args: k: emit non-zero gradients every k steps, otherwise accumulate them. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): grad_acc = jax.tree_map(jnp.zeros_like, params) return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc) def update_fn(updates, state, params=None): del params c = state.count % k acc = c != 0 grad_acc = jax.tree_multimap(lambda g, ga: acc * ga + g, updates, state.grad_acc) emit = c == (k - 1) updates = jax.tree_map(lambda ga: emit * ga, grad_acc) count_inc = numerics.safe_int32_increment(state.count) return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc) return base.GradientTransformation(init_fn, update_fn)
def scale_by_param_block_rms( min_scale: float = 1e-3 ) -> base.GradientTransformation: """Scale updates by rms of the gradient for each param vector or matrix. A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. Args: min_scale: minimum scaling factor. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return base.EmptyState() def update_fn(updates, state, params): updates = jax.tree_map( lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale), updates, params) return updates, state return base.GradientTransformation(init_fn, update_fn)
def clip_by_global_norm(max_norm) -> base.GradientTransformation: """Clip updates using their global norm. References: [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) Args: max_norm: the maximum global norm for an update. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ClipByGlobalNormState() def update_fn(updates, state, params=None): del params g_norm = linear_algebra.global_norm(updates) # TODO(b/163995078): revert back to the following (faster) implementation # once analysed how it affects backprop through update (e.g. meta-gradients) # g_norm = jnp.maximum(max_norm, g_norm) # updates = jax.tree_map(lambda t: (t / g_norm) * max_norm, updates) trigger = g_norm < max_norm updates = jax.tree_map( lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def keep_params_nonnegative() -> base.GradientTransformation: """Modifies the updates to keep parameters non-negative, i.e. >= 0. This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one. WARNING: the transformation expects input params to be non-negative. When params is negative the transformed update will move them to 0. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): del params return NonNegativeParamsState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) updates = jax.tree_multimap( lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def clip_by_block_rms(threshold: float) -> base.GradientTransformation: """Clips updates to a max rms for the gradient of each param vector or matrix. A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. Args: threshold: The maximum rms for the gradient of each param vector or matrix. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): del params return base.EmptyState() def update_fn(updates, state, params=None): del params def _clip_fn(u): clip_denom = jnp.maximum( 1.0, jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) return u / clip_denom updates = jax.tree_map(_clip_fn, updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def adaptive_grad_clip(clipping, eps=1e-3) -> base.GradientTransformation: """Clip updates to be at most clipping * parameter_norm, unit-wise. References: [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) Args: clipping: Maximum allowed ratio of update norm to parameter norm. eps: epsilon term to prevent clipping of zero-initialized params. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AdaptiveGradClipState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) g_norm = jax.tree_map(unitwise_norm, updates) p_norm = jax.tree_map(unitwise_norm, params) # Maximum allowable norm max_norm = jax.tree_map(lambda x: clipping * jnp.maximum(x, eps), p_norm) # If grad norm > clipping * param_norm, rescale updates = jax.tree_multimap(unitwise_clip, g_norm, max_norm, updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def zero_nans() -> base.GradientTransformation: """A transformation which replaces NaNs with 0. Zeroing values in gradients is guaranteed to produce a direction of non-increasing loss. The state of the transformation has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to `update`. This state is not used by the transformation internally, but lets users be aware when NaNs have been zeroed out. Returns: A `GradientTransformation`. """ def init_fn(params): return ZeroNansState( jax.tree_map(lambda p: jnp.array(False, dtype=jnp.bool_), params)) def update_fn(updates, opt_state, params=None): del params opt_state = ZeroNansState( jax.tree_map(lambda p: jnp.any(jnp.isnan(p)), updates)) updates = jax.tree_map( lambda p: jnp.where(jnp.isnan(p), jnp.zeros_like(p), p), updates) return updates, opt_state return base.GradientTransformation(init=init_fn, update=update_fn)
def scale_by_rss(initial_accumulator_value: float = 0.1, eps: float = 1e-7) -> base.GradientTransformation: """Rescale updates by the root of the sum of all squared gradients to date. References: [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) Args: initial_accumulator_value: Starting value for accumulators, must be >= 0. eps: A small floating point value to avoid zero denominator. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): sum_of_squares = jax.tree_map( lambda t: jnp.full_like(t, initial_accumulator_value), params) return ScaleByRssState(sum_of_squares=sum_of_squares) def update_fn(updates, state, params=None): del params sum_of_squares = jax.tree_multimap(lambda g, t: jnp.square(g) + t, updates, state.sum_of_squares) inv_sqrt_g_square = jax.tree_map( lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares) updates = jax.tree_multimap(lambda scale, g: scale * g, inv_sqrt_g_square, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return base.GradientTransformation(init_fn, update_fn)
def scale_by_schedule( step_size_fn: base.Schedule ) -> base.GradientTransformation: """Scale updates using a custom schedule for the `step_size`. Args: step_size_fn: a function that takes an update count as input and proposes the step_size to multiply the updates by. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ScaleByScheduleState(count=jnp.zeros([], jnp.int32)) def update_fn(updates, state, params=None): del params step_size = step_size_fn(state.count) updates = jax.tree_map( lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates) return updates, ScaleByScheduleState( count=numerics.safe_int32_increment(state.count)) return base.GradientTransformation(init_fn, update_fn)
def chain(*args: base.GradientTransformation) -> base.GradientTransformation: """Applies a list of chainable update transformations. Given a sequence of chainable transforms, `chain` returns an `init_fn` that constructs a `state` by concatenating the states of the individual transforms, and returns an `update_fn` which chains the update transformations feeding the appropriate state to each. Args: *args: a sequence of chainable (init_fn, update_fn) tuples. Returns: A single (init_fn, update_fn) tuple. """ init_fns, update_fns = zip(*args) def init_fn(params): return tuple(fn(params) for fn in init_fns) def update_fn(updates, state, params=None): if len(update_fns) != len(state): raise ValueError( 'The number of updates and states has to be the same in ' 'chain! Make sure you have called init first!') new_state = [] for s, fn in zip(state, update_fns): updates, new_s = fn(updates, s, params) new_state.append(new_s) return updates, tuple(new_state) return base.GradientTransformation(init_fn, update_fn)
def scale_by_stddev(decay: float = 0.9, eps: float = 1e-8, initial_scale: float = 0.) -> base.GradientTransformation: """Rescale updates by the root of the centered exp. moving average of squares. References: [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) Args: decay: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. initial_scale: initial value for second moment Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): mu = jax.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_map(lambda n: jnp.full_like(n, initial_scale), params) # second moment return ScaleByRStdDevState(mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, decay, 1) nu = _update_moment(updates, state.nu, decay, 2) updates = jax.tree_multimap( lambda g, m, n: g * jax.lax.rsqrt(n - jnp.square(m) + eps), updates, mu, nu) return updates, ScaleByRStdDevState(mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def apply_if_finite( inner: base.GradientTransformation, max_consecutive_errors: int) -> base.GradientTransformation: """A function that wraps an optimiser to make it robust to a few NaNs or Infs. The purpose of this function is to prevent any optimisation to happen if the gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the gradients, the wrapped optimiser ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimiser gives up and accepts the update. Args: inner: Inner transformation to be wrapped. max_consecutive_errors: Maximum number of consecutive gradient updates containing NaNs of Infs that the wrapped optimiser will ignore. After that many ignored updates, the optimiser will give up and accept. Returns: New GradientTransformation. """ def init(params): return ApplyIfFiniteState(notfinite_count=jnp.zeros([], jnp.int32), last_finite=jnp.array(True, jnp.bool_), total_notfinite=jnp.zeros([], jnp.int32), inner_state=inner.init(params)) def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where( isfinite, jnp.zeros([], jnp.int32), numerics.safe_int32_increment(state.notfinite_count)) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond(jnp.logical_or( isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState(notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.where( isfinite, state.total_notfinite, numerics.safe_int32_increment( state.total_notfinite)), inner_state=new_inner_state) return base.GradientTransformation(init=init, update=update)
def wrapped_transform(*args, **kwargs) -> base.GradientTransformation: bound_arguments = inner_signature.bind(*args, **kwargs) bound_arguments.apply_defaults() sched_hps, numeric_hps, other_hps = {}, {}, {} for name, value in bound_arguments.arguments.items(): if name in static_args or isinstance(value, bool): other_hps[name] = value elif callable(value): sched_hps[name] = value elif isinstance(value, (int, float, chex.Array)): numeric_hps[name] = value else: other_hps[name] = value def schedule_fn(count, dtype): return { k: _convert_floats(f(count), dtype) for k, f in sched_hps.items() } def init_fn(params): count = jnp.zeros([], jnp.int32) dtype = getattr(next(iter(jax.tree_leaves(params)), None), 'dtype', None) hparams = { k: jnp.asarray(_convert_floats(v, dtype)) for k, v in numeric_hps.items() } hparams.update(schedule_fn(count, dtype)) return InjectHyperparamsState( # pylint:disable=too-many-function-args count, hparams, inner_factory(**other_hps, **hparams).init(params)) def update_fn(updates, state, params=None): count_inc = utils.safe_int32_increment(state.count) dtype = getattr(next(iter(jax.tree_leaves(updates)), None), 'dtype', None) hparams = { k: _convert_floats(v, dtype) for k, v in state.hyperparams.items() } hparams.update(schedule_fn(count_inc, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update( updates, state.inner_state, params) # pylint:disable=too-many-function-args return updates, InjectHyperparamsState(count_inc, hparams, inner_state) # pylint:enable=too-many-function-args return base.GradientTransformation(init_fn, update_fn)
def scale_by_radam(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, threshold: float = 5.0) -> base.GradientTransformation: """Rescale updates according to the Rectified Adam algorithm. References: [Liu et al, 2020](https://arxiv.org/abs/1908.03265) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. threshold: Threshold for variance tractability Returns: An (init_fn, update_fn) tuple. """ ro_inf = 2. / (1 - b2) - 1 def _radam_update(params): ro = params[0] mu_hat = params[1] nu_hat = params[2] r = jnp.sqrt( (ro - 4) * (ro - 2) * ro_inf / ((ro_inf - 4) * (ro_inf - 2) * ro)) updates = jax.tree_multimap( lambda m, v: r * m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates def init_fn(params): mu = jax.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) count_inc = numerics.safe_int32_increment(state.count) b2t = b2**count_inc ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) mu_hat = _bias_correction(mu, b1, count_inc) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.lax.cond(ro >= threshold, _radam_update, lambda _: mu_hat, (ro, mu_hat, nu_hat)) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def differentially_private_aggregate(l2_norm_clip: float, noise_multiplier: float, seed: int) -> base.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. WARNING: Unlike other transforms, `differentially_private_aggregate` expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using `jax.vmap`). It can still be composed with other transformations as long as it is the first in the chain. References: [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) Args: l2_norm_clip: maximum L2 norm of the per-example gradients. noise_multiplier: ratio of standard deviation to the clipping norm. seed: initial seed used for the jax.random.PRNGKey Returns: A `GradientTransformation`. """ noise_std = l2_norm_clip * noise_multiplier def init_fn(_): return DifferentiallyPrivateAggregateState( rng_key=jax.random.PRNGKey(seed)) def update_fn(updates, state, params=None): del params grads_flat, grads_treedef = jax.tree_flatten(updates) bsize = grads_flat[0].shape[0] if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat): raise ValueError( 'Unlike other transforms, `differentially_private_aggregate` expects' ' `updates` to have a batch dimension in the 0th axis. That is, this' ' function expects per-example gradients as input.') new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1) global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads_flat) divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat] noised = [ (g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize for g, r in zip(clipped, rngs) ] return (jax.tree_unflatten(grads_treedef, noised), DifferentiallyPrivateAggregateState(rng_key=new_key)) return base.GradientTransformation(init_fn, update_fn)
def flatten( inner: base.GradientTransformation ) -> base.GradientTransformation: """Flattens parameters and gradients for init and update of inner transform. This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage. Args: inner: Inner transformation to flatten inputs for. Returns: New GradientTransformation. """ def _flatten(params): """Flattens and concatenates all tensors in params to a single vector.""" params, _ = tree_flatten(params) return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) def _unflatten(updates, flat): """Extracts tensors from flat, using the structure and shapes of params.""" updates_flat, treedef = tree_flatten(updates) offsets = [] for update in updates_flat: size = np.prod(update.shape) if offsets: offsets.append(size + offsets[-1]) else: offsets.append(size) del offsets[-1] flat_split = jnp.split(flat, offsets) reshaped = [ jnp.reshape(flat_update, update.shape) for flat_update, update in zip(flat_split, updates_flat) ] return tree_unflatten(treedef, reshaped) def init_fn(params): flat = _flatten(params) return inner.init(flat) def update_fn(updates, state, params=None): if params is not None: params = _flatten(params) updates_flat, state = inner.update(_flatten(updates), state, params) updates = _unflatten(updates, updates_flat) return updates, state return base.GradientTransformation(init_fn, update_fn)
def scale_by_yogi( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-3, eps_root: float = 0.0, initial_accumulator_value: float = 1e-6 ) -> base.GradientTransformation: """Rescale updates according to the Yogi algorithm. Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 References: [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of variance of grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. initial_accumulator_value: The starting value for accumulators. Only positive values are allowed. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): value_like = lambda p: jnp.full_like(p, initial_accumulator_value) mu = jax.tree_map(value_like, params) # First moment nu = jax.tree_map(value_like, params) # Second Central moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = jax.tree_multimap( lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g), updates, state.nu) count_inc = numerics.safe_int32_increment(state.count) mu_hat = _bias_correction(mu, b1, count_inc) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def _test_optimizer(step_size: float) -> base.GradientTransformation: """Fast optimizer for the lookahead tests.""" # Use SGD for simplicity but add non-trivial optimizer state so that the # resetting behaviour of lookahead can be tested. def init_fn(params): aggregate_grads = jax.tree_map(jnp.zeros_like, params) return TestOptimizerState(aggregate_grads, is_reset=True) def update_fn(updates, state, params=None): del params # unused by the test optimizer aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False) return base.GradientTransformation(init_fn, update_fn)
def scale_by_trust_ratio( min_norm: float = 0.0, trust_coefficient: float = 1., eps: float = 0., ) -> base.GradientTransformation: """Scale updates by trust ratio`. References: [You et. al 2020](https://arxiv.org/abs/1904.00962) Args: min_norm: minimum norm for params and gradient norms; by default is zero. trust_coefficient: a multiplier for the trust ratio. eps: additive constant added to the denominator for numerical stability. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): del params return ScaleByTrustRatioState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) def _scale_update(update, param): # Clip norms to minimum value, by default no clipping. param_norm = numerics.safe_norm(param, min_norm) update_norm = numerics.safe_norm(update, min_norm) trust_ratio = trust_coefficient * param_norm / (update_norm + eps) # If no minimum norm clipping is used # Set trust_ratio to 1 in case where parameters would never be updated. zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.) safe_trust_ratio = jnp.where( zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio) return update * safe_trust_ratio updates = jax.tree_multimap(_scale_update, updates, params) return updates, state return base.GradientTransformation(init_fn, update_fn)
def scale_by_adam( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[Any] = None, ) -> base.GradientTransformation: """Rescale updates according to the Adam algorithm. References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) def init_fn(params): mu = jax.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_int32_increment(state.count) mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def centralize() -> base.GradientTransformation: """Centralize gradients. References: [Yong et al, 2020](https://arxiv.org/abs/2004.01461) Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return CentralState() def update_fn(updates, state, params=None): del params updates = jax.tree_map(_subtract_mean, updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def scale(step_size: float) -> base.GradientTransformation: """Scale updates by some fixed scalar `step_size`. Args: step_size: a scalar corresponding to a fixed scaling factor for updates. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ScaleState() def update_fn(updates, state, params=None): del params updates = jax.tree_map(lambda g: step_size * g, updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def _test_optimizer(step_size: float) -> base.GradientTransformation: """Fast optimizer for the lookahead tests.""" # Use SGD for simplicity but add non-trivial optimizer state so that the # resetting behaviour of lookahead can be tested. def init_fn(params): aggregate_grads = jax.tree_map(jnp.zeros_like, params) return TestOptimizerState(aggregate_grads, is_reset=True) def update_fn(updates, state, params): # The test optimizer does not use the parameters, but we check that they # have been passed correctly. chex.assert_trees_all_equal_shapes(updates, params) aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False) return base.GradientTransformation(init_fn, update_fn)
def scale_by_param_norm( min_scale: float = 1e-3) -> base.GradientTransformation: """Scale updates for each layer by the norm of that layer's parameters. Args: min_scale: minimum scaling factor. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return base.EmptyState() def update_fn(updates, state, params): updates = jax.tree_multimap( lambda u, p: u * numerics.safe_norm(p, min_scale), updates, params) return updates, state return base.GradientTransformation(init_fn, update_fn)
def clip(max_delta) -> base.GradientTransformation: """Clip updates element-wise, to be between -max_delta and +max_delta. Args: max_delta: the maximum absolute value for each element in the update. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ClipState() def update_fn(updates, state, params=None): del params updates = jax.tree_map(lambda g: jnp.clip(g, -max_delta, max_delta), updates) return updates, state return base.GradientTransformation(init_fn, update_fn)
def scale_by_belief( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, eps_root: float = 1e-16 ) -> base.GradientTransformation: """Rescale updates according to the AdaBelief algorithm. References: [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of variance of grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): mu = jax.tree_map(jnp.zeros_like, params) # First moment s = jax.tree_map(jnp.zeros_like, params) # Second Central moment return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) prediction_error = jax.tree_multimap(lambda g, m: g-m, updates, state.mu) nu = _update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) nu = jax.tree_map(lambda v: v + eps_root, nu) count_inc = numerics.safe_int32_increment(state.count) mu_hat = _bias_correction(mu, b1, count_inc) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat) return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def maybe_update( inner: base.GradientTransformation, should_update_fn: Callable[[Array], Array]) -> base.GradientTransformation: """Calls the inner update function only at certain steps. Creates a transformation wrapper which counts the number of times the `update` function has been called. This counter is passed to the `should_update_fn` to decide when to call the inner update function. When not calling the inner update function, the `updates` and the inner state are left untouched and just passed through. The step counter is increased regardless. Args: inner: the inner transformation. should_update_fn: this function takes in a step counter (array of shape [] and dtype int32), and returns a boolean array of shape []. Returns: An `optax.GradientTransformation`. """ def init_fn(params): return MaybeUpdateState(inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int32)) def update_fn(updates, state, params=None): def do_update(_): return inner.update(updates, state.inner_state, params) def reject_update(_): return updates, state.inner_state updates, new_inner_state = lax.cond(should_update_fn(state.step), do_update, reject_update, operand=None) return updates, MaybeUpdateState( new_inner_state, numerics.safe_int32_increment(state.step)) return base.GradientTransformation(init_fn, update_fn)