def test_update_requires_params(self): weight_decay = 0.1 mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, input_updates, params) init_fn, update_fn = wrappers.masked( transform.additive_weight_decay(weight_decay), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) params = update.apply_updates(params, updates) # Test repeated application new_correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, correct_updates, params) updates, state = update_fn(correct_updates, state, params) chex.assert_tree_all_close(updates, new_correct_updates)
def adamw(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, weight_decay: float = 1e-4) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps), transform.additive_weight_decay(weight_decay), transform.scale(-learning_rate), )
def adamw(learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, weight_decay: float = 1e-4) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), transform.additive_weight_decay(weight_decay), _scale_by_learning_rate(learning_rate), )
def lamb(learning_rate: float, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-6, eps_root: float = 0.0, weight_decay: float = 0.) -> GradientTransformation: return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), transform.additive_weight_decay(weight_decay), transform.scale_by_trust_ratio(), transform.scale(-learning_rate), )
def adamw( learning_rate: ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, weight_decay: float = 1e-4 ) -> base.GradientTransformation: """Adam with weight decay regularization. AdamW uses weight decay to regularise learning towards small weights, as this leads to better generalisation. In SGD you can also use L2 regularisation to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam. WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or for the bias parameters. You can use `optax.masked` to make your own AdamW variant where `additive_weight_decay` is applied only to a subset of `params`. References: Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101 Args: learning_rate: this is a fixed global scaling factor. b1: the exponential decay rate to track the first moment of past gradients. b2: the exponential decay rate to track the second moment of past gradients. eps: a small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. eps_root: (default `0`), a small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam. weight_decay: strength of the weight decay regularization. Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), transform.additive_weight_decay(weight_decay), _scale_by_learning_rate(learning_rate), )