Exemple #1
0
    def test_multi_transform(self, use_fn):
        params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10.0, params)
        tx_dict = {
            'a': transform.scale(-1.0),
            'b': transform.ema(0.0),  # stateful
            'c': transform.scale(2.0)
        }
        param_labels = _map_keys_fn(lambda k, _: k[0])
        if not use_fn:
            param_labels = param_labels(params)
        tx = combine.multi_transform(tx_dict, param_labels)
        update_fn = self.variant(tx.update)
        state = self.variant(tx.init)(params)

        correct_update_fn = _map_keys_fn(lambda k, v: {
            'a': -v,
            'b': v,
            'c': 2.0 * v
        }[k[0]])

        updates, state = update_fn(input_updates, state, params)
        correct_updates = correct_update_fn(input_updates)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = correct_update_fn(correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)
Exemple #2
0
    def test_keep_params_nonnegative(self):
        grads = (jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.]),
                 jnp.array([500., -500., 0.]))

        params = (jnp.array([-1., -1.,
                             -1.]), jnp.array([1., 1.,
                                               1.]), jnp.array([0., 0., 0.]))

        # vanilla sgd
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR))
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(
            new_params, (jnp.array([-6., 4., -1.]), jnp.array(
                [-4., 6., 1.]), jnp.array([-5., 5., 0.])))

        # sgd with keeping parameters non-negative
        opt = combine.chain(transform.trace(decay=0, nesterov=False),
                            transform.scale(-LR),
                            constrain.keep_params_nonnegative())
        opt_state = opt.init(params)

        updates, _ = opt.update(grads, opt_state, params)
        new_params = update.apply_updates(params, updates)

        chex.assert_tree_all_close(new_params, (jnp.array(
            [0., 4., 0.]), jnp.array([0., 6., 1.]), jnp.array([0., 5., 0.])))
Exemple #3
0
def rmsprop(learning_rate: float,
            decay: float = 0.9,
            eps: float = 1e-8,
            centered: bool = False) -> GradientTransformation:
    if centered:
        return combine.chain(
            transform.scale_by_stddev(decay=decay, eps=eps),
            transform.scale(-learning_rate),
        )
    return combine.chain(
        transform.scale_by_rms(decay=decay, eps=eps),
        transform.scale(-learning_rate),
    )
Exemple #4
0
def sm3(learning_rate: float,
        momentum: float = 0.9) -> base.GradientTransformation:
    """The SM3 optimiser.

  SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a
  memory-efficient adaptive optimiser designed to decrease memory overhead when
  training very large models, such as the Transformer for machine translation,
  BERT for language modelling, and AmoebaNet-D for image classification. SM3: 1)
  applies to tensors of arbitrary dimensions and any predefined cover of the
  parameters; 2) adapts the learning rates in an adaptive and data-driven manner
  (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence
  guarantees in stochastic convex optimization settings.

  References:
    Anil et al, 2019: https://arxiv.org/abs/1901.11150

  Args:
    learning_rate: this is a fixed global scaling factor.
    momentum: the `decay` rate used by the momentum term (when it is not set to
      `None`, then momentum is not used at all).

  Returns:
    the corresponding `GradientTransformation`.
  """
    return combine.chain(
        transform.scale_by_sm3(momentum),
        transform.scale(-learning_rate),
    )
Exemple #5
0
def sgd(learning_rate: float,
        momentum: float = 0.,
        nesterov: bool = False) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=momentum, nesterov=nesterov),
        transform.scale(-learning_rate),
    )
Exemple #6
0
  def test_chain(self):
    transformations = [
        transform.scale_by_adam(),
        transform.trace(decay=0, nesterov=False),
        transform.scale(-LR)]

    # Apply updates with chain.
    chain_params = self.init_params
    chained_transforms = combine.chain(*transformations)
    state = chained_transforms.init(chain_params)

    @self.variant
    def update_fn(updates, state):
      return chained_transforms.update(updates, state)

    for _ in range(STEPS):
      updates, state = update_fn(self.per_step_updates, state)
      chain_params = update.apply_updates(chain_params, updates)

    # Manually apply sequence of transformations.
    manual_params = self.init_params
    states = [t.init(manual_params) for t in transformations]
    for _ in range(STEPS):
      updates = self.per_step_updates
      new_states = []
      for t, s in zip(transformations, states):
        updates, state = t.update(updates, s)
        new_states.append(state)
      manual_params = update.apply_updates(manual_params, updates)
      states = new_states

    # Check equivalence.
    chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
