Example #1
0
    def test_keep_params_nonnegative(self):
        grads = (jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.]),
                 jnp.array([500., -500., 0.]))

        params = (jnp.array([-1., -1.,
                             -1.]), jnp.array([1., 1.,
                                               1.]), jnp.array([0., 0., 0.]))

        # vanilla sgd
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR))
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(
            new_params, (jnp.array([-6., 4., -1.]), jnp.array(
                [-4., 6., 1.]), jnp.array([-5., 5., 0.])))

        # sgd with keeping parameters non-negative
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR),
                            constrain.keep_params_nonnegative())
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(new_params, (jnp.array(
            [0., 4., 0.]), jnp.array([0., 6., 1.]), jnp.array([0., 5., 0.])))
Example #2
0
def rmsprop(
    learning_rate: ScalarOrSchedule,
    decay: float = 0.9,
    eps: float = 1e-8,
    initial_scale: float = 0.,
    centered: bool = False,
    momentum: Optional[float] = None,
    nesterov: bool = False
) -> base.GradientTransformation:
  # pylint: disable=line-too-long
  """A flexible RMSProp optimiser.

  RMSProp is an SGD variant with learning rate adaptation. The `learning_rate`
  used for each weight is scaled by a suitable estimate of the magnitude of the
  gradients on previous steps. Several variants of RMSProp can be found
  in the literature. This alias provides an easy to configure RMSProp
  optimiser that can be used to switch between several of these variants.

  References:
    Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf
    Graves, 2013: https://arxiv.org/abs/1308.0850

  Args:
    learning_rate: this is a fixed global scaling factor.
    decay: the decay used to track the magnitude of previous gradients.
    eps: a small numerical constant to avoid dividing by zero when rescaling.
    initial_scale: (default `0.`), initialisation of accumulators tracking the
      magnitude of previous updates. PyTorch uses `0`, TF1 uses `1`. When
      reproducing results from a paper, verify the value used by the authors.
    centered: (default `False`), whether the second moment or the variance of
      the past gradients is used to rescale the latest gradients.
    momentum: (default `None`), the `decay` rate used by the momentum term,
      when it is set to `None`, then momentum is not used at all.
    nesterov (default `False`): whether nesterov momentum is used.

  Returns:
    the corresponding `GradientTransformation`.
  """
  # pylint: enable=line-too-long
  if centered:
    return combine.chain(
        transform.scale_by_stddev(
            decay=decay, eps=eps, initial_scale=initial_scale),
        _scale_by_learning_rate(learning_rate),
        (transform.trace(decay=momentum, nesterov=nesterov)
         if momentum is not None else base.identity())
    )
  return combine.chain(
      transform.scale_by_rms(
          decay=decay, eps=eps, initial_scale=initial_scale),
      _scale_by_learning_rate(learning_rate),
      (transform.trace(decay=momentum, nesterov=nesterov)
       if momentum is not None else base.identity())
  )
Example #3
0
def rmsprop(learning_rate: float,
            decay: float = 0.9,
            eps: float = 1e-8,
            centered: bool = False) -> GradientTransformation:
    if centered:
        return combine.chain(
            transform.scale_by_stddev(decay=decay, eps=eps),
            transform.scale(-learning_rate),
        )
    return combine.chain(
        transform.scale_by_rms(decay=decay, eps=eps),
        transform.scale(-learning_rate),
    )
