def bias_correction( decay: float = 0.9, accumulator_dtype: Optional[Any] = None ) -> optax.GradientTransformation: """Compute the Adam style bias correction. Args: decay: the decay rate for the exponential moving average. accumulator_dtype: optional `dtype` to used for the accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) def init_fn(params): del params return BiasCorrectionState(count=jnp.zeros([], jnp.int32)) def update_fn(updates, state, params=None): del params count_inc = utils.safe_int32_increment(state.count) new_vals = _bias_correction(updates, decay, count_inc) return new_vals, BiasCorrectionState(count=count_inc) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_adam( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[Any] = None, ) -> base.GradientTransformation: """Rescale updates according to the Adam algorithm. References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) def init_fn(params): mu = jax.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment_per_elem_norm(updates, state.nu, b2, 2) count_inc = numerics.safe_int32_increment(state.count) mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return base.GradientTransformation(init_fn, update_fn)
def ema( decay: float, debias: bool = True, accumulator_dtype: Optional[Any] = None ) -> base.GradientTransformation: """Compute an exponential moving average of past updates. Note: `trace` and `ema` have very similar but distinct updates; `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. Both are frequently found in the optimisation literature. Args: decay: the decay rate for the exponential moving average. debias: whether to debias the transformed gradient. accumulator_dtype: optional `dtype` to used for the accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) def init_fn(params): return EmaState( count=jnp.zeros([], jnp.int32), ema=jax.tree_map( lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) def update_fn(updates, state, params=None): del params new_ema = _update_moment(updates, state.ema, decay, order=1) count_inc = utils.safe_int32_increment(state.count) if debias: new_ema = _bias_correction(new_ema, decay, count_inc) state_ema = utils.cast_tree(new_ema, accumulator_dtype) return new_ema, EmaState(count=count_inc, ema=state_ema) return base.GradientTransformation(init_fn, update_fn)
def trace( decay: float, nesterov: bool = False, accumulator_dtype: Optional[Any] = None, ) -> base.GradientTransformation: """Compute a trace of past updates. Note: `trace` and `ema` have very similar but distinct updates; `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. Both are frequently found in the optimisation literature. Args: decay: the decay rate for the tracing of past updates. nesterov: whether to use Nesterov momentum. accumulator_dtype: optional `dtype` to used for the accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) def init_fn(params): return TraceState( trace=jax.tree_map( lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) def update_fn(updates, state, params=None): del params f = lambda g, t: g + decay * t new_trace = jax.tree_multimap(f, updates, state.trace) updates = ( jax.tree_multimap(f, updates, new_trace) if nesterov else new_trace) new_trace = utils.cast_tree(new_trace, accumulator_dtype) return updates, TraceState(trace=new_trace) return base.GradientTransformation(init_fn, update_fn)