Exemple #7
0
def fromage(learning_rate: float,
            min_norm: float = 1e-6) -> GradientTransformation:
    mult = 1 / jnp.sqrt(1 + learning_rate**2)
    return combine.chain(
        transform.scale_by_trust_ratio(min_norm),
        transform.scale(-learning_rate * mult),
        transform.add_decayed_weights((mult - 1)),
    )
Exemple #8
0
def adam(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-8) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.scale(-learning_rate),
    )
Exemple #9
0
def adagrad(learning_rate: float,
            initial_accumulator_value: float = 0.1,
            eps: float = 1e-7) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_rss(
            initial_accumulator_value=initial_accumulator_value, eps=eps),
        transform.scale(-learning_rate),
    )
Exemple #10
0
def noisy_sgd(learning_rate: float,
              eta: float = 0.01,
              gamma: float = 0.55,
              seed: int = 0) -> GradientTransformation:
    return combine.chain(
        transform.trace(decay=0., nesterov=False),
        transform.scale(-learning_rate),
        transform.add_noise(eta, gamma, seed),
    )
Exemple #11
0
def radam(learning_rate: float,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          threshold: float = 5.0) -> GradientTransformation:
  return combine.chain(
      transform.scale_by_radam(b1=b1, b2=b2, eps=eps, threshold=threshold),
      transform.scale(-learning_rate),
  )
Exemple #12
0
def adamw(learning_rate: float,
          b1: float = 0.9,
          b2: float = 0.999,
          eps: float = 1e-8,
          weight_decay: float = 1e-4) -> GradientTransformation:
    return combine.chain(
        transform.scale_by_adam(b1=b1, b2=b2, eps=eps),
        transform.additive_weight_decay(weight_decay),
        transform.scale(-learning_rate),
    )
Exemple #13
0
def lamb(learning_rate: float,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-6,
         eps_root: float = 0.0,
         weight_decay: float = 0.) -> GradientTransformation:
  return combine.chain(
      transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
      transform.additive_weight_decay(weight_decay),
      transform.scale_by_trust_ratio(),
      transform.scale(-learning_rate),
  )
Exemple #14
0
 def test_scale(self):
   updates = self.per_step_updates
   for i in range(1, STEPS + 1):
     factor = 0.1 ** i
     rescaler = transform.scale(factor)
     # Apply rescaling.
     scaled_updates, _ = rescaler.update(updates, None)
     # Manually scale updates.
     def rescale(t):
       return t * factor  # pylint:disable=cell-var-from-loop
     manual_updates = jax.tree_map(rescale, updates)
     # Check the rescaled updates match.
     chex.assert_tree_all_close(scaled_updates, manual_updates)
Exemple #15
0
  def test_stateless_inner(self):
    params = jnp.zeros([])
    grads = jnp.ones([])

    def should_update(step):
      return step < MaybeUpdateTest.NUM_STEPS

    opt = wrappers.maybe_update(transform.scale(2.), should_update)
    state = opt.init(params)
    update_fn = self.variant(opt.update)
    for _ in range(MaybeUpdateTest.NUM_STEPS):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 2.)
    # Further updates stop calling the inner optimiser.
    for _ in range(5):
      updates, state = update_fn(grads, state)
      self.assertEqual(updates, 1.)
