コード例 #1
0
def bias_correction(
        decay: float = 0.9,
        accumulator_dtype: Optional[Any] = None
) -> optax.GradientTransformation:
    """Compute the Adam style bias correction.

  Args:
    decay: the decay rate for the exponential moving average.
    accumulator_dtype: optional `dtype` to used for the accumulator; if `None`
      then the `dtype` is inferred from `params` and `updates`.

  Returns:
    An (init_fn, update_fn) tuple.
  """

    accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype)

    def init_fn(params):
        del params
        return BiasCorrectionState(count=jnp.zeros([], jnp.int32))

    def update_fn(updates, state, params=None):
        del params
        count_inc = utils.safe_int32_increment(state.count)
        new_vals = _bias_correction(updates, decay, count_inc)
        return new_vals, BiasCorrectionState(count=count_inc)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #2
0
def scale_by_adam(b1: float = 0.9,
                  b2: float = 0.999,
                  eps: float = 1e-8,
                  eps_root: float = 0.0,
                  debias: bool = True) -> optax.GradientTransformation:
    """Rescale updates according to the Adam algorithm.

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    debias: whether to use bias correction.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        mu = jax.tree_map(jnp.zeros_like, params)  # First moment
        nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
        return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, b1, 1)
        nu = _update_moment(updates, state.nu, b2, 2)
        count = state.count + jnp.array(1, dtype=jnp.int32)
        mu_hat = mu if not debias else _bias_correction(mu, b1, count)
        nu_hat = nu if not debias else _bias_correction(nu, b2, count)
        updates = jax.tree_multimap(
            lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
        return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #3
0
def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation:
    """Clip updates to be at most clipping * parameter_norm.

  References:
    [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image
    Recognition Without Normalization.

  Args:
    clip: Maximum allowed ratio of update norm to parameter norm.
    eps: epsilon term to prevent clipping of zero-initialized params.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return optax.ClipByGlobalNormState()

    def update_fn(updates, state, params):
        g_norm = jax.tree_map(unitwise_norm, updates)
        p_norm = jax.tree_map(unitwise_norm, params)
        # Maximum allowable norm
        max_norm = jax.tree_map(lambda x: clip * jnp.maximum(x, eps), p_norm)
        # If grad norm > clipping * param_norm, rescale
        updates = jax.tree_multimap(my_clip, g_norm, max_norm, updates)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #4
0
ファイル: lars.py プロジェクト: rwightman/efficientnet-jax
def add_weight_decay(
        weight_decay: float,
        filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
    """Adds a weight decay to the update.

    Args:
      weight_decay: weight_decay coeficient.
      filter_fn: an optional filter function.

    Returns:
      An (init_fn, update_fn) tuple.
    """

    def init_fn(_) -> AddWeightDecayState:
        return AddWeightDecayState()

    def update_fn(
            updates: optax.Updates,
            state: AddWeightDecayState,
            params: optax.Params,
    ) -> Tuple[optax.Updates, AddWeightDecayState]:
        new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params)
        new_updates = _partial_update(updates, new_updates, params, filter_fn)
        return new_updates, state

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #5
0
def scaled_sgld(key: np.ndarray,
                schedule_fn: callable = optax.constant_schedule(1.)):
    """
    Scale SGLD the correct way, using a custom schedule for the stepsize.

        an (init_fn, update_fn) Tuple"""
    scaler = optax.scale_by_schedule(schedule_fn)

    def init_fn(params):
        return ScaledSGLDState(count=0, key=key)

    def update_fn(updates, state, params=None):
        """
        returns
        - stepsize * updates + np.sqrt(2 stepsize) * z,
        where z is standard normal.
        """
        count, key = state
        stepsize = schedule_fn(count)
        count += 1
        updates = jax.tree_map(lambda g: -stepsize * g, updates)
        key, subkey = random.split(key)
        # TODO: either throw error when stepsize < 0 or put np.abs(stepsize)
        # under the square root.
        return add_noise(subkey, updates,
                         np.sqrt(2 * stepsize)), ScaledSGLDState(count=count,
                                                                 key=key)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #6
0
def rexp_updater(
    exponent: float = 0.5,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    moment: int = 2,
    use_accumulated_gradient: bool = False,
) -> optax.GradientTransformation:
    """Apply an update function."""
    def init(params: optax.Params) -> optax.OptState:
        del params
        return None

    def update(
        updates: optax.Updates,
        state: optax.OptState,
        params: Optional[optax.Params] = None
    ) -> Tuple[optax.Updates, optax.OptState]:
        del params
        grads = updates['updates'] if not use_accumulated_gradient else updates[
            'moments']['1']

        updates['output'] = jax.tree_multimap(
            lambda u, v: u / (jnp.power(v + eps_root, exponent) + eps), grads,
            updates['moments'][str(moment)])

        return updates, state

    return optax.GradientTransformation(init, update)
コード例 #7
0
ファイル: complex_valued.py プロジェクト: rbktech/netket
def split_complex(inner: optax.GradientTransformation) -> optax.GradientTransformation:
    """Splits complex parameters into pairs of real parameters.

    The inner transformation processes real parameters and updates, and the
    pairs of transformed real updates are merged into complex updates.

    Parameters that are real before `split_complex` are passed through unmodified.

    Args:
      inner: The inner transformation.

    Returns:
      An `optax.GradientTransformation`.
    """

    def init_fn(params):
        params = jax.tree_map(_complex_to_real_pair, params)
        inner_state = inner.init(params)
        return SplitComplexState(inner_state)

    def update_fn(updates, state, params=None):
        inner_state = state.inner_state
        updates = jax.tree_map(_complex_to_real_pair, updates)
        params = jax.tree_map(_complex_to_real_pair, params)
        updates, inner_state = inner.update(updates, inner_state, params)
        updates = jax.tree_map(_real_pair_to_complex, updates, is_leaf=_is_real_pair)
        return updates, SplitComplexState(inner_state)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #8
0
def scale_by_variable_opt(multipliers):
    """Custom learning rates for different variables.

  Args:
    multipliers: a pytree, with the same structure as `params`. Each leaf can
      be either a float, or an array shape-compatible with the corresponding
      `params` element. These multiply the learning rate for each leaf.

  Returns:
    optax.GradientTransformation optimizer
  """
    def init_fn(params):
        params_struct = jax.tree_map(lambda _: None, params)
        multipliers_struct = jax.tree_map(lambda _: None, multipliers)
        assert params_struct == multipliers_struct, (
            'multipliers should have same struct as params')
        return None

    def update_fn(updates, _, params=None):
        del params  # Unused.
        scaled_updates = jax.tree_multimap(lambda a, g: a * g, multipliers,
                                           updates)
        return scaled_updates, None

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #9
0
def ema_accumulator(decay: float = 0.999,
                    debias: bool = False) -> optax.GradientTransformation:
    """Create accumulator that computes EMA on all updates."""
    def init(params: optax.Params) -> optax.OptState:
        return (jax.tree_map(jnp.zeros_like,
                             params), jnp.array(0, dtype=jnp.int32))

    def update(
        updates: optax.Updates,
        state: optax.OptState,
        params: Optional[optax.Params] = None
    ) -> Tuple[optax.Updates, optax.OptState]:
        del params

        moments, count = state
        update_fn = lambda g, t: (1 - decay) * g + decay * t
        moments = jax.tree_map(update_fn, updates['variables'], moments)

        count = count + jnp.array(1, dtype=jnp.int32)
        beta = jnp.array(1, dtype=jnp.int32) - decay**count
        updates['moments'] = moments if not debias else jax.tree_map(
            lambda t: t / beta.astype(t.dtype), moments)

        return updates, (moments, count)

    return optax.GradientTransformation(init, update)
コード例 #10
0
def add_weight_decay(
    weight_decay: float,
    exclude_names: Optional[List[str]] = None) -> optax.GradientTransformation:
  """Add parameter scaled by `weight_decay` to the `updates`.

  Same as optax.add_decayed_weights but can exclude parameters by name.

  Args:
    weight_decay: weight_decay coefficient.
    exclude_names: an optional list of names to exclude for weight_decay. ['b']
      by default.

  Returns:
    An (init_fn, update_fn) tuple.
  """

  def init_fn(_):
    return AddWeightDecayState()

  def update_fn(updates, state, params):
    exclude = _weight_decay_exclude(exclude_names=exclude_names)

    u_ex, u_in = hk.data_structures.partition(exclude, updates)
    _, p_in = hk.data_structures.partition(exclude, params)
    u_in = jax.tree_multimap(lambda g, p: g + weight_decay * p, u_in, p_in)
    updates = hk.data_structures.merge(u_ex, u_in)
    return updates, state

  return optax.GradientTransformation(init_fn, update_fn)
コード例 #11
0
def precondition_by_amsgrad(
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
) -> optax.GradientTransformation:
    """Rescale updates according to the AMSGrad algorithm.

  References:
    [Reddi et al, 2018](https://arxiv.org/abs/1904.09237v1)

  Args:
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        return PreconditionBySecondMomentCoordinateWiseState(
            count=jnp.zeros([], jnp.int32),
            nu=jax.tree_map(jnp.zeros_like, params))

    def update_fn(updates, state, params=None):
        del params
        nu = _update_moment(updates, state.nu, b2, 2)
        count = state.count + jnp.array(1, dtype=jnp.int32)
        nu_hat = jax.tree_multimap(jnp.maximum, nu, state.nu)
        updates = jax.tree_multimap(
            lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat)
        return updates, PreconditionBySecondMomentCoordinateWiseState(
            count=count, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #12
0
def nth_power(
        power: Union[int, Tuple[int]] = 2) -> optax.GradientTransformation:
    """Create nth power(s) from gradients."""

    if not hasattr(power, '__iter__'):
        power = [power]

    for p in power:
        if p != int(p):
            raise ValueError(
                f'Currently we only support integer orders; got {p}.')

    def init(params: optax.Params) -> optax.OptState:
        del params
        return None

    def update(
        updates: optax.Updates,
        state: optax.OptState,
        params: Optional[optax.Params] = None
    ) -> Tuple[optax.Updates, optax.OptState]:
        del params

        for p in power:
            if p == 1:
                updates['variables'][str(int(p))] = updates['updates']
            else:
                gradients = jax.tree_map(lambda x: x**p, updates['updates'])  # pylint: disable=cell-var-from-loop
                updates['variables'][str(int(p))] = gradients

        return updates, state

    return optax.GradientTransformation(init, update)
コード例 #13
0
def scale_selected_parameters(regex, multiplier):
    """Creates an optimizer that multiply a selection of weights by `multiplier`.

  Args:
    regex: The regex that matches the (flatten) parameters whose learning rate
      should be scaled.
    multiplier: The scaling factor applied to matching parameters.

  Returns:
    A chainable optimizer.
  """
    def init_fn(params):
        flatten_params = parameter_overview.flatten_dict(params)
        multipliers = {
            k: multiplier if re.match(regex, k) else 1.
            for k, _ in flatten_params.items()
        }
        logging.info(
            "Optimizer uses multiplier %f for weights: %s", multiplier,
            sorted(
                [k for k, _ in flatten_params.items() if re.match(regex, k)]))
        multipliers = unflatten_dict(multipliers)
        return MultipliersState(multipliers)

    def update_fn(updates, state, params=None):
        del params
        multiplied_updates = jax.tree_multimap(
            lambda m, update: jax.tree_map(lambda u: u * m, update),
            state.multipliers, updates)
        return multiplied_updates, state

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #14
0
def decoupled_weight_decay(decay, step_size_fn):
    """Adds decay * step_size_fn(count) to updates.

  Args:
    decay: the decay coefficient for weight decay.
    step_size_fn: a function that takes an update count as input and proposes
      the step_size to multiply the params by.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return DecoupledWeightDecayState(count=jnp.zeros([], jnp.int32),
                                         step_size=jnp.zeros([], jnp.float32))

    def update_fn(updates, state, params=None):
        step_size = step_size_fn(state.count) * decay
        updates = jax.tree_multimap(lambda u, p: u - step_size * p, updates,
                                    params)

        # does a _safe_int32_increment
        max_int32_value = jnp.iinfo(jnp.int32).max
        new_count = jnp.where(state.count < max_int32_value, state.count + 1,
                              max_int32_value)
        new_state = DecoupledWeightDecayState(count=new_count,
                                              step_size=step_size)

        return updates, new_state

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #15
0
ファイル: test_transform.py プロジェクト: google/init2winit
        def twisted_adam():
            def init_fn(params):
                return State(nu=jax.tree_map(jnp.zeros_like, params),
                             trace=jax.tree_map(jnp.zeros_like, params),
                             count=jnp.zeros([], jnp.int32))

            def update_fn(updates, state, params=None):
                del params
                count = state.count + jnp.array(1, jnp.int32)
                nu = {
                    'w': (1 - rms_decay) * (updates['w']**2) +
                    rms_decay * state.nu['w']
                }
                updates = {
                    'w':
                    updates['w'] / (jax.lax.sqrt(nu['w'] + eps_root) + eps)
                }

                updates = {
                    'w': updates['w'] * jnp.sqrt((1 - rms_decay**count))
                }

                trace = {
                    'w': (1 - moment_decay) * updates['w'] +
                    moment_decay * state.trace['w']
                }
                updates = {'w': trace['w']}

                updates = {'w': updates['w'] / (1 - moment_decay**count)}

                return updates, State(nu=nu, count=count, trace=trace)

            return optax.GradientTransformation(init_fn, update_fn)
コード例 #16
0
def transform_chain(
    elements: List[str],
    hps: List[Dict[str, float]] = None,
    masks: List[Any] = None,
    learning_rate: float = None) -> optax.GradientTransformation:
  """Utility function for chaining GradientTransforms based on string names.

  Args:
    elements: list of transform strings.
    hps: list of dicts of args for each transform.
    masks: list of masks for each transform.
    learning_rate: learning rate that gets injected.

  Returns:
    optax.GradientTransform
  """

  hps = hps or [{}] * len(elements)
  masks = masks or [None] * len(elements)
  transforms = []

  if len(hps) != len(elements):
    raise ValueError('Number of hps must equal number of elements.')

  if len(masks) != len(elements):
    raise ValueError('Number of masks must equal number of elements.')

  transforms = [_transformations[el](**hp) for el, hp in zip(elements, hps)]

  for i, (transform, mask) in enumerate(zip(transforms, masks)):
    if mask is not None:
      transforms[i] = optax.masked(transform, mask)

  if learning_rate is not None:
    transforms += [scale_by_learning_rate(learning_rate)]

  init_fn, update_fn = optax.chain(*transforms)

  # NOTE(dsuo): We use plain dicts internally due to this issue
  # https://github.com/deepmind/optax/issues/160.
  def wrapped_init_fn(params):
    return init_fn(flax.core.unfreeze(params))

  def wrapped_update_fn(updates, state, params=None):
    new_updates, state = update_fn(
        flax.core.unfreeze(updates), state,
        None if params is None else flax.core.unfreeze(params))

    if isinstance(updates, flax.core.FrozenDict):
      new_updates = flax.core.freeze(new_updates)

    return new_updates, state

  return optax.GradientTransformation(wrapped_init_fn, wrapped_update_fn)
コード例 #17
0
ファイル: test_transform.py プロジェクト: google/init2winit
        def ema(decay, debias=True):
            def init_fn(params):
                del params
                return {'w': jnp.zeros((2, )), 'count': 0}

            def update_fn(updates, state, params=None):
                del params
                state['count'] += 1
                state['w'] = ((1 - decay) * updates['w'] + decay * state['w'])
                if debias:
                    update = {'w': state['w'] / (1 - decay**state['count'])}
                else:
                    update = {'w': state['w']}
                return update, state

            return optax.GradientTransformation(init_fn, update_fn)
コード例 #18
0
ファイル: optimizers.py プロジェクト: afcarl/google-research
def maybe_skip_gradient_update(
    inner,
    gradient_norm_skip_threshold,
):
    """A function that wraps an optimiser to skip updates under some condition.

  The purpose of this function is to prevent any optimisation to happen if the
  gradients contain NaNs, Infs, or if its norm is higher than a certain
  threshold. That is, when a NaN of Inf, is detected in the gradients or when
  the norm of the gradient is higher than the threshold, the wrapped optimiser
  ignores that gradient update.

  Args:
    inner: Inner transformation to be wrapped.
    gradient_norm_skip_threshold: float,

  Returns:
    New GradientTransformation.
  """
    def init(params):
        return MaybeSkipGradientUpdateState(inner_state=inner.init(params))

    def update(updates, state, params=None):
        inner_state = state.inner_state
        # Compute gradient norm and clip gradient if necessary
        gradient_norm = optax.global_norm(updates)
        flat_updates = jax.tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        islowerthan = gradient_norm < gradient_norm_skip_threshold

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (jax.tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = jax.lax.cond(jnp.logical_and(
            isfinite, islowerthan),
                                                do_update,
                                                reject_update,
                                                operand=None)

        return updates, MaybeSkipGradientUpdateState(
            inner_state=new_inner_state)

    return optax.GradientTransformation(init=init, update=update)
コード例 #19
0
def scale_by_nadam(b1: float = 0.9,
                   b2: float = 0.999,
                   eps: float = 1e-8,
                   eps_root: float = 0.0,
                   debias: bool = True) -> optax.GradientTransformation:
    """Rescale updates according to the NAdam algorithm.

  References:
  There seem to be multiple versions of NAdam. The original version is here
  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also
  follows this)

  Current code implements a simpler version with no momentum decay and slightly
  different bias correction terms. The exact description can be found here
  https://arxiv.org/pdf/1910.05446.pdf (Table 1)

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    debias: whether to use bias correction.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        mu = jax.tree_map(jnp.zeros_like, params)  # First moment
        nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
        return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, b1, 1)
        nu = _update_moment(updates, state.nu, b2, 2)
        count = state.count + jnp.array(1, dtype=jnp.int32)
        mu_hat = _update_moment(updates, mu, b1, 1)
        mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
        nu_hat = nu if not debias else _bias_correction(nu, b2, count)
        updates = jax.tree_multimap(
            lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
        return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #20
0
def preconditioner(
    grad_transformer,
    accumulator,
    updater,
    moment_creator_args,
    accumulator_args,
    updater_args,
) -> optax.GradientTransformation:
    """Generic precondition update function."""

    grad_transformer = grad_transformer(**moment_creator_args)
    accumulator = accumulator(**accumulator_args)
    updater = updater(**updater_args)

    def init(params: optax.Params) -> optax.OptState:
        """`init` function."""
        grads_state = grad_transformer.init(params)

        # NOTE(dsuo): assumes params and updates have the same shape.
        updates = {'updates': params, 'variables': {}}
        grads, _ = grad_transformer.update(updates, grads_state, params)

        # NOTE(dsuo): assume accumulator only needs `gradients`.
        accumulator_state = accumulator.init(grads['variables'])

        updater_state = updater.init(params)
        return (grads_state, accumulator_state, updater_state)

    def update(updates, state, params=None):
        """`update` function."""
        updates = {
            'updates': updates,
            'variables': {},
            'moments': {},
            'output': None
        }
        new_state = []
        for s, transform in zip(state,
                                [grad_transformer, accumulator, updater]):
            updates, new_s = transform.update(updates, s, params)
            new_state.append(new_s)

        return updates['output'], tuple(new_state)

    return optax.GradientTransformation(init, update)
コード例 #21
0
ファイル: attacks.py プロジェクト: yynst2/deepmind-research
    def __init__(self,
                 learning_rate_fn: Union[float, int, optax.Schedule],
                 normalize_fn: Optional[NormalizeFn] = None):
        # Accept schedules, as well as scalar values.
        if isinstance(learning_rate_fn, (float, int)):
            lr = float(learning_rate_fn)
            learning_rate_fn = lambda _: lr
        # Normalization.
        def update_fn(updates, state, params=None):
            del params
            updates = jax.tree_map(normalize_fn or (lambda x: x), updates)
            return updates, state

        gradient_transformation = optax.chain(
            optax.GradientTransformation(lambda _: optax.EmptyState(),
                                         update_fn),
            optax.scale_by_schedule(learning_rate_fn), optax.scale(-1.))
        super(SGD, self).__init__(gradient_transformation)
コード例 #22
0
def kitchen_sink(chains: List[optax.GradientTransformation],
                 scales: jnp.array = None,
                 combinator: Union[Callable[[Any, Any], Any], str] = 'sum',
                 combinator_args: Dict[str, float] = None,
                 learning_rate: float = None) -> optax.GradientTransformation:
  """Runs a list of GradientTransforms in parallel and combines.

  Args:
    chains: list of optax.GradientTransforms (typically from transform_chain).
    scales: a (len(chains),)-shaped jnp.array.
    combinator: a combinator that reduces a list of identical pytrees
    combinator_args: a dictionary of keyword arguments to the combinator func.
    learning_rate: learning rate that gets injected.

  Returns:
    optax.GradientTransform
  """
  if isinstance(combinator, str):
    combinator = _combinators.get(combinator, _sum_combinator)
  combinator_args = combinator_args or {}

  if scales is None:
    scales = jnp.ones(len(chains))

  chains = [
      optax.chain(chain, optax.scale(scale))
      for chain, scale in zip(chains, scales)
  ]

  def init_fn(params):
    return [chain.init(params) for chain in chains]

  def update_fn(updates, state, params=None):
    result = [chain.update(updates, chain_state, params)
              for chain, chain_state in zip(chains, state)]
    new_updates, new_state = list(zip(*result))
    return combinator(*new_updates, **combinator_args), new_state

  transform = optax.GradientTransformation(init_fn, update_fn)

  if learning_rate is not None:
    transform = optax.chain(transform, scale_by_learning_rate(learning_rate))

  return transform
コード例 #23
0
def precondition_by_yogi(b2: float = 0.999,
                         eps: float = 1e-8,
                         eps_root: float = 0.0,
                         initial_accumulator_value: float = 1e-6,
                         debias: bool = True) -> optax.GradientTransformation:
    """Preconditions updates according to the Yogi Preconditioner.

  References:
    [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long

  Args:
    b2: decay rate for the exponentially weighted average of moments of grads.
    eps: Term added to the denominator to improve numerical stability.
      The default is changed to 1e-8. Optax Yogi's default is 1e-3.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    initial_accumulator_value: The starting value for accumulators.
      Only positive values are allowed.
    debias: whether to use bias correction or not

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        value_like = lambda p: jnp.full_like(p, initial_accumulator_value)
        nu = jax.tree_map(value_like, params)  # Second Central moment
        return PreconditionBySecondMomentCoordinateWiseState(count=jnp.zeros(
            [], jnp.int32),
                                                             nu=nu)

    def update_fn(updates, state, params=None):
        del params
        nu = jax.tree_multimap(
            lambda g, v: v - (1 - b2) * jnp.sign(v - g**2) * (g**2), updates,
            state.nu)
        count = state.count + jnp.array(1, dtype=jnp.int32)
        nu_hat = nu if not debias else _bias_correction(nu, b2, count)
        updates = jax.tree_multimap(
            lambda u, v: u / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat)
        return updates, PreconditionBySecondMomentCoordinateWiseState(
            count=count, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #24
0
def precondition_by_rms(
    decay: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    debias: bool = False,
) -> optax.GradientTransformation:
    """Preconditions updates according to the RMS Preconditioner from Adam.

  References:
    [Kingma, Ba 2015] https://arxiv.org/pdf/1412.6980.pdf

  Args:
    decay: decay rate for exponentially weighted average of moments of grads.
    eps: Term added to the denominator to improve numerical stability.
      The default is kept to 1e-8 to match optax Adam implementation.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    debias: whether to use bias correction or not

  Gotcha:
    Note that the usage of epsilon and defaults are different from optax's
    scale_by_rms. This matches optax's adam template.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        return PreconditionBySecondMomentCoordinateWiseState(
            count=jnp.zeros([], jnp.int32),
            nu=jax.tree_map(jnp.zeros_like, params))

    def update_fn(updates, state, params=None):
        del params
        nu = _update_moment(updates, state.nu, decay, 2)
        count = state.count + jnp.array(1, dtype=jnp.int32)
        nu_hat = nu if not debias else _bias_correction(nu, decay, count)
        updates = jax.tree_multimap(
            lambda u, v: u / (jnp.sqrt(v + eps_root) + eps), updates, nu_hat)
        return updates, PreconditionBySecondMomentCoordinateWiseState(
            count=count, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #25
0
ファイル: transform.py プロジェクト: rbktech/netket
def scale_by_adam(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0,
    mu_dtype=None,
) -> optax.GradientTransformation:
    """Rescale updates according to the Adam algorithm.
    References:
      [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
    Args:
      b1: decay rate for the exponentially weighted average of grads.
      b2: decay rate for the exponentially weighted average of squared norm of grads.
      eps: term added to the denominator to improve numerical stability.
      eps_root: term added to the denominator inside the square-root to improve
        numerical stability when backpropagating gradients through the rescaling.
      mu_dtype: optional `dtype` to be used for the first order accumulator; if
        `None` then the `dtype is inferred from `params` and `updates`.
    Returns:
      An (init_fn, update_fn) tuple.
    """
    mu_dtype = utils.canonicalize_dtype(mu_dtype)

    def init_fn(params):
        mu = jax.tree_map(lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
        nu = jax.tree_map(jnp.zeros_like, params)
        return optax.ScaleByAdamState(count=jnp.zeros([], jnp.int32),
                                      mu=mu,
                                      nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, b1, 1)
        nu = _update_moment_norm(updates, state.nu, b2, 2)
        count_inc = utils.safe_int32_increment(state.count)
        mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype)
        nu_hat = _bias_correction(nu, b2, count_inc)
        updates = jax.tree_map(lambda m, v: m / (jnp.sqrt(v + eps_root) + eps),
                               mu_hat, nu_hat)
        return updates, optax.ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #26
0
ファイル: test_transform.py プロジェクト: google/init2winit
        def amsgrad():
            adam = optax.scale_by_adam()

            def init_fn(params):
                return adam.init(params)

            def update_fn(updates, state, params=None):
                prev_nu = state.nu
                _, state = adam.update(updates, state, params)
                curr_nu = state.nu
                nu_hat = jax.tree_multimap(jnp.maximum, curr_nu, prev_nu)
                updates = jax.tree_multimap(
                    lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu,
                    nu_hat)

                return updates, optax.ScaleByAdamState(count=state.count,
                                                       mu=state.mu,
                                                       nu=nu_hat)

            return optax.GradientTransformation(init_fn, update_fn)
コード例 #27
0
ファイル: lars.py プロジェクト: rwightman/efficientnet-jax
def scale_by_lars(
        momentum: float = 0.9,
        eta: float = 0.001,
        filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation:
    """Rescales updates according to the LARS algorithm.

    Does not include weight decay.
    References:
        [You et al, 2017](https://arxiv.org/abs/1708.03888)

    Args:
        momentum: momentum coeficient.
        eta: LARS coefficient.
        filter_fn: an optional filter function.

    Returns:
        An (init_fn, update_fn) tuple.
    """

    def init_fn(params: optax.Params) -> ScaleByLarsState:
        mu = jax.tree_multimap(jnp.zeros_like, params)  # momentum
        return ScaleByLarsState(mu=mu)

    def update_fn(updates: optax.Updates, state: ScaleByLarsState,
                  params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]:
        def lars_adaptation(
                update: jnp.ndarray,
                param: jnp.ndarray,
        ) -> jnp.ndarray:
            param_norm = jnp.linalg.norm(param)
            update_norm = jnp.linalg.norm(update)
            return update * jnp.where(
                param_norm > 0.,
                jnp.where(update_norm > 0, (eta * param_norm / update_norm), 1.0), 1.0)

        adapted_updates = jax.tree_multimap(lars_adaptation, updates, params)
        adapted_updates = _partial_update(updates, adapted_updates, params, filter_fn)
        mu = jax.tree_multimap(lambda g, t: momentum * g + t, state.mu, adapted_updates)
        return mu, ScaleByLarsState(mu=mu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #28
0
def scale_by_amsgrad(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
) -> optax.GradientTransformation:
    """Rescale updates according to the AMSGrad algorithm.

  References:
    [Reddi et al, 2018](https://arxiv.org/abs/1904.09237v1)

  Args:
    b1: decay rate for the exponentially weighted average of grads.
    b2: decay rate for the exponentially weighted average of squared grads.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    mu_dtype: optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype is inferred from `params` and `updates`.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(params):
        mu = jax.tree_map(  # First moment
            lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
        nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
        return ScaleByAMSGradState(mu=mu, nu=nu)

    def update_fn(updates, state, params=None):
        del params
        mu = _update_moment(updates, state.mu, b1, 1)
        nu = _update_moment(updates, state.nu, b2, 2)
        nu_hat = jax.tree_multimap(jnp.maximum, nu, state.nu)
        updates = jax.tree_multimap(
            lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu, nu_hat)
        return updates, ScaleByAMSGradState(mu=mu, nu=nu)

    return optax.GradientTransformation(init_fn, update_fn)
コード例 #29
0
def get_transformations_with_vectorized_repeat_prefix(
        tx: GeneralGradientTransformation,
        var_params: NestedParams) -> GeneralGradientTransformation:
    """Vectorizes a transformation on shape/sharding prefixes."""
    def _init(variables):
        return init_with_vectorized_repeat_prefix(tx, variables, var_params)

    def _update(updates, state, params=None):
        return update_with_vectorized_repeat_prefix(tx, updates, state, params,
                                                    var_params)

    def _init_partition_spec(var_param_args):
        assert isinstance(tx, ShardedGradientTransformation)
        return init_partition_spec_with_vectorized_repeat_prefix(
            tx, var_param_args)

    if isinstance(tx, ShardedGradientTransformation):
        return ShardedGradientTransformation(
            init=_init,
            update=_update,
            init_partition_spec=_init_partition_spec)
    else:
        assert isinstance(tx, optax.GradientTransformation)
        return optax.GradientTransformation(init=_init, update=_update)
コード例 #30
0
def distributed_shampoo(learning_rate,
                        block_size,
                        beta1=0.9,
                        beta2=0.999,
                        diagonal_epsilon=1e-10,
                        matrix_epsilon=1e-6,
                        weight_decay=0.0,
                        start_preconditioning_step=1,
                        preconditioning_compute_steps=1,
                        statistics_compute_steps=1,
                        best_effort_shape_interpretation=True,
                        graft_type=GraftingType.SGD,
                        nesterov=True,
                        exponent_override=0,
                        batch_axis_name=None,
                        mesh_axis_names=None,
                        num_devices_for_pjit=None,
                        shard_optimizer_states=False,
                        inverse_failure_threshold=0.1,
                        moving_average_for_momentum=False,
                        skip_preconditioning_dim_size_gt=4096,
                        clip_by_scaled_gradient_norm=None,
                        precision=lax.Precision.HIGHEST):
  """Distributed Shampoo optimizer.

  Distributed Shampoo is a second-order preconditioned method (concretely, a
  variant of full-matrix Adagrad), that provides significant convergence and
  wall-clock time improvements compared to conventional first-order methods,
  and that has been shown to scale to large state-of-the-art deep learning
  models.

  References:
    Scalable Second Order Optimization for Deep Learning,
    Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer

    Preprint: https://arxiv.org/abs/2002.09018

  Args:
    learning_rate: the step size used to update the parameters.
    block_size: Block size for large layers (if > 0). Preconditioning compute
      operation is cubic in the dimension of the tensor. Block size allows us to
      chunk the layers into sub-layers of maximal dimension dictated by this
      value. Use 128 as default (increase if you have compute budget).
    beta1: momentum parameter.
    beta2: second moment averaging parameter.
    diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
      to AdaGrad is enabled).
    matrix_epsilon: epsilon to add to statistics before computing inverse pth
      root. If you are running in f32 precision for inverse pth root
      (recommended today) this can go upto 1e-6. If you have latest hardware
      with native f64 precision, set this upto 1e-12.
    weight_decay: Weight decay for regularization.
    start_preconditioning_step: When to start Shampoo update before which
      diagonal update is used. This is because we dont have enough information
      to do stable inverse.
    preconditioning_compute_steps: How often to compute preconditioner.
      Performance tuning params for controlling memory and compute requirements.
      Ideally set this and statistics_compute_steps params to 1.
    statistics_compute_steps: How often to compute statistics.
    best_effort_shape_interpretation: If there are some small dimensions,
      collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
      block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
    graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
      optimizer. This allows us to plugin the Shampoo optimizer into settings
      where SGD/AdaGrad is already well tuned. Available options are:
        GraftingType.SGD and GraftingType.ADAGRAD.
    nesterov: Nesterov momentum.
    exponent_override: Override the exponent used in matrix inverse.
    batch_axis_name: labeled axis over pmap for data-parallel training the
      optimizer used for.
    mesh_axis_names: Axis names for the mesh (used in pjit).
    num_devices_for_pjit: Number of devices to parallelize over when using pjit.
    shard_optimizer_states: Shard optimizer states to save memory in model
      parallel training.
    inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
      determine that using this threshold.
    moving_average_for_momentum: Whether to use moving average for momentum
      instead of exponential moving average.
    skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
        greater than this value.
    clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
      when using RMSProp Grafting).
    precision: precision XLA related flag, the available options are: a)
      lax.Precision.DEFAULT (better step time, but not precise) b)
      lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
      (best possible precision, slowest)

  Returns:
    a GradientTransformation.
  """

  def sharded_init_fn(params):
    params_flat, treedef = jax.tree_flatten(params)
    # Find max size to pad to.
    max_size = 0
    for param in params_flat:
      preconditioner = Preconditioner(param, block_size,
                                      best_effort_shape_interpretation)
      if not _skip_preconditioning(param):
        shapes = preconditioner.shapes_for_preconditioners()
        sizes = [s[0] for s in shapes]
        max_size = max(max(sizes), max_size)

    padded_statistics = []
    padded_preconditioners = []
    local_stats_flat = []
    for param in params_flat:
      preconditioner = Preconditioner(param, block_size,
                                      best_effort_shape_interpretation)
      shapes = preconditioner.shapes_for_preconditioners()
      sizes = []

      statistics = []
      preconditioners = []
      index_start = len(padded_statistics)
      if not _skip_preconditioning(param):
        sizes = [s[0] for s in shapes]
        shapes = preconditioner.shapes_for_preconditioners()
        statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
        preconditioners = [jnp.eye(max_size) for s in shapes]
        padded_statistics.extend(statistics)
        padded_preconditioners.extend(preconditioners)

      adagrad_statistics = []
      if graft_type != GraftingType.SGD:
        adagrad_statistics = jnp.zeros_like(param)
      local_stats_flat.append(
          LocalShardedParameterStats(adagrad_statistics, jnp.zeros_like(param),
                                     jnp.zeros_like(param), index_start, sizes))

    local_stats = jax.tree_unflatten(treedef, local_stats_flat)
    # Pad the statistics and preconditioner matrices to be a multiple of
    # num devices.
    # TODO(rohananil): Relax to only the size of the mesh axis where the dim
    # is split on.
    to_pad = -len(padded_statistics) % num_devices_for_pjit
    padded_statistics.extend([
        jnp.eye(max_size, dtype=padded_statistics[0].dtype)
        for _ in range(to_pad)
    ])
    padded_preconditioners.extend([
        jnp.eye(max_size, dtype=padded_statistics[0].dtype)
        for _ in range(to_pad)
    ])
    global_stats = GlobalShardedParameterStats(
        jnp.stack(padded_statistics), jnp.stack(padded_preconditioners))
    return ShampooState(
        count=jnp.zeros([], jnp.int32),
        stats=ShardedShampooStats(global_stats, local_stats))

  def sharded_update_fn(grads, state, params):
    """Transform the input gradient and update all statistics in sharded mode.

    Args:
      grads: the gradient tensors for the parameters.
      state: a named tuple containing the state of the optimizer
      params: the parameters that should be updated.

    Returns:
      A tuple containing the new parameters and the new optimizer state.
    """
    params_flat, treedef = jax.tree_flatten(params)
    grads_flat = treedef.flatten_up_to(grads)

    global_stats = state.stats.global_stats
    local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
    stats_flat = [
        _convert_to_parameter_stats(global_stats, local_stat)
        for local_stat in local_stats_flat
    ]
    new_stats_flat = jax.tree_multimap(
        lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
        stats_flat, params_flat)

    exponents = []
    for stat, param in zip(new_stats_flat, params_flat):
      num_statistics = len(stat.statistics)
      if num_statistics > 0:
        preconditioner = Preconditioner(param, block_size,
                                        best_effort_shape_interpretation)
        exponent = (
            preconditioner.exponent_for_preconditioner()
            if exponent_override == 0 else exponent_override)
        exponents.extend([exponent] * num_statistics)

    outputs = jax.tree_multimap(
        lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
        new_stats_flat, params_flat)
    updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

    updates = jax.tree_unflatten(treedef, updates_flat)
    # Create new local_stats
    new_local_stats_flat = [
        _convert_from_parameter_stats(new_stat, local_stat)
        for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
    ]
    new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)

    max_size = global_stats.statistics.shape[1]
    new_padded_statistics = []
    for stat in new_stats_flat:
      new_padded_statistics.extend(
          [pad_matrix(stat, max_size) for stat in stat.statistics])

    # Create global stats
    # TODO(rohananil): Preconditioner is not updated every step, so cost of
    # stack/pad can be obviated away.
    # Pad the statistics and preconditioner matrices to be a multiple of
    # num devices.
    # TODO(rohananil): Relax to only the size of the mesh axis where the dim
    # is split on.
    to_pad = -len(new_padded_statistics) % num_devices_for_pjit
    new_padded_statistics.extend([
        jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
        for _ in range(to_pad)
    ])
    exponents.extend([1 for _ in range(to_pad)])

    def _matrix_inverse_pth_root_vmap(xs, ps):
      mi_pth_root = functools.partial(
          matrix_inverse_pth_root,
          ridge_epsilon=matrix_epsilon,
          precision=precision)
      preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
      return preconditioners, errors

    def _internal_inverse_pth_root_all():
      preconditioners, errors = _matrix_inverse_pth_root_vmap(
          global_stats.statistics, jnp.stack(exponents))
      return preconditioners, errors

    if preconditioning_compute_steps == 1:
      new_preconditioners, errors = _internal_inverse_pth_root_all()
    else:
      # Passing statistics instead of preconditioners as they are similarly
      # shaped tensors. Note statistics will be ignored as we are passing in
      # a large init value for error.
      preconditioners_init = global_stats.statistics
      errors_init = np.stack([inverse_failure_threshold] * len(exponents))
      init_state = [preconditioners_init, errors_init]
      perform_step = state.count % preconditioning_compute_steps == 0
      new_preconditioners, errors = efficient_cond(
          perform_step, _internal_inverse_pth_root_all, init_state)

    errors = errors.reshape((-1, 1, 1))
    predicate = jnp.logical_or(
        jnp.isnan(errors),
        errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
    # TODO(rohananil): Check for numerical instabilities.
    new_conditional_preconditioners = (
        predicate * global_stats.preconditioners +
        (1.0 - predicate) * new_preconditioners)
    new_global_stats = GlobalShardedParameterStats(
        jnp.stack(new_padded_statistics), new_conditional_preconditioners)
    new_shampoo_state = ShampooState(
        count=state.count + 1,
        stats=ShardedShampooStats(new_global_stats, new_local_stats))
    return updates, new_shampoo_state

  def init_fn(params):
    """Initialise the optimiser's state."""

    def _init(param):
      preconditioner = Preconditioner(param, block_size,
                                      best_effort_shape_interpretation)
      statistics = []
      preconditioners = []
      if not _skip_preconditioning(param):
        shapes = preconditioner.shapes_for_preconditioners()
        statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
        preconditioners = [jnp.eye(s[0]) for s in shapes]

      adagrad_statistics = []
      if graft_type != GraftingType.SGD:
        adagrad_statistics = jnp.zeros_like(param)
      return ParameterStats(adagrad_statistics, statistics, preconditioners,
                            jnp.zeros_like(param), jnp.zeros_like(param))

    return ShampooState(
        count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))

  def _skip_preconditioning(param):
    return len(param.shape) < 1 or any(
        [s > skip_preconditioning_dim_size_gt for s in param.shape])

  def _compute_stats(grad, state, param, step):
    """Compute per-parameter statistics."""
    preconditioner = Preconditioner(param, block_size,
                                    best_effort_shape_interpretation)
    new_statistics = [[]] * len(state.statistics)
    w1 = beta2
    w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
    if not _skip_preconditioning(param):

      def compute_updated_statistics():
        new_stats = preconditioner.statistics_from_grad(grad)
        new_stats_accumulators = []
        for stat, stat_accumulator in zip(new_stats, state.statistics):
          new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat)
        return new_stats_accumulators

      if statistics_compute_steps > 1:
        perform_step = step % statistics_compute_steps == 0
        init_state = state.statistics
        new_statistics = list(
            efficient_cond(perform_step, compute_updated_statistics,
                           init_state))
      else:
        new_statistics = compute_updated_statistics()
    return ParameterStats(state.diagonal_statistics, new_statistics,
                          state.preconditioners, state.diagonal_momentum,
                          state.momentum)

  def _compute_preconditioners(states, params, step):
    """Compute preconditioners for statistics."""
    statistics = []
    num_statistics_per_state = []
    original_shapes = []
    exponents = []
    max_size = 0
    prev_preconditioners = []
    for state, param in zip(states, params):
      num_statistics = len(state.statistics)
      num_statistics_per_state.append(num_statistics)
      original_shapes_for_state = []
      if num_statistics > 0:
        preconditioner = Preconditioner(param, block_size,
                                        best_effort_shape_interpretation)
        for statistic in state.statistics:
          exponents.append(preconditioner.exponent_for_preconditioner(
          ) if exponent_override == 0 else exponent_override)
          original_shapes_for_state.append(statistic.shape)
          max_size = max(max_size, statistic.shape[0])
        statistics.extend(state.statistics)
        prev_preconditioners.extend(state.preconditioners)
        original_shapes.extend(original_shapes_for_state)
    num_statistics = len(statistics)

    if batch_axis_name:
      num_devices = lax.psum(1, batch_axis_name)

      # Pad statistics and exponents to next multiple of num_devices.
      packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
      to_pad = -num_statistics % num_devices
      packed_statistics.extend([
          jnp.eye(max_size, dtype=packed_statistics[0].dtype)
          for _ in range(to_pad)
      ])
      exponents.extend([1 for _ in range(to_pad)])

      if not packed_statistics:
        return states
      # Batch statistics and exponents so that so that leading axis is
      # num_devices.
      def _batch(statistics, exponents, num_devices):
        assert len(statistics) == len(exponents)
        n = len(statistics)
        b = int(n / num_devices)
        batched_statistics = [
            jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
        ]
        batched_exponents = [
            jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
        ]
        return jnp.stack(batched_statistics), jnp.stack(batched_exponents)

      # Unbatch values across leading axis and return a list of elements.
      def _unbatch(batched_values):
        b1, b2 = batched_values.shape[0], batched_values.shape[1]
        results = []
        for v_array in jnp.split(
            batched_values, indices_or_sections=b1, axis=0):
          v_array = jnp.squeeze(v_array)
          # b2 = batches (number of preconditioner computation) per core.
          if b2 > 1:
            for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
              results.append(jnp.squeeze(v))
          else:
            results.append(v_array)
        return results

      all_statistics, all_exponents = _batch(packed_statistics, exponents,
                                             num_devices)
    else:
      to_pad = -num_statistics % num_devices_for_pjit
      padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
      padded_statistics.extend([
          jnp.eye(max_size, dtype=padded_statistics[0].dtype)
          for _ in range(to_pad)
      ])
      exponents.extend([1 for _ in range(to_pad)])
      all_statistics = jnp.stack(padded_statistics)
      all_exponents = jnp.stack(exponents)

    def _matrix_inverse_pth_root_vmap(xs, ps):
      mi_pth_root = functools.partial(
          matrix_inverse_pth_root,
          ridge_epsilon=matrix_epsilon,
          precision=precision)
      preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
      return preconditioners, errors

    def _matrix_inverse_pth_root_pjit(xs, ps):
      mesh_axis_names_tuple = tuple(mesh_axis_names)
      # Partition the concatenated statistics matrix across all cores.
      partitioned_xs, partitioned_ps = pjit.pjit(
          lambda x, y: (x, y),
          in_axis_resources=None,
          out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
      # Run matrix inverse pth root on each shard.
      partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
          partitioned_xs, partitioned_ps)
      # Recombine the outputs at each core.
      preconditioners, errors = pjit.pjit(
          lambda x, y: (x, y),
          in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
                             pjit.PartitionSpec(mesh_axis_names_tuple,)),
          out_axis_resources=(None, None))(partitioned_preconditioners,
                                           partitioned_errors)
      return preconditioners, errors

    if not batch_axis_name:
      def _internal_inverse_pth_root_all():
        preconditioners, errors = _matrix_inverse_pth_root_pjit(
            all_statistics, all_exponents)
        b1 = preconditioners.shape[0]
        def split(batched_values):
          return [
              jnp.squeeze(v) for v in jnp.split(
                  batched_values, indices_or_sections=b1, axis=0)
          ]

        return split(preconditioners), split(errors)

      if preconditioning_compute_steps == 1:
        preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
      else:
        # Passing statistics instead of preconditioners as they are similarly
        # shaped tensors. Note statistics will be ignored as we are passing in
        # a large init value for error.
        preconditioners_init = padded_statistics
        errors_init = [inverse_failure_threshold] * len(padded_statistics)
        init_state = [preconditioners_init, errors_init]
        perform_step = step % preconditioning_compute_steps == 0
        preconditioners_flat, errors_flat = efficient_cond(
            perform_step, _internal_inverse_pth_root_all, init_state)
    else:

      def _internal_inverse_pth_root_all():
        preconditioners = jnp.array(all_statistics)
        current_replica = lax.axis_index(batch_axis_name)
        preconditioners, errors = _matrix_inverse_pth_root_vmap(
            all_statistics[current_replica], all_exponents[current_replica])
        preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
        errors = jax.lax.all_gather(errors, batch_axis_name)
        preconditioners_flat = _unbatch(preconditioners)
        errors_flat = _unbatch(errors)
        return preconditioners_flat, errors_flat

      if preconditioning_compute_steps == 1:
        preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
      else:
        # Passing statistics instead of preconditioners as they are similarly
        # shaped tensors. Note statistics will be ignored as we are passing in
        # a large init value for error.
        preconditioners_init = packed_statistics
        errors_init = ([inverse_failure_threshold] * len(packed_statistics))
        init_state = [preconditioners_init, errors_init]
        perform_step = step % preconditioning_compute_steps == 0
        preconditioners_flat, errors_flat = efficient_cond(
            perform_step, _internal_inverse_pth_root_all, init_state)

    def _skip(error):
      condition = jnp.logical_or(
          jnp.isnan(error), error >= inverse_failure_threshold)
      return condition.astype(error.dtype)

    def _select_preconditioner(error, new_p, old_p):
      return lax.cond(
          _skip(error), lambda _: old_p, lambda _: new_p, operand=None)

    new_preconditioners_flat = []
    for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
                                       prev_preconditioners, errors_flat):
      new_preconditioners_flat.append(
          _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))

    assert len(states) == len(num_statistics_per_state)
    assert len(new_preconditioners_flat) == num_statistics

    # Add back empty preconditioners so we that we can set the optimizer state.
    preconditioners_for_states = []
    idx = 0
    for num_statistics, state in zip(num_statistics_per_state, states):
      if num_statistics == 0:
        preconditioners_for_states.append([])
      else:
        preconditioners_for_state = new_preconditioners_flat[idx:idx +
                                                             num_statistics]
        assert len(state.statistics) == len(preconditioners_for_state)
        preconditioners_for_states.append(preconditioners_for_state)
        idx += num_statistics
    new_states = []
    for state, new_preconditioners in zip(states, preconditioners_for_states):
      new_states.append(
          ParameterStats(state.diagonal_statistics, state.statistics,
                         new_preconditioners, state.diagonal_momentum,
                         state.momentum))

    return new_states

  def _transform_grad(grad, state, param, step):
    """Transform per-parameter gradients."""
    preconditioner = Preconditioner(param, block_size,
                                    best_effort_shape_interpretation)
    sgd_update = grad
    new_diagonal_statistics = state.diagonal_statistics
    if graft_type == GraftingType.ADAGRAD:
      new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad)
      adagrad_update = grad / (
          jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
      grafting_update = adagrad_update
    elif (graft_type == GraftingType.RMSPROP or
          graft_type == GraftingType.RMSPROP_NORMALIZED):

      scaled_grad = grad
      if graft_type == GraftingType.RMSPROP_NORMALIZED:
        scaled_grad = grad / jnp.linalg.norm(grad)

      w1 = beta2
      w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)

      new_diagonal_statistics = (
          w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad))
      rmsprop_update = scaled_grad / (
          jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)

      if clip_by_scaled_gradient_norm:
        scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
            jnp.sqrt(float(rmsprop_update.size)))
        clipping_denom = jnp.maximum(
            1., scaled_grad_norm / clip_by_scaled_gradient_norm)
        rmsprop_update /= clipping_denom

      grafting_update = rmsprop_update
    else:
      grafting_update = sgd_update

    precond_grad = grad
    if not _skip_preconditioning(param):
      precond_grad = preconditioner.preconditioned_grad(precond_grad,
                                                        state.preconditioners)
    else:
      precond_grad = grafting_update

    grafting_update_norm = jnp.linalg.norm(grafting_update)
    precond_grad_norm = jnp.linalg.norm(precond_grad)

    multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
    shampoo_update = precond_grad * multiplier

    shampoo_update_with_wd = shampoo_update
    grafting_update_with_wd = grafting_update
    if weight_decay != 0:
      shampoo_update_with_wd = shampoo_update + weight_decay * param
      grafting_update_with_wd = grafting_update + weight_decay * param

    w = (1.0 - beta1) if moving_average_for_momentum else 1.0
    shampoo_update_with_wd_momentum = (
        state.momentum * beta1 + w * shampoo_update_with_wd)
    grafting_update_with_wd_momentum = (
        state.diagonal_momentum * beta1 + w * grafting_update_with_wd)

    run_shampoo = (step >= start_preconditioning_step).astype(
        grafting_update_with_wd_momentum.dtype)

    momentum_update = (
        run_shampoo * shampoo_update_with_wd_momentum +
        (1.0 - run_shampoo) * grafting_update_with_wd_momentum)

    wd_update = (
        run_shampoo * shampoo_update_with_wd +
        (1.0 - run_shampoo) * grafting_update_with_wd)

    if nesterov:
      momentum_update = w * wd_update + beta1 * momentum_update

    lr = learning_rate
    if callable(learning_rate):
      lr = learning_rate(step)
    transformed_update = -1.0 * lr * momentum_update

    param_stats = ParameterStats(new_diagonal_statistics, state.statistics,
                                 state.preconditioners,
                                 grafting_update_with_wd_momentum,
                                 shampoo_update_with_wd_momentum)
    return transformed_update, param_stats

  def update_fn(grads, state, params):
    """Transform the input gradient and update all statistics.

    Args:
      grads: the gradient tensors for the parameters.
      state: a named tuple containing the state of the optimizer
      params: the parameters that should be updated.

    Returns:
      A tuple containing the new parameters and the new optimizer state.
    """
    params_flat, treedef = jax.tree_flatten(params)
    stats_flat = treedef.flatten_up_to(state.stats)
    grads_flat = treedef.flatten_up_to(grads)

    new_stats_flat = jax.tree_multimap(
        lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
        stats_flat, params_flat)
    new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
                                              state.count)

    outputs = jax.tree_multimap(
        lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
        new_stats_flat, params_flat)
    updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())

    updates = jax.tree_unflatten(treedef, updates_flat)
    new_stats = jax.tree_unflatten(treedef, new_stats_flat)

    new_state = ShampooState(
        count=state.count+1, stats=new_stats)
    return updates, new_state

  if shard_optimizer_states:
    return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
  else:
    return optax.GradientTransformation(init_fn, update_fn)