def test_multi_transform(self, use_fn): params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10.0, params) tx_dict = { 'a': transform.scale(-1.0), 'b': transform.ema(0.0), # stateful 'c': transform.scale(2.0) } param_labels = _map_keys_fn(lambda k, _: k[0]) if not use_fn: param_labels = param_labels(params) tx = combine.multi_transform(tx_dict, param_labels) update_fn = self.variant(tx.update) state = self.variant(tx.init)(params) correct_update_fn = _map_keys_fn(lambda k, v: { 'a': -v, 'b': v, 'c': 2.0 * v }[k[0]]) updates, state = update_fn(input_updates, state, params) correct_updates = correct_update_fn(input_updates) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = correct_update_fn(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates)
def test_ema(self): values = jnp.array([5.0, 7.0]) decay = 0.9 d = decay ema = transform.ema(decay=decay, debias=False) state = ema.init(values[0]) # init to zeroes transform_fn = self.variant(ema.update) mean, state = transform_fn(values[0], state) np.testing.assert_allclose(mean, (1 - d) * values[0], atol=1e-4) mean, state = transform_fn(values[1], state) np.testing.assert_allclose(mean, (1 - d) * (values[1] + d * values[0]), atol=1e-2)
def test_ema_debias(self): values = jnp.array([5.0, 7.0]) decay = 0.9 d = decay ema = transform.ema(decay=decay) state = ema.init(values[0]) transform_fn = self.variant(ema.update) mean, state = transform_fn(values[0], state) np.testing.assert_allclose(mean, values[0], atol=1e-4) mean, state = transform_fn(values[1], state) np.testing.assert_allclose( mean, ((1 - d) * values[1] + d * values[0]) / (1 - d**2), atol=1e-2)
def adafactor( learning_rate: Optional[ScalarOrSchedule] = None, min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: float = True, clipping_threshold: Optional[float] = 1.0, momentum: Optional[float] = None, dtype_momentum: Any = jnp.float32, weight_decay_rate: Optional[float] = None, eps: float = 1e-30, factored: bool = True, weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformation: """The Adafactor optimiser. Adafactor is an adaptive learning rate optimiser that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients. References: Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235 Args: learning_rate: (float) a step size. Note: the natural scale for Adafactor's LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden) correction for this optim with attention-based models. min_dim_size_to_factor: (int) only factor the statistics if two array dimensions have at least this size. decay_rate: (float) controls second-moment exponential decay schedule. decay_offset: (int) for finetuning, one may set this to the starting step number of the finetuning phase. multiply_by_parameter_scale: (bool): if True, then scale learning_rate by parameter norm. if False, provided learning_rate is absolute step size. clipping_threshold: (float>=1) optional value; if None, clipping disabled. momentum: (float) optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default. dtype_momentum: (dtype) dtype of momentum buffers. weight_decay_rate: (float) optional rate at which to decay weights. eps: (float) regularization constant for root mean squared gradient. factored: (bool) whether to use factored second-moment estimates. weight_decay_mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. Returns: the corresponding `GradientTransformation`. """ # The core of the algorithm is a procedure for rescaling gradients # by a factored estimate of the root mean squared gradients. # This reduces memory compared to algorithms such as Adam or RmsProp, # by not having to hold a separate estimate for each weight. tx = [ factorized.scale_by_factored_rms( factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)] # This basic rescaling is typically combined with one or more of the following # transformation (all can be disabled via adafactor's constructor args). if clipping_threshold is not None: tx.append(clipping.clip_by_block_rms(clipping_threshold)) if learning_rate is not None: tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False)) if multiply_by_parameter_scale: tx.append(transform.scale_by_param_block_rms()) if momentum is not None: tx.append( transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)) if weight_decay_rate is not None: tx.append(transform.add_decayed_weights( weight_decay_rate, mask=weight_decay_mask)) # In gradient "descent" we follow the negative gradient. tx.append(transform.scale(-1)) return combine.chain(*tx)