Example #4
0
def adabelief(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-16,
    eps_root: float = 1e-16) -> base.GradientTransformation:
  """The AdaBelief optimiser.

  AdaBelief is an adaptive learning rate optimiser that focuses on fast
  convergence, generalisation, and stability. It adapts the step size depending
  on its "belief" in the gradient direction — the optimiser adaptively scales
  the step size by the difference between the predicted and observed gradients.
  AdaBelief is a modified version of Adam and contains the same number of
  parameters.

  References:
    Zhuang et al, 2020: https://arxiv.org/abs/2010.07468

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: term added to the denominator to improve numerical stability.
    eps_root: term added to the second moment of the prediction error to
      improve numerical stability. If backpropagating gradients through the
      gradient transformation (e.g. for meta-learning), this must be non-zero.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      _scale_by_learning_rate(learning_rate),
  )
Example #5
0
def yogi(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-3) -> base.GradientTransformation:
    """The Yogi optimiser.

  Yogi is an adaptive optimiser, which provides control in tuning the effective
  learning rate to prevent it from increasing. By doing so, it focuses on
  addressing the issues of convergence and generalisation in exponential moving
  average-based adaptive methods (such as Adam and RMSprop). Yogi is a
  modification of Adam and uses the same parameters.

  References:
    Zaheer et al, 2020: http://www.sanjivk.com/yogi_nips2018.pdf

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_yogi(b1=b1, b2=b2, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
Example #6
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    """The Frobenius matched gradient descent (Fromage) optimiser.

  Fromage is a learning algorithm that does not require learning rate tuning.
  The optimiser is based on modelling neural network gradients via deep relative
  trust (a distance function on deep neural networks). Fromage is similar to the
  LARS optimiser and can work on a range of standard neural network benchmarks,
  such as natural language Transformers and generative adversarial networks.

  References:
    Bernstein et al, 2020: https://arxiv.org/abs/2002.03432

  Args:
    learning_rate: this is a fixed global scaling factor.
    min_norm: a minimum value that the norm of the gradient updates and the
    norm of the layer parameters can be clipped to to avoid dividing by zero
    when computing the trust ratio (as in the LARS paper).

  Returns:
    the corresponding `GradientTransformation`.
  """
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Example #7
0
def noisy_sgd(learning_rate: ScalarOrSchedule,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> base.GradientTransformation:
    r"""A variant of SGD with added noise.

  It has been found that adding noise to the gradients can improve
  both the training error and the generalisation error in very deep networks.

  References:
    [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

  Args:
    learning_rate: this is a fixed global scaling factor.
    eta: the initial variance for the gaussian noise added to gradients.
    gamma: a parameter controlling the annealing of noise over time,
      the variance decays according to `(1+t)^-\gamma`.
    seed: the seed for the pseudo-random generation process.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        _scale_by_learning_rate(learning_rate),
        transform.add_noise(eta, gamma, seed),
    )
Example #8
0
def sgd(learning_rate: float,
        momentum: float = 0.,
        nesterov: bool = False) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=momentum, nesterov=nesterov),
        transform.scale(-learning_rate),
    )
Example #9
0
def adabelief(learning_rate: ScalarOrSchedule,
              b1: float = 0.9,
              b2: float = 0.999,
              eps: float = 1e-8) -> base.GradientTransformation:
    """The AdaBelief optimiser.

  AdaBelief is an adaptive learning rate optimiser that focuses on fast
  convergence, generalisation, and stability. It adapts the step size depending
  on its "belief" in the gradient direction — the optimiser adaptively scales
  the step size by the difference between the predicted and observed gradients.
  AdaBelief is a modified version of Adam and contains the same number of
  parameters.

  References:
    [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_belief(b1=b1, b2=b2, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
Example #10
0
def sgd(learning_rate: ScalarOrSchedule,
        momentum: float = 0.,
        nesterov: bool = False) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=momentum, nesterov=nesterov),
        _scale_by_learning_rate(learning_rate),
    )
Example #11
0
def sm3(learning_rate: float,
        momentum: float = 0.9) -> base.GradientTransformation:
    """The SM3 optimiser.

  SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a
  memory-efficient adaptive optimiser designed to decrease memory overhead when
  training very large models, such as the Transformer for machine translation,
  BERT for language modelling, and AmoebaNet-D for image classification. SM3: 1)
  applies to tensors of arbitrary dimensions and any predefined cover of the
  parameters; 2) adapts the learning rates in an adaptive and data-driven manner
  (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence
  guarantees in stochastic convex optimization settings.

  References:
    Anil et al, 2019: https://arxiv.org/abs/1901.11150

  Args:
    learning_rate: this is a fixed global scaling factor.
    momentum: the `decay` rate used by the momentum term (when it is not set to
      `None`, then momentum is not used at all).

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_sm3(momentum),
        transform.scale(-learning_rate),
    )
Example #12
0
  def test_chain(self):
    transformations = [
        transform.scale_by_adam(),
        transform.trace(decay=0, nesterov=False),
        transform.scale(-LR)]

    # Apply updates with chain.
    chain_params = self.init_params
    chained_transforms = combine.chain(*transformations)
    state = chained_transforms.init(chain_params)

    @self.variant
    def update_fn(updates, state):
      return chained_transforms.update(updates, state)

    for _ in range(STEPS):
      updates, state = update_fn(self.per_step_updates, state)
      chain_params = update.apply_updates(chain_params, updates)

    # Manually apply sequence of transformations.
    manual_params = self.init_params
    states = [t.init(manual_params) for t in transformations]
    for _ in range(STEPS):
      updates = self.per_step_updates
      new_states = []
      for t, s in zip(transformations, states):
        updates, state = t.update(updates, s)
        new_states.append(state)
      manual_params = update.apply_updates(manual_params, updates)
      states = new_states

    # Check equivalence.
    chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
Example #13
0
def adagrad(learning_rate: ScalarOrSchedule,
            initial_accumulator_value: float = 0.1,
            eps: float = 1e-7) -> base.GradientTransformation:
    """The Adagrad optimizer.

  Adagrad is an algorithm for gradient based optimisation that anneals the
  learning rate for each parameter during the course of training.

  WARNING: Adagrad's main limit is the monotonic accumulation of squared
  gradients in the denominator: since all terms are >0, the sum keeps growing
  during training and the learning rate eventually becomes vanishingly small.

  References:
    [Duchi et al, 2011](https://jmlr.org/papers/v12/duchi11a.html)

  Args:
    learning_rate: this is a fixed global scaling factor.
    initial_accumulator_value: initialisation for the accumulator.
    eps: a small constant applied to denominator inside of the square root
      (as in RMSProp) to avoid dividing by zero when rescaling.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_rss(
            initial_accumulator_value=initial_accumulator_value, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
Example #14
0
def sgd(
    learning_rate: ScalarOrSchedule,
    momentum: Optional[float] = None,
    nesterov: bool = False
) -> base.GradientTransformation:
  """A canonical Stochastic Gradient Descent optimiser.

  This implements stochastic gradient descent. It also includes support for
  momentum, and nesterov acceleration, as these are standard practice when
  using stochastic gradient descent to train deep neural networks.

  References:
    Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf

  Args:
    learning_rate: this is a fixed global scaling factor.
    momentum: (default `None`), the `decay` rate used by the momentum term,
      when it is set to `None`, then momentum is not used at all.
    nesterov (default `False`): whether nesterov momentum is used.

  Returns:
    A `GradientTransformation`.
  """
  return combine.chain(
      (transform.trace(decay=momentum, nesterov=nesterov)
       if momentum is not None else base.identity()),
      _scale_by_learning_rate(learning_rate)
  )
Example #15
0
def adam(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-8,
         eps_root: float = 0.0) -> base.GradientTransformation:
    """The classic Adam optimiser.

  Adam is an SGD variant with learning rate adaptation. The `learning_rate`
  used for each weight is computed from estimates of first- and second-order
  moments of the gradients (using suitable exponential moving averages).

  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0`), a small constant applied to denominator inside the
      square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for example when computing (meta-)gradients through Adam.

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        _scale_by_learning_rate(learning_rate),
    )
Example #16
0
def adagrad(learning_rate: float,
            initial_accumulator_value: float = 0.1,
            eps: float = 1e-7) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_rss(
            initial_accumulator_value=initial_accumulator_value, eps=eps),
        transform.scale(-learning_rate),
    )
Example #17
0
def adam(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-8) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.scale(-learning_rate),
    )
Example #18
0
def yogi(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-3) -> base.GradientTransformation:
    return combine.chain(
        transform.scale_by_yogi(b1=b1, b2=b2, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
Example #19
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> base.GradientTransformation:
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        _scale_by_learning_rate(learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Example #20
0
def adabelief(learning_rate: ScalarOrSchedule,
              b1: float = 0.9,
              b2: float = 0.999,
              eps: float = 1e-8) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_belief(b1=b1, b2=b2, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
Example #21
0
def noisy_sgd(learning_rate: float,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=0., nesterov=False),
        transform.scale(-learning_rate),
        transform.add_noise(eta, gamma, seed),
    )
Example #22
0
def adam(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-8,
         eps_root: float = 0.0) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        _scale_by_learning_rate(learning_rate),
    )
Example #23
0
def radam(learning_rate: ScalarOrSchedule,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          threshold: float = 5.0) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_radam(b1=b1, b2=b2, eps=eps, threshold=threshold),
        _scale_by_learning_rate(learning_rate),
    )
Example #24
0
def adamw(learning_rate: float,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          weight_decay: float = 1e-4) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.additive_weight_decay(weight_decay),
        transform.scale(-learning_rate),
    )
Example #25
0
def adamw(learning_rate: ScalarOrSchedule,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          eps_root: float = 0.0,
          weight_decay: float = 1e-4) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.additive_weight_decay(weight_decay),
        _scale_by_learning_rate(learning_rate),
    )
Example #26
0
def lamb(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> GradientTransformation:
  return combine.chain(
      transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      transform.additive_weight_decay(weight_decay),
      transform.scale_by_trust_ratio(),
      transform.scale(-learning_rate),
  )
Example #27
0
def lamb(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> base.GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
        transform.add_decayed_weights(weight_decay),
        transform.scale_by_trust_ratio(),
        _scale_by_learning_rate(learning_rate),
    )
Example #28
0
def adamw(
    learning_rate: ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
    weight_decay: float = 1e-4,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
  """Adam with weight decay regularization.

  AdamW uses weight decay to regularise learning towards small weights, as
  this leads to better generalisation. In SGD you can also use L2 regularisation
  to implement this as an additive loss term, however L2 regularization
  does not behave as intended for adaptive gradient algorithms such as Adam.

  WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or
  for the bias parameters. You can use `optax.masked` to make your own AdamW
  variant where `additive_weight_decay` is applied only to a subset of `params`.

  References:
    Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

  Args:
    learning_rate: this is a fixed global scaling factor.
    b1: the exponential decay rate to track the first moment of past gradients.
    b2: the exponential decay rate to track the second moment of past gradients.
    eps: a small constant applied to denominator outside of the square root
      (as in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: (default `0`), a small constant applied to denominator inside the
      square root (as in RMSProp), to avoid dividing by zero when rescaling.
      This is needed for instance when computing (meta-)gradients through Adam.
    mu_dtype: optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.
    weight_decay: strength of the weight decay regularization.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the weight decay to, and `False` for those you want to skip. Note
      that the Adam gradient transformations are applied to all parameters.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.scale_by_adam(
          b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
      transform.add_decayed_weights(weight_decay, mask),
      _scale_by_learning_rate(learning_rate),
  )
Example #29
0
def lars(
    learning_rate: ScalarOrSchedule,
    weight_decay: float = 0.,
    weight_decay_mask: MaskOrFn = True,
    trust_coefficient: float = 0.001,
    eps: float = 0.,
    trust_ratio_mask: MaskOrFn = True,
    momentum: float = 0.9,
    nesterov: bool = False,
) -> base.GradientTransformation:
  """The LARS optimiser.

  LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
  larger batch sizes. LARS later inspired the LAMB optimiser.

  References:
    You et al, 2017: https://arxiv.org/abs/1708.03888

  Args:
    learning_rate: this is a fixed global scaling factor.
    weight_decay (default `0.`): strength of the weight decay regularization.
    weight_decay_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    trust_coefficient: a multiplier for the trust ratio.
    eps: optional additive constant in the trust ratio denominator.
    trust_ratio_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    momentum: the decay rate for momentum.
    nesterov: whether to use Nesterov momentum.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
      wrappers.masked(
          inner=transform.scale_by_trust_ratio(
              trust_coefficient=trust_coefficient, eps=eps),
          mask=trust_ratio_mask),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=nesterov),
  )
Example #30
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # optax sgd
        optax_sgd_params = self.init_params
        sgd = alias.sgd(LR, 0.0)
        state_sgd = sgd.init(optax_sgd_params)

        # optax sgd plus apply every
        optax_sgd_apply_every_params = self.init_params
        sgd_apply_every = combine.chain(
            transform.apply_every(k=k), transform.trace(decay=0,
                                                        nesterov=False),
            transform.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optax_sgd_apply_every_params)
        transform_fn = self.variant(sgd_apply_every.update)

        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optax_sgd_params = update.apply_updates(optax_sgd_params,
                                                    updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = transform_fn(
                self.per_step_updates, state_sgd_apply_every)
            optax_sgd_apply_every_params = update.apply_updates(
                optax_sgd_apply_every_params, updates_sgd_apply_every)

            # Every k steps, check equivalence.
            if i % k == k - 1:
                chex.assert_tree_all_close(optax_sgd_apply_every_params,
                                           optax_sgd_params,
                                           atol=1e-6,
                                           rtol=1e-5)
            # Otherwise, check update is zero.
            else:
                chex.assert_tree_all_close(updates_sgd_apply_every,
                                           zero_update,
                                           atol=0.0,
                                           rtol=0.0)