def bias_correction( decay: float = 0.9, accumulator_dtype: Optional[Any] = None ) -> optax.GradientTransformation: """Compute the Adam style bias correction. Args: decay: the decay rate for the exponential moving average. accumulator_dtype: optional `dtype` to used for the accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) def init_fn(params): del params return BiasCorrectionState(count=jnp.zeros([], jnp.int32)) def update_fn(updates, state, params=None): del params count_inc = utils.safe_int32_increment(state.count) new_vals = _bias_correction(updates, decay, count_inc) return new_vals, BiasCorrectionState(count=count_inc) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_adam(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, debias: bool = True) -> optax.GradientTransformation: """Rescale updates according to the Adam algorithm. Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. debias: whether to use bias correction. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): mu = jax.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) count = state.count + jnp.array(1, dtype=jnp.int32) mu_hat = mu if not debias else _bias_correction(mu, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation: """Clip updates to be at most clipping * parameter_norm. References: [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image Recognition Without Normalization. Args: clip: Maximum allowed ratio of update norm to parameter norm. eps: epsilon term to prevent clipping of zero-initialized params. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return optax.ClipByGlobalNormState() def update_fn(updates, state, params): g_norm = jax.tree_map(unitwise_norm, updates) p_norm = jax.tree_map(unitwise_norm, params) # Maximum allowable norm max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm) # If grad norm > clipping * param_norm, rescale updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates) return updates, state return optax.GradientTransformation(init_fn, update_fn)
def add_weight_decay( weight_decay: float, filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: """Adds a weight decay to the update. Args: weight_decay: weight_decay coeficient. filter_fn: an optional filter function. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_) -> AddWeightDecayState: return AddWeightDecayState() def update_fn( updates: optax.Updates, state: AddWeightDecayState, params: optax.Params, ) -> Tuple[optax.Updates, AddWeightDecayState]: new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params) new_updates = _partial_update(updates, new_updates, params, filter_fn) return new_updates, state return optax.GradientTransformation(init_fn, update_fn)
def scaled_sgld(key: np.ndarray, schedule_fn: callable = optax.constant_schedule(1.)): """ Scale SGLD the correct way, using a custom schedule for the stepsize. an (init_fn, update_fn) Tuple""" scaler = optax.scale_by_schedule(schedule_fn) def init_fn(params): return ScaledSGLDState(count=0, key=key) def update_fn(updates, state, params=None): """ returns - stepsize * updates + np.sqrt(2 stepsize) * z, where z is standard normal. """ count, key = state stepsize = schedule_fn(count) count += 1 updates = jax.tree_map(lambda g: -stepsize * g, updates) key, subkey = random.split(key) # TODO: either throw error when stepsize < 0 or put np.abs(stepsize) # under the square root. return add_noise(subkey, updates, np.sqrt(2 * stepsize)), ScaledSGLDState(count=count, key=key) return optax.GradientTransformation(init_fn, update_fn)
def rexp_updater( exponent: float = 0.5, eps: float = 1e-8, eps_root: float = 0.0, moment: int = 2, use_accumulated_gradient: bool = False, ) -> optax.GradientTransformation: """Apply an update function.""" def init(params: optax.Params) -> optax.OptState: del params return None def update( updates: optax.Updates, state: optax.OptState, params: Optional[optax.Params] = None ) -> Tuple[optax.Updates, optax.OptState]: del params grads = updates['updates'] if not use_accumulated_gradient else updates[ 'moments']['1'] updates['output'] = jax.tree_multimap( lambda u, v: u / (jnp.power(v + eps_root, exponent) + eps), grads, updates['moments'][str(moment)]) return updates, state return optax.GradientTransformation(init, update)
def split_complex(inner: optax.GradientTransformation) -> optax.GradientTransformation: """Splits complex parameters into pairs of real parameters. The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates. Parameters that are real before `split_complex` are passed through unmodified. Args: inner: The inner transformation. Returns: An `optax.GradientTransformation`. """ def init_fn(params): params = jax.tree_map(_complex_to_real_pair, params) inner_state = inner.init(params) return SplitComplexState(inner_state) def update_fn(updates, state, params=None): inner_state = state.inner_state updates = jax.tree_map(_complex_to_real_pair, updates) params = jax.tree_map(_complex_to_real_pair, params) updates, inner_state = inner.update(updates, inner_state, params) updates = jax.tree_map(_real_pair_to_complex, updates, is_leaf=_is_real_pair) return updates, SplitComplexState(inner_state) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_variable_opt(multipliers): """Custom learning rates for different variables. Args: multipliers: a pytree, with the same structure as `params`. Each leaf can be either a float, or an array shape-compatible with the corresponding `params` element. These multiply the learning rate for each leaf. Returns: optax.GradientTransformation optimizer """ def init_fn(params): params_struct = jax.tree_map(lambda _: None, params) multipliers_struct = jax.tree_map(lambda _: None, multipliers) assert params_struct == multipliers_struct, ( 'multipliers should have same struct as params') return None def update_fn(updates, _, params=None): del params # Unused. scaled_updates = jax.tree_multimap(lambda a, g: a * g, multipliers, updates) return scaled_updates, None return optax.GradientTransformation(init_fn, update_fn)
def ema_accumulator(decay: float = 0.999, debias: bool = False) -> optax.GradientTransformation: """Create accumulator that computes EMA on all updates.""" def init(params: optax.Params) -> optax.OptState: return (jax.tree_map(jnp.zeros_like, params), jnp.array(0, dtype=jnp.int32)) def update( updates: optax.Updates, state: optax.OptState, params: Optional[optax.Params] = None ) -> Tuple[optax.Updates, optax.OptState]: del params moments, count = state update_fn = lambda g, t: (1 - decay) * g + decay * t moments = jax.tree_map(update_fn, updates['variables'], moments) count = count + jnp.array(1, dtype=jnp.int32) beta = jnp.array(1, dtype=jnp.int32) - decay**count updates['moments'] = moments if not debias else jax.tree_map( lambda t: t / beta.astype(t.dtype), moments) return updates, (moments, count) return optax.GradientTransformation(init, update)
def add_weight_decay( weight_decay: float, exclude_names: Optional[List[str]] = None) -> optax.GradientTransformation: """Add parameter scaled by `weight_decay` to the `updates`. Same as optax.add_decayed_weights but can exclude parameters by name. Args: weight_decay: weight_decay coefficient. exclude_names: an optional list of names to exclude for weight_decay. ['b'] by default. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AddWeightDecayState() def update_fn(updates, state, params): exclude = _weight_decay_exclude(exclude_names=exclude_names) u_ex, u_in = hk.data_structures.partition(exclude, updates) _, p_in = hk.data_structures.partition(exclude, params) u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in) updates = hk.data_structures.merge(u_ex, u_in) return updates, state return optax.GradientTransformation(init_fn, update_fn)
def precondition_by_amsgrad( b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, ) -> optax.GradientTransformation: """Rescale updates according to the AMSGrad algorithm. References: [Reddi et al, 2018](https://arxiv.org/abs/1904.09237v1) Args: b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): return PreconditionBySecondMomentCoordinateWiseState( count=jnp.zeros([], jnp.int32), nu=jax.tree_map(jnp.zeros_like, params)) def update_fn(updates, state, params=None): del params nu = _update_moment(updates, state.nu, b2, 2) count = state.count + jnp.array(1, dtype=jnp.int32) nu_hat = jax.tree_multimap(jnp.maximum, nu, state.nu) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat) return updates, PreconditionBySecondMomentCoordinateWiseState( count=count, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def nth_power( power: Union[int, Tuple[int]] = 2) -> optax.GradientTransformation: """Create nth power(s) from gradients.""" if not hasattr(power, '__iter__'): power = [power] for p in power: if p != int(p): raise ValueError( f'Currently we only support integer orders; got {p}.') def init(params: optax.Params) -> optax.OptState: del params return None def update( updates: optax.Updates, state: optax.OptState, params: Optional[optax.Params] = None ) -> Tuple[optax.Updates, optax.OptState]: del params for p in power: if p == 1: updates['variables'][str(int(p))] = updates['updates'] else: gradients = jax.tree_map(lambda x: x**p, updates['updates']) # pylint: disable=cell-var-from-loop updates['variables'][str(int(p))] = gradients return updates, state return optax.GradientTransformation(init, update)
def scale_selected_parameters(regex, multiplier): """Creates an optimizer that multiply a selection of weights by `multiplier`. Args: regex: The regex that matches the (flatten) parameters whose learning rate should be scaled. multiplier: The scaling factor applied to matching parameters. Returns: A chainable optimizer. """ def init_fn(params): flatten_params = parameter_overview.flatten_dict(params) multipliers = { k: multiplier if re.match(regex, k) else 1. for k, _ in flatten_params.items() } logging.info( "Optimizer uses multiplier %f for weights: %s", multiplier, sorted( [k for k, _ in flatten_params.items() if re.match(regex, k)])) multipliers = unflatten_dict(multipliers) return MultipliersState(multipliers) def update_fn(updates, state, params=None): del params multiplied_updates = jax.tree_multimap( lambda m, update: jax.tree_map(lambda u: u * m, update), state.multipliers, updates) return multiplied_updates, state return optax.GradientTransformation(init_fn, update_fn)
def decoupled_weight_decay(decay, step_size_fn): """Adds decay * step_size_fn(count) to updates. Args: decay: the decay coefficient for weight decay. step_size_fn: a function that takes an update count as input and proposes the step_size to multiply the params by. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return DecoupledWeightDecayState(count=jnp.zeros([], jnp.int32), step_size=jnp.zeros([], jnp.float32)) def update_fn(updates, state, params=None): step_size = step_size_fn(state.count) * decay updates = jax.tree_multimap(lambda u, p: u - step_size * p, updates, params) # does a _safe_int32_increment max_int32_value = jnp.iinfo(jnp.int32).max new_count = jnp.where(state.count < max_int32_value, state.count + 1, max_int32_value) new_state = DecoupledWeightDecayState(count=new_count, step_size=step_size) return updates, new_state return optax.GradientTransformation(init_fn, update_fn)
def twisted_adam(): def init_fn(params): return State(nu=jax.tree_map(jnp.zeros_like, params), trace=jax.tree_map(jnp.zeros_like, params), count=jnp.zeros([], jnp.int32)) def update_fn(updates, state, params=None): del params count = state.count + jnp.array(1, jnp.int32) nu = { 'w': (1 - rms_decay) * (updates['w']**2) + rms_decay * state.nu['w'] } updates = { 'w': updates['w'] / (jax.lax.sqrt(nu['w'] + eps_root) + eps) } updates = { 'w': updates['w'] * jnp.sqrt((1 - rms_decay**count)) } trace = { 'w': (1 - moment_decay) * updates['w'] + moment_decay * state.trace['w'] } updates = {'w': trace['w']} updates = {'w': updates['w'] / (1 - moment_decay**count)} return updates, State(nu=nu, count=count, trace=trace) return optax.GradientTransformation(init_fn, update_fn)
def transform_chain( elements: List[str], hps: List[Dict[str, float]] = None, masks: List[Any] = None, learning_rate: float = None) -> optax.GradientTransformation: """Utility function for chaining GradientTransforms based on string names. Args: elements: list of transform strings. hps: list of dicts of args for each transform. masks: list of masks for each transform. learning_rate: learning rate that gets injected. Returns: optax.GradientTransform """ hps = hps or [{}] * len(elements) masks = masks or [None] * len(elements) transforms = [] if len(hps) != len(elements): raise ValueError('Number of hps must equal number of elements.') if len(masks) != len(elements): raise ValueError('Number of masks must equal number of elements.') transforms = [_transformations[el](**hp) for el, hp in zip(elements, hps)] for i, (transform, mask) in enumerate(zip(transforms, masks)): if mask is not None: transforms[i] = optax.masked(transform, mask) if learning_rate is not None: transforms += [scale_by_learning_rate(learning_rate)] init_fn, update_fn = optax.chain(*transforms) # NOTE(dsuo): We use plain dicts internally due to this issue # https://github.com/deepmind/optax/issues/160. def wrapped_init_fn(params): return init_fn(flax.core.unfreeze(params)) def wrapped_update_fn(updates, state, params=None): new_updates, state = update_fn( flax.core.unfreeze(updates), state, None if params is None else flax.core.unfreeze(params)) if isinstance(updates, flax.core.FrozenDict): new_updates = flax.core.freeze(new_updates) return new_updates, state return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn)
def ema(decay, debias=True): def init_fn(params): del params return {'w': jnp.zeros((2, )), 'count': 0} def update_fn(updates, state, params=None): del params state['count'] += 1 state['w'] = ((1 - decay) * updates['w'] + decay * state['w']) if debias: update = {'w': state['w'] / (1 - decay**state['count'])} else: update = {'w': state['w']} return update, state return optax.GradientTransformation(init_fn, update_fn)
def maybe_skip_gradient_update( inner, gradient_norm_skip_threshold, ): """A function that wraps an optimiser to skip updates under some condition. The purpose of this function is to prevent any optimisation to happen if the gradients contain NaNs, Infs, or if its norm is higher than a certain threshold. That is, when a NaN of Inf, is detected in the gradients or when the norm of the gradient is higher than the threshold, the wrapped optimiser ignores that gradient update. Args: inner: Inner transformation to be wrapped. gradient_norm_skip_threshold: float, Returns: New GradientTransformation. """ def init(params): return MaybeSkipGradientUpdateState(inner_state=inner.init(params)) def update(updates, state, params=None): inner_state = state.inner_state # Compute gradient norm and clip gradient if necessary gradient_norm = optax.global_norm(updates) flat_updates = jax.tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) islowerthan = gradient_norm < gradient_norm_skip_threshold def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (jax.tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = jax.lax.cond(jnp.logical_and( isfinite, islowerthan), do_update, reject_update, operand=None) return updates, MaybeSkipGradientUpdateState( inner_state=new_inner_state) return optax.GradientTransformation(init=init, update=update)
def scale_by_nadam(b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, debias: bool = True) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also follows this) Current code implements a simpler version with no momentum decay and slightly different bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. debias: whether to use bias correction. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): mu = jax.tree_map(jnp.zeros_like, params) # First moment nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) count = state.count + jnp.array(1, dtype=jnp.int32) mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def preconditioner( grad_transformer, accumulator, updater, moment_creator_args, accumulator_args, updater_args, ) -> optax.GradientTransformation: """Generic precondition update function.""" grad_transformer = grad_transformer(**moment_creator_args) accumulator = accumulator(**accumulator_args) updater = updater(**updater_args) def init(params: optax.Params) -> optax.OptState: """`init` function.""" grads_state = grad_transformer.init(params) # NOTE(dsuo): assumes params and updates have the same shape. updates = {'updates': params, 'variables': {}} grads, _ = grad_transformer.update(updates, grads_state, params) # NOTE(dsuo): assume accumulator only needs `gradients`. accumulator_state = accumulator.init(grads['variables']) updater_state = updater.init(params) return (grads_state, accumulator_state, updater_state) def update(updates, state, params=None): """`update` function.""" updates = { 'updates': updates, 'variables': {}, 'moments': {}, 'output': None } new_state = [] for s, transform in zip(state, [grad_transformer, accumulator, updater]): updates, new_s = transform.update(updates, s, params) new_state.append(new_s) return updates['output'], tuple(new_state) return optax.GradientTransformation(init, update)
def __init__(self, learning_rate_fn: Union[float, int, optax.Schedule], normalize_fn: Optional[NormalizeFn] = None): # Accept schedules, as well as scalar values. if isinstance(learning_rate_fn, (float, int)): lr = float(learning_rate_fn) learning_rate_fn = lambda _: lr # Normalization. def update_fn(updates, state, params=None): del params updates = jax.tree_map(normalize_fn or (lambda x: x), updates) return updates, state gradient_transformation = optax.chain( optax.GradientTransformation(lambda _: optax.EmptyState(), update_fn), optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.)) super(SGD, self).__init__(gradient_transformation)
def kitchen_sink(chains: List[optax.GradientTransformation], scales: jnp.array = None, combinator: Union[Callable[[Any, Any], Any], str] = 'sum', combinator_args: Dict[str, float] = None, learning_rate: float = None) -> optax.GradientTransformation: """Runs a list of GradientTransforms in parallel and combines. Args: chains: list of optax.GradientTransforms (typically from transform_chain). scales: a (len(chains),)-shaped jnp.array. combinator: a combinator that reduces a list of identical pytrees combinator_args: a dictionary of keyword arguments to the combinator func. learning_rate: learning rate that gets injected. Returns: optax.GradientTransform """ if isinstance(combinator, str): combinator = _combinators.get(combinator, _sum_combinator) combinator_args = combinator_args or {} if scales is None: scales = jnp.ones(len(chains)) chains = [ optax.chain(chain, optax.scale(scale)) for chain, scale in zip(chains, scales) ] def init_fn(params): return [chain.init(params) for chain in chains] def update_fn(updates, state, params=None): result = [chain.update(updates, chain_state, params) for chain, chain_state in zip(chains, state)] new_updates, new_state = list(zip(*result)) return combinator(*new_updates, **combinator_args), new_state transform = optax.GradientTransformation(init_fn, update_fn) if learning_rate is not None: transform = optax.chain(transform, scale_by_learning_rate(learning_rate)) return transform
def precondition_by_yogi(b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, initial_accumulator_value: float = 1e-6, debias: bool = True) -> optax.GradientTransformation: """Preconditions updates according to the Yogi Preconditioner. References: [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long Args: b2: decay rate for the exponentially weighted average of moments of grads. eps: Term added to the denominator to improve numerical stability. The default is changed to 1e-8. Optax Yogi's default is 1e-3. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. initial_accumulator_value: The starting value for accumulators. Only positive values are allowed. debias: whether to use bias correction or not Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): value_like = lambda p: jnp.full_like(p, initial_accumulator_value) nu = jax.tree_map(value_like, params) # Second Central moment return PreconditionBySecondMomentCoordinateWiseState(count=jnp.zeros( [], jnp.int32), nu=nu) def update_fn(updates, state, params=None): del params nu = jax.tree_multimap( lambda g, v: v - (1 - b2) * jnp.sign(v - g**2) * (g**2), updates, state.nu) count = state.count + jnp.array(1, dtype=jnp.int32) nu_hat = nu if not debias else _bias_correction(nu, b2, count) updates = jax.tree_multimap( lambda u, v: u / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat) return updates, PreconditionBySecondMomentCoordinateWiseState( count=count, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def precondition_by_rms( decay: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, debias: bool = False, ) -> optax.GradientTransformation: """Preconditions updates according to the RMS Preconditioner from Adam. References: [Kingma, Ba 2015] https://arxiv.org/pdf/1412.6980.pdf Args: decay: decay rate for exponentially weighted average of moments of grads. eps: Term added to the denominator to improve numerical stability. The default is kept to 1e-8 to match optax Adam implementation. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. debias: whether to use bias correction or not Gotcha: Note that the usage of epsilon and defaults are different from optax's scale_by_rms. This matches optax's adam template. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): return PreconditionBySecondMomentCoordinateWiseState( count=jnp.zeros([], jnp.int32), nu=jax.tree_map(jnp.zeros_like, params)) def update_fn(updates, state, params=None): del params nu = _update_moment(updates, state.nu, decay, 2) count = state.count + jnp.array(1, dtype=jnp.int32) nu_hat = nu if not debias else _bias_correction(nu, decay, count) updates = jax.tree_multimap( lambda u, v: u / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat) return updates, PreconditionBySecondMomentCoordinateWiseState( count=count, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_adam( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0, mu_dtype=None, ) -> optax.GradientTransformation: """Rescale updates according to the Adam algorithm. References: [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared norm of grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) def init_fn(params): mu = jax.tree_map(lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_map(jnp.zeros_like, params) return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment_norm(updates, state.nu, b2, 2) count_inc = utils.safe_int32_increment(state.count) mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_map(lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, optax.ScaleByAdamState(count=count_inc, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def amsgrad(): adam = optax.scale_by_adam() def init_fn(params): return adam.init(params) def update_fn(updates, state, params=None): prev_nu = state.nu _, state = adam.update(updates, state, params) curr_nu = state.nu nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu, nu_hat) return updates, optax.ScaleByAdamState(count=state.count, mu=state.mu, nu=nu_hat) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_lars( momentum: float = 0.9, eta: float = 0.001, filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: """Rescales updates according to the LARS algorithm. Does not include weight decay. References: [You et al, 2017](https://arxiv.org/abs/1708.03888) Args: momentum: momentum coeficient. eta: LARS coefficient. filter_fn: an optional filter function. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params: optax.Params) -> ScaleByLarsState: mu = jax.tree_multimap(jnp.zeros_like, params) # momentum return ScaleByLarsState(mu=mu) def update_fn(updates: optax.Updates, state: ScaleByLarsState, params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]: def lars_adaptation( update: jnp.ndarray, param: jnp.ndarray, ) -> jnp.ndarray: param_norm = jnp.linalg.norm(param) update_norm = jnp.linalg.norm(update) return update * jnp.where( param_norm > 0., jnp.where(update_norm > 0, (eta * param_norm / update_norm), 1.0), 1.0) adapted_updates = jax.tree_multimap(lars_adaptation, updates, params) adapted_updates = _partial_update(updates, adapted_updates, params, filter_fn) mu = jax.tree_multimap(lambda g, t: momentum * g + t, state.mu, adapted_updates) return mu, ScaleByLarsState(mu=mu) return optax.GradientTransformation(init_fn, update_fn)
def scale_by_amsgrad( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[Any] = None, ) -> optax.GradientTransformation: """Rescale updates according to the AMSGrad algorithm. References: [Reddi et al, 2018](https://arxiv.org/abs/1904.09237v1) Args: b1: decay rate for the exponentially weighted average of grads. b2: decay rate for the exponentially weighted average of squared grads. eps: term added to the denominator to improve numerical stability. eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. mu_dtype: optional `dtype` to be used for the first order accumulator; if `None` then the `dtype is inferred from `params` and `updates`. Returns: An (init_fn, update_fn) tuple. """ def init_fn(params): mu = jax.tree_map( # First moment lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) nu = jax.tree_map(jnp.zeros_like, params) # Second moment return ScaleByAMSGradState(mu=mu, nu=nu) def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) nu_hat = jax.tree_multimap(jnp.maximum, nu, state.nu) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu, nu_hat) return updates, ScaleByAMSGradState(mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn)
def get_transformations_with_vectorized_repeat_prefix( tx: GeneralGradientTransformation, var_params: NestedParams) -> GeneralGradientTransformation: """Vectorizes a transformation on shape/sharding prefixes.""" def _init(variables): return init_with_vectorized_repeat_prefix(tx, variables, var_params) def _update(updates, state, params=None): return update_with_vectorized_repeat_prefix(tx, updates, state, params, var_params) def _init_partition_spec(var_param_args): assert isinstance(tx, ShardedGradientTransformation) return init_partition_spec_with_vectorized_repeat_prefix( tx, var_param_args) if isinstance(tx, ShardedGradientTransformation): return ShardedGradientTransformation( init=_init, update=_update, init_partition_spec=_init_partition_spec) else: assert isinstance(tx, optax.GradientTransformation) return optax.GradientTransformation(init=_init, update=_update)
def distributed_shampoo(learning_rate, block_size, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, matrix_epsilon=1e-6, weight_decay=0.0, start_preconditioning_step=1, preconditioning_compute_steps=1, statistics_compute_steps=1, best_effort_shape_interpretation=True, graft_type=GraftingType.SGD, nesterov=True, exponent_override=0, batch_axis_name=None, mesh_axis_names=None, num_devices_for_pjit=None, shard_optimizer_states=False, inverse_failure_threshold=0.1, moving_average_for_momentum=False, skip_preconditioning_dim_size_gt=4096, clip_by_scaled_gradient_norm=None, precision=lax.Precision.HIGHEST): """Distributed Shampoo optimizer. Distributed Shampoo is a second-order preconditioned method (concretely, a variant of full-matrix Adagrad), that provides significant convergence and wall-clock time improvements compared to conventional first-order methods, and that has been shown to scale to large state-of-the-art deep learning models. References: Scalable Second Order Optimization for Deep Learning, Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer Preprint: https://arxiv.org/abs/2002.09018 Args: learning_rate: the step size used to update the parameters. block_size: Block size for large layers (if > 0). Preconditioning compute operation is cubic in the dimension of the tensor. Block size allows us to chunk the layers into sub-layers of maximal dimension dictated by this value. Use 128 as default (increase if you have compute budget). beta1: momentum parameter. beta2: second moment averaging parameter. diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting to AdaGrad is enabled). matrix_epsilon: epsilon to add to statistics before computing inverse pth root. If you are running in f32 precision for inverse pth root (recommended today) this can go upto 1e-6. If you have latest hardware with native f64 precision, set this upto 1e-12. weight_decay: Weight decay for regularization. start_preconditioning_step: When to start Shampoo update before which diagonal update is used. This is because we dont have enough information to do stable inverse. preconditioning_compute_steps: How often to compute preconditioner. Performance tuning params for controlling memory and compute requirements. Ideally set this and statistics_compute_steps params to 1. statistics_compute_steps: How often to compute statistics. best_effort_shape_interpretation: If there are some small dimensions, collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048] graft_type: Grafting is a technique to fix the layerwise scale of Shampoo optimizer. This allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad is already well tuned. Available options are: GraftingType.SGD and GraftingType.ADAGRAD. nesterov: Nesterov momentum. exponent_override: Override the exponent used in matrix inverse. batch_axis_name: labeled axis over pmap for data-parallel training the optimizer used for. mesh_axis_names: Axis names for the mesh (used in pjit). num_devices_for_pjit: Number of devices to parallelize over when using pjit. shard_optimizer_states: Shard optimizer states to save memory in model parallel training. inverse_failure_threshold: numerics are hard and inverses fail sometimes; we determine that using this threshold. moving_average_for_momentum: Whether to use moving average for momentum instead of exponential moving average. skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is greater than this value. clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when using RMSProp Grafting). precision: precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise) b) lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST (best possible precision, slowest) Returns: a GradientTransformation. """ def sharded_init_fn(params): params_flat, treedef = jax.tree_flatten(params) # Find max size to pad to. max_size = 0 for param in params_flat: preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() sizes = [s[0] for s in shapes] max_size = max(max(sizes), max_size) padded_statistics = [] padded_preconditioners = [] local_stats_flat = [] for param in params_flat: preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) shapes = preconditioner.shapes_for_preconditioners() sizes = [] statistics = [] preconditioners = [] index_start = len(padded_statistics) if not _skip_preconditioning(param): sizes = [s[0] for s in shapes] shapes = preconditioner.shapes_for_preconditioners() statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes] preconditioners = [jnp.eye(max_size) for s in shapes] padded_statistics.extend(statistics) padded_preconditioners.extend(preconditioners) adagrad_statistics = [] if graft_type != GraftingType.SGD: adagrad_statistics = jnp.zeros_like(param) local_stats_flat.append( LocalShardedParameterStats(adagrad_statistics, jnp.zeros_like(param), jnp.zeros_like(param), index_start, sizes)) local_stats = jax.tree_unflatten(treedef, local_stats_flat) # Pad the statistics and preconditioner matrices to be a multiple of # num devices. # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. to_pad = -len(padded_statistics) % num_devices_for_pjit padded_statistics.extend([ jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad) ]) padded_preconditioners.extend([ jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad) ]) global_stats = GlobalShardedParameterStats( jnp.stack(padded_statistics), jnp.stack(padded_preconditioners)) return ShampooState( count=jnp.zeros([], jnp.int32), stats=ShardedShampooStats(global_stats, local_stats)) def sharded_update_fn(grads, state, params): """Transform the input gradient and update all statistics in sharded mode. Args: grads: the gradient tensors for the parameters. state: a named tuple containing the state of the optimizer params: the parameters that should be updated. Returns: A tuple containing the new parameters and the new optimizer state. """ params_flat, treedef = jax.tree_flatten(params) grads_flat = treedef.flatten_up_to(grads) global_stats = state.stats.global_stats local_stats_flat = treedef.flatten_up_to(state.stats.local_stats) stats_flat = [ _convert_to_parameter_stats(global_stats, local_stat) for local_stat in local_stats_flat ] new_stats_flat = jax.tree_multimap( lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) exponents = [] for stat, param in zip(new_stats_flat, params_flat): num_statistics = len(stat.statistics) if num_statistics > 0: preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) exponent = ( preconditioner.exponent_for_preconditioner() if exponent_override == 0 else exponent_override) exponents.extend([exponent] * num_statistics) outputs = jax.tree_multimap( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_unflatten(treedef, updates_flat) # Create new local_stats new_local_stats_flat = [ _convert_from_parameter_stats(new_stat, local_stat) for new_stat, local_stat in zip(new_stats_flat, local_stats_flat) ] new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat) max_size = global_stats.statistics.shape[1] new_padded_statistics = [] for stat in new_stats_flat: new_padded_statistics.extend( [pad_matrix(stat, max_size) for stat in stat.statistics]) # Create global stats # TODO(rohananil): Preconditioner is not updated every step, so cost of # stack/pad can be obviated away. # Pad the statistics and preconditioner matrices to be a multiple of # num devices. # TODO(rohananil): Relax to only the size of the mesh axis where the dim # is split on. to_pad = -len(new_padded_statistics) % num_devices_for_pjit new_padded_statistics.extend([ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype) for _ in range(to_pad) ]) exponents.extend([1 for _ in range(to_pad)]) def _matrix_inverse_pth_root_vmap(xs, ps): mi_pth_root = functools.partial( matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision) preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) return preconditioners, errors def _internal_inverse_pth_root_all(): preconditioners, errors = _matrix_inverse_pth_root_vmap( global_stats.statistics, jnp.stack(exponents)) return preconditioners, errors if preconditioning_compute_steps == 1: new_preconditioners, errors = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = global_stats.statistics errors_init = np.stack([inverse_failure_threshold] * len(exponents)) init_state = [preconditioners_init, errors_init] perform_step = state.count % preconditioning_compute_steps == 0 new_preconditioners, errors = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state) errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( jnp.isnan(errors), errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + (1.0 - predicate) * new_preconditioners) new_global_stats = GlobalShardedParameterStats( jnp.stack(new_padded_statistics), new_conditional_preconditioners) new_shampoo_state = ShampooState( count=state.count + 1, stats=ShardedShampooStats(new_global_stats, new_local_stats)) return updates, new_shampoo_state def init_fn(params): """Initialise the optimiser's state.""" def _init(param): preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) statistics = [] preconditioners = [] if not _skip_preconditioning(param): shapes = preconditioner.shapes_for_preconditioners() statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes] preconditioners = [jnp.eye(s[0]) for s in shapes] adagrad_statistics = [] if graft_type != GraftingType.SGD: adagrad_statistics = jnp.zeros_like(param) return ParameterStats(adagrad_statistics, statistics, preconditioners, jnp.zeros_like(param), jnp.zeros_like(param)) return ShampooState( count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < 1 or any( [s > skip_preconditioning_dim_size_gt for s in param.shape]) def _compute_stats(grad, state, param, step): """Compute per-parameter statistics.""" preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) new_statistics = [[]] * len(state.statistics) w1 = beta2 w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) if not _skip_preconditioning(param): def compute_updated_statistics(): new_stats = preconditioner.statistics_from_grad(grad) new_stats_accumulators = [] for stat, stat_accumulator in zip(new_stats, state.statistics): new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat) return new_stats_accumulators if statistics_compute_steps > 1: perform_step = step % statistics_compute_steps == 0 init_state = state.statistics new_statistics = list( efficient_cond(perform_step, compute_updated_statistics, init_state)) else: new_statistics = compute_updated_statistics() return ParameterStats(state.diagonal_statistics, new_statistics, state.preconditioners, state.diagonal_momentum, state.momentum) def _compute_preconditioners(states, params, step): """Compute preconditioners for statistics.""" statistics = [] num_statistics_per_state = [] original_shapes = [] exponents = [] max_size = 0 prev_preconditioners = [] for state, param in zip(states, params): num_statistics = len(state.statistics) num_statistics_per_state.append(num_statistics) original_shapes_for_state = [] if num_statistics > 0: preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) for statistic in state.statistics: exponents.append(preconditioner.exponent_for_preconditioner( ) if exponent_override == 0 else exponent_override) original_shapes_for_state.append(statistic.shape) max_size = max(max_size, statistic.shape[0]) statistics.extend(state.statistics) prev_preconditioners.extend(state.preconditioners) original_shapes.extend(original_shapes_for_state) num_statistics = len(statistics) if batch_axis_name: num_devices = lax.psum(1, batch_axis_name) # Pad statistics and exponents to next multiple of num_devices. packed_statistics = [pad_matrix(stat, max_size) for stat in statistics] to_pad = -num_statistics % num_devices packed_statistics.extend([ jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad) ]) exponents.extend([1 for _ in range(to_pad)]) if not packed_statistics: return states # Batch statistics and exponents so that so that leading axis is # num_devices. def _batch(statistics, exponents, num_devices): assert len(statistics) == len(exponents) n = len(statistics) b = int(n / num_devices) batched_statistics = [ jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b) ] batched_exponents = [ jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b) ] return jnp.stack(batched_statistics), jnp.stack(batched_exponents) # Unbatch values across leading axis and return a list of elements. def _unbatch(batched_values): b1, b2 = batched_values.shape[0], batched_values.shape[1] results = [] for v_array in jnp.split( batched_values, indices_or_sections=b1, axis=0): v_array = jnp.squeeze(v_array) # b2 = batches (number of preconditioner computation) per core. if b2 > 1: for v in jnp.split(v_array, indices_or_sections=b2, axis=0): results.append(jnp.squeeze(v)) else: results.append(v_array) return results all_statistics, all_exponents = _batch(packed_statistics, exponents, num_devices) else: to_pad = -num_statistics % num_devices_for_pjit padded_statistics = [pad_matrix(stat, max_size) for stat in statistics] padded_statistics.extend([ jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad) ]) exponents.extend([1 for _ in range(to_pad)]) all_statistics = jnp.stack(padded_statistics) all_exponents = jnp.stack(exponents) def _matrix_inverse_pth_root_vmap(xs, ps): mi_pth_root = functools.partial( matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision) preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps) return preconditioners, errors def _matrix_inverse_pth_root_pjit(xs, ps): mesh_axis_names_tuple = tuple(mesh_axis_names) # Partition the concatenated statistics matrix across all cores. partitioned_xs, partitioned_ps = pjit.pjit( lambda x, y: (x, y), in_axis_resources=None, out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps) # Run matrix inverse pth root on each shard. partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap( partitioned_xs, partitioned_ps) # Recombine the outputs at each core. preconditioners, errors = pjit.pjit( lambda x, y: (x, y), in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,), pjit.PartitionSpec(mesh_axis_names_tuple,)), out_axis_resources=(None, None))(partitioned_preconditioners, partitioned_errors) return preconditioners, errors if not batch_axis_name: def _internal_inverse_pth_root_all(): preconditioners, errors = _matrix_inverse_pth_root_pjit( all_statistics, all_exponents) b1 = preconditioners.shape[0] def split(batched_values): return [ jnp.squeeze(v) for v in jnp.split( batched_values, indices_or_sections=b1, axis=0) ] return split(preconditioners), split(errors) if preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = padded_statistics errors_init = [inverse_failure_threshold] * len(padded_statistics) init_state = [preconditioners_init, errors_init] perform_step = step % preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state) else: def _internal_inverse_pth_root_all(): preconditioners = jnp.array(all_statistics) current_replica = lax.axis_index(batch_axis_name) preconditioners, errors = _matrix_inverse_pth_root_vmap( all_statistics[current_replica], all_exponents[current_replica]) preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) errors = jax.lax.all_gather(errors, batch_axis_name) preconditioners_flat = _unbatch(preconditioners) errors_flat = _unbatch(errors) return preconditioners_flat, errors_flat if preconditioning_compute_steps == 1: preconditioners_flat, errors_flat = _internal_inverse_pth_root_all() else: # Passing statistics instead of preconditioners as they are similarly # shaped tensors. Note statistics will be ignored as we are passing in # a large init value for error. preconditioners_init = packed_statistics errors_init = ([inverse_failure_threshold] * len(packed_statistics)) init_state = [preconditioners_init, errors_init] perform_step = step % preconditioning_compute_steps == 0 preconditioners_flat, errors_flat = efficient_cond( perform_step, _internal_inverse_pth_root_all, init_state) def _skip(error): condition = jnp.logical_or( jnp.isnan(error), error >= inverse_failure_threshold) return condition.astype(error.dtype) def _select_preconditioner(error, new_p, old_p): return lax.cond( _skip(error), lambda _: old_p, lambda _: new_p, operand=None) new_preconditioners_flat = [] for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes, prev_preconditioners, errors_flat): new_preconditioners_flat.append( _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p)) assert len(states) == len(num_statistics_per_state) assert len(new_preconditioners_flat) == num_statistics # Add back empty preconditioners so we that we can set the optimizer state. preconditioners_for_states = [] idx = 0 for num_statistics, state in zip(num_statistics_per_state, states): if num_statistics == 0: preconditioners_for_states.append([]) else: preconditioners_for_state = new_preconditioners_flat[idx:idx + num_statistics] assert len(state.statistics) == len(preconditioners_for_state) preconditioners_for_states.append(preconditioners_for_state) idx += num_statistics new_states = [] for state, new_preconditioners in zip(states, preconditioners_for_states): new_states.append( ParameterStats(state.diagonal_statistics, state.statistics, new_preconditioners, state.diagonal_momentum, state.momentum)) return new_states def _transform_grad(grad, state, param, step): """Transform per-parameter gradients.""" preconditioner = Preconditioner(param, block_size, best_effort_shape_interpretation) sgd_update = grad new_diagonal_statistics = state.diagonal_statistics if graft_type == GraftingType.ADAGRAD: new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad) adagrad_update = grad / ( jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) grafting_update = adagrad_update elif (graft_type == GraftingType.RMSPROP or graft_type == GraftingType.RMSPROP_NORMALIZED): scaled_grad = grad if graft_type == GraftingType.RMSPROP_NORMALIZED: scaled_grad = grad / jnp.linalg.norm(grad) w1 = beta2 w2 = beta2 if beta2 == 1.0 else (1.0 - beta2) new_diagonal_statistics = ( w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)) rmsprop_update = scaled_grad / ( jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon) if clip_by_scaled_gradient_norm: scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / ( jnp.sqrt(float(rmsprop_update.size))) clipping_denom = jnp.maximum( 1., scaled_grad_norm / clip_by_scaled_gradient_norm) rmsprop_update /= clipping_denom grafting_update = rmsprop_update else: grafting_update = sgd_update precond_grad = grad if not _skip_preconditioning(param): precond_grad = preconditioner.preconditioned_grad(precond_grad, state.preconditioners) else: precond_grad = grafting_update grafting_update_norm = jnp.linalg.norm(grafting_update) precond_grad_norm = jnp.linalg.norm(precond_grad) multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16)) shampoo_update = precond_grad * multiplier shampoo_update_with_wd = shampoo_update grafting_update_with_wd = grafting_update if weight_decay != 0: shampoo_update_with_wd = shampoo_update + weight_decay * param grafting_update_with_wd = grafting_update + weight_decay * param w = (1.0 - beta1) if moving_average_for_momentum else 1.0 shampoo_update_with_wd_momentum = ( state.momentum * beta1 + w * shampoo_update_with_wd) grafting_update_with_wd_momentum = ( state.diagonal_momentum * beta1 + w * grafting_update_with_wd) run_shampoo = (step >= start_preconditioning_step).astype( grafting_update_with_wd_momentum.dtype) momentum_update = ( run_shampoo * shampoo_update_with_wd_momentum + (1.0 - run_shampoo) * grafting_update_with_wd_momentum) wd_update = ( run_shampoo * shampoo_update_with_wd + (1.0 - run_shampoo) * grafting_update_with_wd) if nesterov: momentum_update = w * wd_update + beta1 * momentum_update lr = learning_rate if callable(learning_rate): lr = learning_rate(step) transformed_update = -1.0 * lr * momentum_update param_stats = ParameterStats(new_diagonal_statistics, state.statistics, state.preconditioners, grafting_update_with_wd_momentum, shampoo_update_with_wd_momentum) return transformed_update, param_stats def update_fn(grads, state, params): """Transform the input gradient and update all statistics. Args: grads: the gradient tensors for the parameters. state: a named tuple containing the state of the optimizer params: the parameters that should be updated. Returns: A tuple containing the new parameters and the new optimizer state. """ params_flat, treedef = jax.tree_flatten(params) stats_flat = treedef.flatten_up_to(state.stats) grads_flat = treedef.flatten_up_to(grads) new_stats_flat = jax.tree_multimap( lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) outputs = jax.tree_multimap( lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ()) updates = jax.tree_unflatten(treedef, updates_flat) new_stats = jax.tree_unflatten(treedef, new_stats_flat) new_state = ShampooState( count=state.count+1, stats=new_stats) return updates, new_state if shard_optimizer_states: return optax.GradientTransformation(sharded_init_fn, sharded_update_fn) else: return optax.GradientTransformation(init_fn, update_fn)