def test_multi_transform(self, use_fn): params = {'a1': 1., 'b1': 2., 'z1': {'a2': 3., 'z2': {'c1': 4.}}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10.0, params) tx_dict = { 'a': transform.scale(-1.0), 'b': transform.ema(0.0), # stateful 'c': transform.scale(2.0) } param_labels = _map_keys_fn(lambda k, _: k[0]) if not use_fn: param_labels = param_labels(params) tx = combine.multi_transform(tx_dict, param_labels) update_fn = self.variant(tx.update) state = self.variant(tx.init)(params) correct_update_fn = _map_keys_fn(lambda k, v: { 'a': -v, 'b': v, 'c': 2.0 * v }[k[0]]) updates, state = update_fn(input_updates, state, params) correct_updates = correct_update_fn(input_updates) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = correct_update_fn(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates)
def test_keep_params_nonnegative(self): grads = (jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.]), jnp.array([500., -500., 0.])) params = (jnp.array([-1., -1., -1.]), jnp.array([1., 1., 1.]), jnp.array([0., 0., 0.])) # vanilla sgd opt = combine.chain(transform.trace(decay=0, nesterov=False), transform.scale(-LR)) opt_state = opt.init(params) updates, _ = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) chex.assert_tree_all_close( new_params, (jnp.array([-6., 4., -1.]), jnp.array( [-4., 6., 1.]), jnp.array([-5., 5., 0.]))) # sgd with keeping parameters non-negative opt = combine.chain(transform.trace(decay=0, nesterov=False), transform.scale(-LR), constrain.keep_params_nonnegative()) opt_state = opt.init(params) updates, _ = opt.update(grads, opt_state, params) new_params = update.apply_updates(params, updates) chex.assert_tree_all_close(new_params, (jnp.array( [0., 4., 0.]), jnp.array([0., 6., 1.]), jnp.array([0., 5., 0.])))
def rmsprop(learning_rate: float, decay: float = 0.9, eps: float = 1e-8, centered: bool = False) -> GradientTransformation: if centered: return combine.chain( transform.scale_by_stddev(decay=decay, eps=eps), transform.scale(-learning_rate), ) return combine.chain( transform.scale_by_rms(decay=decay, eps=eps), transform.scale(-learning_rate), )
def sm3(learning_rate: float, momentum: float = 0.9) -> base.GradientTransformation: """The SM3 optimiser. SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a memory-efficient adaptive optimiser designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modelling, and AmoebaNet-D for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and data-driven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings. References: Anil et al, 2019: https://arxiv.org/abs/1901.11150 Args: learning_rate: this is a fixed global scaling factor. momentum: the `decay` rate used by the momentum term (when it is not set to `None`, then momentum is not used at all). Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_sm3(momentum), transform.scale(-learning_rate), )
def sgd(learning_rate: float, momentum: float = 0., nesterov: bool = False) -> GradientTransformation: return combine.chain( transform.trace(decay=momentum, nesterov=nesterov), transform.scale(-learning_rate), )
def test_chain(self): transformations = [ transform.scale_by_adam(), transform.trace(decay=0, nesterov=False), transform.scale(-LR)] # Apply updates with chain. chain_params = self.init_params chained_transforms = combine.chain(*transformations) state = chained_transforms.init(chain_params) @self.variant def update_fn(updates, state): return chained_transforms.update(updates, state) for _ in range(STEPS): updates, state = update_fn(self.per_step_updates, state) chain_params = update.apply_updates(chain_params, updates) # Manually apply sequence of transformations. manual_params = self.init_params states = [t.init(manual_params) for t in transformations] for _ in range(STEPS): updates = self.per_step_updates new_states = [] for t, s in zip(transformations, states): updates, state = t.update(updates, s) new_states.append(state) manual_params = update.apply_updates(manual_params, updates) states = new_states # Check equivalence. chex.assert_tree_all_close(manual_params, chain_params, rtol=1e-4)
def fromage(learning_rate: float, min_norm: float = 1e-6) -> GradientTransformation: mult = 1 / jnp.sqrt(1 + learning_rate**2) return combine.chain( transform.scale_by_trust_ratio(min_norm), transform.scale(-learning_rate * mult), transform.add_decayed_weights((mult - 1)), )
def adam(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps), transform.scale(-learning_rate), )
def adagrad(learning_rate: float, initial_accumulator_value: float = 0.1, eps: float = 1e-7) -> GradientTransformation: return combine.chain( transform.scale_by_rss( initial_accumulator_value=initial_accumulator_value, eps=eps), transform.scale(-learning_rate), )
def noisy_sgd(learning_rate: float, eta: float = 0.01, gamma: float = 0.55, seed: int = 0) -> GradientTransformation: return combine.chain( transform.trace(decay=0., nesterov=False), transform.scale(-learning_rate), transform.add_noise(eta, gamma, seed), )
def radam(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, threshold: float = 5.0) -> GradientTransformation: return combine.chain( transform.scale_by_radam(b1=b1, b2=b2, eps=eps, threshold=threshold), transform.scale(-learning_rate), )
def adamw(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, weight_decay: float = 1e-4) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps), transform.additive_weight_decay(weight_decay), transform.scale(-learning_rate), )
def lamb(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, eps_root: float = 0.0, weight_decay: float = 0.) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), transform.additive_weight_decay(weight_decay), transform.scale_by_trust_ratio(), transform.scale(-learning_rate), )
def test_scale(self): updates = self.per_step_updates for i in range(1, STEPS + 1): factor = 0.1 ** i rescaler = transform.scale(factor) # Apply rescaling. scaled_updates, _ = rescaler.update(updates, None) # Manually scale updates. def rescale(t): return t * factor # pylint:disable=cell-var-from-loop manual_updates = jax.tree_map(rescale, updates) # Check the rescaled updates match. chex.assert_tree_all_close(scaled_updates, manual_updates)
def test_stateless_inner(self): params = jnp.zeros([]) grads = jnp.ones([]) def should_update(step): return step < MaybeUpdateTest.NUM_STEPS opt = wrappers.maybe_update(transform.scale(2.), should_update) state = opt.init(params) update_fn = self.variant(opt.update) for _ in range(MaybeUpdateTest.NUM_STEPS): updates, state = update_fn(grads, state) self.assertEqual(updates, 2.) # Further updates stop calling the inner optimiser. for _ in range(5): updates, state = update_fn(grads, state) self.assertEqual(updates, 1.)
def test_apply_every(self): # The frequency of the application of sgd k = 4 zero_update = (jnp.array([0., 0.]), jnp.array([0., 0.])) # optax sgd optax_sgd_params = self.init_params sgd = alias.sgd(LR, 0.0) state_sgd = sgd.init(optax_sgd_params) # optax sgd plus apply every optax_sgd_apply_every_params = self.init_params sgd_apply_every = combine.chain( transform.apply_every(k=k), transform.trace(decay=0, nesterov=False), transform.scale(-LR)) state_sgd_apply_every = sgd_apply_every.init( optax_sgd_apply_every_params) transform_fn = self.variant(sgd_apply_every.update) for i in range(STEPS): # Apply a step of sgd updates_sgd, state_sgd = sgd.update(self.per_step_updates, state_sgd) optax_sgd_params = update.apply_updates(optax_sgd_params, updates_sgd) # Apply a step of sgd_apply_every updates_sgd_apply_every, state_sgd_apply_every = transform_fn( self.per_step_updates, state_sgd_apply_every) optax_sgd_apply_every_params = update.apply_updates( optax_sgd_apply_every_params, updates_sgd_apply_every) # Every k steps, check equivalence. if i % k == k - 1: chex.assert_tree_all_close(optax_sgd_apply_every_params, optax_sgd_params, atol=1e-6, rtol=1e-5) # Otherwise, check update is zero. else: chex.assert_tree_all_close(updates_sgd_apply_every, zero_update, atol=0.0, rtol=0.0)
def custom_optim(learning_rate, mask): return wrappers.masked(transform.scale(-learning_rate), mask)
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule): if callable(learning_rate): return transform.scale_by_schedule(lambda count: -learning_rate(count)) return transform.scale(-learning_rate)
class MaskedTest(chex.TestCase): """Tests for the masked wrapper.""" @chex.all_variants @parameterized.named_parameters( ('scale', lambda: transform.scale(-1.), True), # stateless test ('sgd', _build_sgd, False), # stateful test ('scale', lambda: transform.scale(-1.), True), # stateless test ('sgd', _build_sgd, False), # stateful test ) def test_masked(self, opt_builder, use_fn): mask = { 'a': True, 'b': [False, True], 'c': { 'd': True, 'e': (False, True) } } mask_arg = lambda _: mask if use_fn else mask params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10., params) # Negate the updates wherever the mask is True def masked_negate(updates): return jax.tree_multimap(lambda upd, m: -upd if m else upd, updates, mask) correct_updates = masked_negate(input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = masked_negate(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates) @chex.all_variants @parameterized.named_parameters( ('scale', lambda: transform.scale(-1.)), # stateless test ('sgd', _build_sgd), # stateful test ) def test_prefix_mask(self, opt_builder): """Test when the mask is a prefix of the updates PyTree.""" mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}} params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10., params) # Negate the updates wherever the mask (or mask parent) is True def _masked_sgd_on_updates(m, upd): return jax.tree_map(lambda x: -x, upd) if m else upd correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask, input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask, correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates) @chex.all_variants def test_update_requires_params(self): weight_decay = 0.1 mask = { 'a': True, 'b': [False, True], 'c': { 'd': True, 'e': (False, True) } } params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10., params) correct_updates = jax.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, input_updates, params) init_fn, update_fn = wrappers.masked( transform.additive_weight_decay(weight_decay), mask) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) params = update.apply_updates(params, updates) # Test repeated application new_correct_updates = jax.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, correct_updates, params) updates, state = update_fn(correct_updates, state, params) chex.assert_tree_all_close(updates, new_correct_updates) @parameterized.parameters(list, tuple, dict) def test_empty(self, container): init_fn, update_fn = wrappers.masked(_build_sgd(), container()) update_fn(container(), init_fn(container())) @parameterized.parameters((False, False), (False, True), (True, False), (True, True)) def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): mask = { 'a': True, 'b': [False, True], 'c': { 'd': True, 'e': (False, True) } } mask_arg = lambda _: mask if use_fn else mask params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) if extra_key_in_mask: mask['c']['extra'] = True else: params['c']['extra'] = 7 init_fn = wrappers.masked(_build_sgd(), mask_arg)[0] with self.assertRaises(ValueError): init_fn(params) @chex.all_variants def test_mask_fn(self): params = { 'a': jnp.ones((1, 2)), 'b': (jnp.ones((1, )), np.ones((1, 2, 3))) } mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p) init_fn, update_fn = wrappers.masked( transform.add_decayed_weights(0.1), mask_fn) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) grads = jax.tree_map(lambda x: x * 2, params) updates, state = update_fn(grads, state, params) np.testing.assert_allclose(updates['a'], grads['a'] + 0.1 * params['a']) np.testing.assert_allclose(updates['b'][0], grads['b'][0]) np.testing.assert_allclose(updates['b'][1], grads['b'][1] + 0.1 * params['b'][1])
class MaskedTest(chex.TestCase): """Tests for the masked wrapper.""" @chex.all_variants @parameterized.named_parameters( ('scale', lambda: transform.scale(-1.)), # stateless test ('sgd', _build_sgd), # stateful test ) def test_masked(self, opt_builder): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) # Negate the updates wherever the mask is True def masked_negate(updates): return jax.tree_util.tree_multimap( lambda upd, m: -upd if m else upd, updates, mask) correct_updates = masked_negate(input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = masked_negate(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates) @chex.all_variants @parameterized.named_parameters( ('scale', lambda: transform.scale(-1.)), # stateless test ('sgd', _build_sgd), # stateful test ) def test_prefix_mask(self, opt_builder): """Test when the mask is a prefix of the updates PyTree.""" mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}} params = {'a': 1, 'b': {'f': 2}, 'c': {'d': 3, 'e': ([4, 5], 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) # Negate the updates wherever the mask (or mask parent) is True def _masked_sgd_on_updates(m, upd): return jax.tree_util.tree_map(lambda x: -x, upd) if m else upd correct_updates = jax.tree_util.tree_multimap( _masked_sgd_on_updates, mask, input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = jax.tree_util.tree_multimap( _masked_sgd_on_updates, mask, correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates) @chex.all_variants def test_update_requires_params(self): weight_decay = 0.1 mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, input_updates, params) init_fn, update_fn = wrappers.masked( transform.additive_weight_decay(weight_decay), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) params = update.apply_updates(params, updates) # Test repeated application new_correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, correct_updates, params) updates, state = update_fn(correct_updates, state, params) chex.assert_tree_all_close(updates, new_correct_updates) @parameterized.parameters(list, tuple, dict) def test_empty(self, container): init_fn, update_fn = wrappers.masked(_build_sgd(), container()) update_fn(container(), init_fn(container())) @parameterized.parameters(True, False) def test_tree_mismatch_fails(self, extra_key_in_mask): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} if extra_key_in_mask: mask['c']['extra'] = True else: params['c']['extra'] = 7 init_fn = wrappers.masked(_build_sgd(), mask)[0] with self.assertRaises(ValueError): init_fn(params)
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True): m = -1 if flip_sign else 1 if callable(learning_rate): return transform.scale_by_schedule(lambda count: m * learning_rate(count)) return transform.scale(m * learning_rate)
def adafactor( learning_rate: Optional[ScalarOrSchedule] = None, min_dim_size_to_factor: int = 128, decay_rate: float = 0.8, decay_offset: int = 0, multiply_by_parameter_scale: float = True, clipping_threshold: Optional[float] = 1.0, momentum: Optional[float] = None, dtype_momentum: Any = jnp.float32, weight_decay_rate: Optional[float] = None, eps: float = 1e-30, factored: bool = True, weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformation: """The Adafactor optimiser. Adafactor is an adaptive learning rate optimiser that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients. References: Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235 Args: learning_rate: (float) a step size. Note: the natural scale for Adafactor's LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden) correction for this optim with attention-based models. min_dim_size_to_factor: (int) only factor the statistics if two array dimensions have at least this size. decay_rate: (float) controls second-moment exponential decay schedule. decay_offset: (int) for finetuning, one may set this to the starting step number of the finetuning phase. multiply_by_parameter_scale: (bool): if True, then scale learning_rate by parameter norm. if False, provided learning_rate is absolute step size. clipping_threshold: (float>=1) optional value; if None, clipping disabled. momentum: (float) optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default. dtype_momentum: (dtype) dtype of momentum buffers. weight_decay_rate: (float) optional rate at which to decay weights. eps: (float) regularization constant for root mean squared gradient. factored: (bool) whether to use factored second-moment estimates. weight_decay_mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. Returns: the corresponding `GradientTransformation`. """ # The core of the algorithm is a procedure for rescaling gradients # by a factored estimate of the root mean squared gradients. # This reduces memory compared to algorithms such as Adam or RmsProp, # by not having to hold a separate estimate for each weight. tx = [ factorized.scale_by_factored_rms( factored, decay_rate, decay_offset, min_dim_size_to_factor, eps)] # This basic rescaling is typically combined with one or more of the following # transformation (all can be disabled via adafactor's constructor args). if clipping_threshold is not None: tx.append(clipping.clip_by_block_rms(clipping_threshold)) if learning_rate is not None: tx.append(_scale_by_learning_rate(learning_rate, flip_sign=False)) if multiply_by_parameter_scale: tx.append(transform.scale_by_param_block_rms()) if momentum is not None: tx.append( transform.ema(momentum, debias=False, accumulator_dtype=dtype_momentum)) if weight_decay_rate is not None: tx.append(transform.add_decayed_weights( weight_decay_rate, mask=weight_decay_mask)) # In gradient "descent" we follow the negative gradient. tx.append(transform.scale(-1)) return combine.chain(*tx)