Exemple #16
0
    def test_apply_every(self):
        # The frequency of the application of sgd
        k = 4
        zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.]))

        # optax sgd
        optax_sgd_params = self.init_params
        sgd = alias.sgd(LR, 0.0)
        state_sgd = sgd.init(optax_sgd_params)

        # optax sgd plus apply every
        optax_sgd_apply_every_params = self.init_params
        sgd_apply_every = combine.chain(
            transform.apply_every(k=k), transform.trace(decay=0,
                                                        nesterov=False),
            transform.scale(-LR))
        state_sgd_apply_every = sgd_apply_every.init(
            optax_sgd_apply_every_params)
        transform_fn = self.variant(sgd_apply_every.update)

        for i in range(STEPS):
            # Apply a step of sgd
            updates_sgd, state_sgd = sgd.update(self.per_step_updates,
                                                state_sgd)
            optax_sgd_params = update.apply_updates(optax_sgd_params,
                                                    updates_sgd)

            # Apply a step of sgd_apply_every
            updates_sgd_apply_every, state_sgd_apply_every = transform_fn(
                self.per_step_updates, state_sgd_apply_every)
            optax_sgd_apply_every_params = update.apply_updates(
                optax_sgd_apply_every_params, updates_sgd_apply_every)

            # Every k steps, check equivalence.
            if i % k == k - 1:
                chex.assert_tree_all_close(optax_sgd_apply_every_params,
                                           optax_sgd_params,
                                           atol=1e-6,
                                           rtol=1e-5)
            # Otherwise, check update is zero.
            else:
                chex.assert_tree_all_close(updates_sgd_apply_every,
                                           zero_update,
                                           atol=0.0,
                                           rtol=0.0)
 def custom_optim(learning_rate, mask):
     return wrappers.masked(transform.scale(-learning_rate), mask)
Exemple #18
0
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule):
    if callable(learning_rate):
        return transform.scale_by_schedule(lambda count: -learning_rate(count))
    return transform.scale(-learning_rate)
Exemple #19
0
class MaskedTest(chex.TestCase):
    """Tests for the masked wrapper."""
    @chex.all_variants
    @parameterized.named_parameters(
        ('scale', lambda: transform.scale(-1.), True),  # stateless test
        ('sgd', _build_sgd, False),  # stateful test
        ('scale', lambda: transform.scale(-1.), True),  # stateless test
        ('sgd', _build_sgd, False),  # stateful test
    )
    def test_masked(self, opt_builder, use_fn):
        mask = {
            'a': True,
            'b': [False, True],
            'c': {
                'd': True,
                'e': (False, True)
            }
        }
        mask_arg = lambda _: mask if use_fn else mask
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10., params)

        # Negate the updates wherever the mask is True
        def masked_negate(updates):
            return jax.tree_multimap(lambda upd, m: -upd
                                     if m else upd, updates, mask)

        correct_updates = masked_negate(input_updates)

        init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg)
        update_fn = self.variant(update_fn)
        state = self.variant(init_fn)(params)
        updates, state = update_fn(input_updates, state, params)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = masked_negate(correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)

    @chex.all_variants
    @parameterized.named_parameters(
        ('scale', lambda: transform.scale(-1.)),  # stateless test
        ('sgd', _build_sgd),  # stateful test
    )
    def test_prefix_mask(self, opt_builder):
        """Test when the mask is a prefix of the updates PyTree."""
        mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}}
        params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10., params)

        # Negate the updates wherever the mask (or mask parent) is True
        def _masked_sgd_on_updates(m, upd):
            return jax.tree_map(lambda x: -x, upd) if m else upd

        correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask,
                                            input_updates)

        init_fn, update_fn = wrappers.masked(opt_builder(), mask)
        update_fn = self.variant(update_fn)
        state = self.variant(init_fn)(params)
        updates, state = update_fn(input_updates, state, params)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask,
                                            correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)

    @chex.all_variants
    def test_update_requires_params(self):
        weight_decay = 0.1
        mask = {
            'a': True,
            'b': [False, True],
            'c': {
                'd': True,
                'e': (False, True)
            }
        }
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10., params)

        correct_updates = jax.tree_multimap(
            lambda m, u, p: u + weight_decay * p
            if m else u, mask, input_updates, params)

        init_fn, update_fn = wrappers.masked(
            transform.additive_weight_decay(weight_decay), mask)
        update_fn = self.variant(update_fn)

        state = self.variant(init_fn)(params)
        updates, state = update_fn(input_updates, state, params)
        chex.assert_tree_all_close(updates, correct_updates)

        params = update.apply_updates(params, updates)

        # Test repeated application
        new_correct_updates = jax.tree_multimap(
            lambda m, u, p: u + weight_decay * p
            if m else u, mask, correct_updates, params)
        updates, state = update_fn(correct_updates, state, params)
        chex.assert_tree_all_close(updates, new_correct_updates)

    @parameterized.parameters(list, tuple, dict)
    def test_empty(self, container):
        init_fn, update_fn = wrappers.masked(_build_sgd(), container())
        update_fn(container(), init_fn(container()))

    @parameterized.parameters((False, False), (False, True), (True, False),
                              (True, True))
    def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn):
        mask = {
            'a': True,
            'b': [False, True],
            'c': {
                'd': True,
                'e': (False, True)
            }
        }
        mask_arg = lambda _: mask if use_fn else mask
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)

        if extra_key_in_mask:
            mask['c']['extra'] = True
        else:
            params['c']['extra'] = 7

        init_fn = wrappers.masked(_build_sgd(), mask_arg)[0]
        with self.assertRaises(ValueError):
            init_fn(params)

    @chex.all_variants
    def test_mask_fn(self):
        params = {
            'a': jnp.ones((1, 2)),
            'b': (jnp.ones((1, )), np.ones((1, 2, 3)))
        }
        mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p)
        init_fn, update_fn = wrappers.masked(
            transform.add_decayed_weights(0.1), mask_fn)
        update_fn = self.variant(update_fn)

        state = self.variant(init_fn)(params)
        grads = jax.tree_map(lambda x: x * 2, params)
        updates, state = update_fn(grads, state, params)
        np.testing.assert_allclose(updates['a'],
                                   grads['a'] + 0.1 * params['a'])
        np.testing.assert_allclose(updates['b'][0], grads['b'][0])
        np.testing.assert_allclose(updates['b'][1],
                                   grads['b'][1] + 0.1 * params['b'][1])
