Exemplo n.º 1
0
    def test_multi_transform(self, use_fn):
        params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10.0, params)
        tx_dict = {
            'a': transform.scale(-1.0),
            'b': transform.ema(0.0),  # stateful
            'c': transform.scale(2.0)
        }
        param_labels = _map_keys_fn(lambda k, _: k[0])
        if not use_fn:
            param_labels = param_labels(params)
        tx = combine.multi_transform(tx_dict, param_labels)
        update_fn = self.variant(tx.update)
        state = self.variant(tx.init)(params)

        correct_update_fn = _map_keys_fn(lambda k, v: {
            'a': -v,
            'b': v,
            'c': 2.0 * v
        }[k[0]])

        updates, state = update_fn(input_updates, state, params)
        correct_updates = correct_update_fn(input_updates)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = correct_update_fn(correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)
Exemplo n.º 2
0
    def test_ema(self):
        values = jnp.array([5.0, 7.0])
        decay = 0.9
        d = decay

        ema = transform.ema(decay=decay, debias=False)
        state = ema.init(values[0])  # init to zeroes

        transform_fn = self.variant(ema.update)
        mean, state = transform_fn(values[0], state)
        np.testing.assert_allclose(mean, (1 - d) * values[0], atol=1e-4)

        mean, state = transform_fn(values[1], state)
        np.testing.assert_allclose(mean, (1 - d) * (values[1] + d * values[0]),
                                   atol=1e-2)
Exemplo n.º 3
0
    def test_ema_debias(self):
        values = jnp.array([5.0, 7.0])
        decay = 0.9
        d = decay

        ema = transform.ema(decay=decay)
        state = ema.init(values[0])

        transform_fn = self.variant(ema.update)
        mean, state = transform_fn(values[0], state)
        np.testing.assert_allclose(mean, values[0], atol=1e-4)

        mean, state = transform_fn(values[1], state)
        np.testing.assert_allclose(
            mean, ((1 - d) * values[1] + d * values[0]) / (1 - d**2),
            atol=1e-2)
Exemplo n.º 4
0
def adafactor(
    learning_rate: Optional[ScalarOrSchedule] = None,
    min_dim_size_to_factor: int = 128,
    decay_rate: float = 0.8,
    decay_offset: int = 0,
    multiply_by_parameter_scale: float = True,
    clipping_threshold: Optional[float] = 1.0,
    momentum: Optional[float] = None,
    dtype_momentum: Any = jnp.float32,
    weight_decay_rate: Optional[float] = None,
    eps: float = 1e-30,
    factored: bool = True,
    weight_decay_mask: MaskOrFn = None,
    ) -> base.GradientTransformation:
  """The Adafactor optimiser.

  Adafactor is an adaptive learning rate optimiser that focuses on fast
  training of large scale neural networks. It saves memory by using a factored
  estimate of the second order moments used to scale gradients.

  References:
    Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235

  Args:
      learning_rate: (float) a step size. Note: the natural scale for
        Adafactor's LR is markedly different from Adam, one doesn't use the
        1/sqrt(hidden) correction for this optim with attention-based models.
      min_dim_size_to_factor: (int) only factor the statistics if two array
        dimensions have at least this size.
      decay_rate: (float) controls second-moment exponential decay schedule.
      decay_offset: (int) for finetuning, one may set this to the starting
        step number of the finetuning phase.
      multiply_by_parameter_scale: (bool): if True, then scale learning_rate by
        parameter norm. if False, provided learning_rate is absolute step size.
      clipping_threshold: (float>=1) optional value; if None, clipping disabled.
      momentum: (float) optional value between 0 and 1, enables
        momentum and uses extra memory if non-None! None by default.
      dtype_momentum: (dtype) dtype of momentum buffers.
      weight_decay_rate: (float) optional rate at which to decay weights.
      eps: (float) regularization constant for root mean squared gradient.
      factored: (bool) whether to use factored second-moment estimates.
      weight_decay_mask: a tree with same structure as (or a prefix of)
        the params PyTree, or a Callable that returns such a pytree given
        the params/updates. The leaves should be booleans, `True`
        for leaves/subtrees you want to apply the transformation to,
        and `False` for those you want to skip.

  Returns:
    the corresponding `GradientTransformation`.
  """
  # The core of the algorithm is a procedure for rescaling gradients
  # by a factored estimate of the root mean squared gradients.
  # This reduces memory compared to algorithms such as Adam or RmsProp,
  # by not having to hold a separate estimate for each weight.
  tx = [
      factorized.scale_by_factored_rms(
          factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)]
  # This basic rescaling is typically combined with one or more of the following
  # transformation (all can be disabled via adafactor's constructor args).
  if clipping_threshold is not None:
    tx.append(clipping.clip_by_block_rms(clipping_threshold))
  if learning_rate is not None:
    tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False))
  if multiply_by_parameter_scale:
    tx.append(transform.scale_by_param_block_rms())
  if momentum is not None:
    tx.append(
        transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum))
  if weight_decay_rate is not None:
    tx.append(transform.add_decayed_weights(
        weight_decay_rate, mask=weight_decay_mask))
  # In gradient "descent" we follow the negative gradient.
  tx.append(transform.scale(-1))
  return combine.chain(*tx)