Exemple #20
0
class MaskedTest(chex.TestCase):
  """Tests for the masked wrapper."""

  @chex.all_variants
  @parameterized.named_parameters(
      ('scale', lambda: transform.scale(-1.)),  # stateless test
      ('sgd', _build_sgd),  # stateful test
  )
  def test_masked(self, opt_builder):
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    # Negate the updates wherever the mask is True
    def masked_negate(updates):
      return jax.tree_util.tree_multimap(
          lambda upd, m: -upd if m else upd, updates, mask)
    correct_updates = masked_negate(input_updates)

    init_fn, update_fn = wrappers.masked(opt_builder(), mask)
    update_fn = self.variant(update_fn)
    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    # Check repeated application, this time with no params.
    correct_updates = masked_negate(correct_updates)
    updates, state = update_fn(updates, state)
    chex.assert_tree_all_close(updates, correct_updates)

  @chex.all_variants
  @parameterized.named_parameters(
      ('scale', lambda: transform.scale(-1.)),  # stateless test
      ('sgd', _build_sgd),  # stateful test
  )
  def test_prefix_mask(self, opt_builder):
    """Test when the mask is a prefix of the updates PyTree."""
    mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}}
    params = {'a': 1, 'b': {'f': 2}, 'c': {'d': 3, 'e': ([4, 5], 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    # Negate the updates wherever the mask (or mask parent) is True
    def _masked_sgd_on_updates(m, upd):
      return jax.tree_util.tree_map(lambda x: -x, upd) if m else upd
    correct_updates = jax.tree_util.tree_multimap(
        _masked_sgd_on_updates, mask, input_updates)

    init_fn, update_fn = wrappers.masked(opt_builder(), mask)
    update_fn = self.variant(update_fn)
    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    # Check repeated application, this time with no params.
    correct_updates = jax.tree_util.tree_multimap(
        _masked_sgd_on_updates, mask, correct_updates)
    updates, state = update_fn(updates, state)
    chex.assert_tree_all_close(updates, correct_updates)

  @chex.all_variants
  def test_update_requires_params(self):
    weight_decay = 0.1
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, input_updates, params)

    init_fn, update_fn = wrappers.masked(
        transform.additive_weight_decay(weight_decay), mask)
    update_fn = self.variant(update_fn)

    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    params = update.apply_updates(params, updates)

    # Test repeated application
    new_correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, correct_updates, params)
    updates, state = update_fn(correct_updates, state, params)
    chex.assert_tree_all_close(updates, new_correct_updates)

  @parameterized.parameters(list, tuple, dict)
  def test_empty(self, container):
    init_fn, update_fn = wrappers.masked(_build_sgd(), container())
    update_fn(container(), init_fn(container()))

  @parameterized.parameters(True, False)
  def test_tree_mismatch_fails(self, extra_key_in_mask):
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}

    if extra_key_in_mask:
      mask['c']['extra'] = True
    else:
      params['c']['extra'] = 7

    init_fn = wrappers.masked(_build_sgd(), mask)[0]
    with self.assertRaises(ValueError):
      init_fn(params)
Exemple #21
0
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True):
  m = -1 if flip_sign else 1
  if callable(learning_rate):
    return transform.scale_by_schedule(lambda count: m * learning_rate(count))
  return transform.scale(m * learning_rate)
Exemple #22
0
def adafactor(
    learning_rate: Optional[ScalarOrSchedule] = None,
    min_dim_size_to_factor: int = 128,
    decay_rate: float = 0.8,
    decay_offset: int = 0,
    multiply_by_parameter_scale: float = True,
    clipping_threshold: Optional[float] = 1.0,
    momentum: Optional[float] = None,
    dtype_momentum: Any = jnp.float32,
    weight_decay_rate: Optional[float] = None,
    eps: float = 1e-30,
    factored: bool = True,
    weight_decay_mask: MaskOrFn = None,
    ) -> base.GradientTransformation:
  """The Adafactor optimiser.

  Adafactor is an adaptive learning rate optimiser that focuses on fast
  training of large scale neural networks. It saves memory by using a factored
  estimate of the second order moments used to scale gradients.

  References:
    Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235

  Args:
      learning_rate: (float) a step size. Note: the natural scale for
        Adafactor's LR is markedly different from Adam, one doesn't use the
        1/sqrt(hidden) correction for this optim with attention-based models.
      min_dim_size_to_factor: (int) only factor the statistics if two array
        dimensions have at least this size.
      decay_rate: (float) controls second-moment exponential decay schedule.
      decay_offset: (int) for finetuning, one may set this to the starting
        step number of the finetuning phase.
      multiply_by_parameter_scale: (bool): if True, then scale learning_rate by
        parameter norm. if False, provided learning_rate is absolute step size.
      clipping_threshold: (float>=1) optional value; if None, clipping disabled.
      momentum: (float) optional value between 0 and 1, enables
        momentum and uses extra memory if non-None! None by default.
      dtype_momentum: (dtype) dtype of momentum buffers.
      weight_decay_rate: (float) optional rate at which to decay weights.
      eps: (float) regularization constant for root mean squared gradient.
      factored: (bool) whether to use factored second-moment estimates.
      weight_decay_mask: a tree with same structure as (or a prefix of)
        the params PyTree, or a Callable that returns such a pytree given
        the params/updates. The leaves should be booleans, `True`
        for leaves/subtrees you want to apply the transformation to,
        and `False` for those you want to skip.

  Returns:
    the corresponding `GradientTransformation`.
  """
  # The core of the algorithm is a procedure for rescaling gradients
  # by a factored estimate of the root mean squared gradients.
  # This reduces memory compared to algorithms such as Adam or RmsProp,
  # by not having to hold a separate estimate for each weight.
  tx = [
      factorized.scale_by_factored_rms(
          factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)]
  # This basic rescaling is typically combined with one or more of the following
  # transformation (all can be disabled via adafactor's constructor args).
  if clipping_threshold is not None:
    tx.append(clipping.clip_by_block_rms(clipping_threshold))
  if learning_rate is not None:
    tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False))
  if multiply_by_parameter_scale:
    tx.append(transform.scale_by_param_block_rms())
  if momentum is not None:
    tx.append(
        transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum))
  if weight_decay_rate is not None:
    tx.append(transform.add_decayed_weights(
        weight_decay_rate, mask=weight_decay_mask))
  # In gradient "descent" we follow the negative gradient.
  tx.append(transform.scale(-1))
  return combine.chain(*tx)