def test_prefix_mask(self, opt_builder): """Test when the mask is a prefix of the updates PyTree.""" mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}} params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10., params) # Negate the updates wherever the mask (or mask parent) is True def _masked_sgd_on_updates(m, upd): return jax.tree_map(lambda x: -x, upd) if m else upd correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask, input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask, correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates)
def add_decayed_weights( weight_decay: float = 0.0, mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None ) -> base.GradientTransformation: """Add parameter scaled by `weight_decay`. Args: weight_decay: a scalar weight decay rate. mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AddDecayedWeightsState() def update_fn(updates, state, params): if params is None: raise ValueError(base.NO_PARAMS_MSG) updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params) return updates, state # If mask is not `None`, apply mask to the gradient transformation. # E.g. it is common to skip weight decay on bias units and batch stats. if mask is not None: return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) return base.GradientTransformation(init_fn, update_fn)
def test_update_requires_params(self): weight_decay = 0.1 mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, input_updates, params) init_fn, update_fn = wrappers.masked( transform.additive_weight_decay(weight_decay), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) params = update.apply_updates(params, updates) # Test repeated application new_correct_updates = jax.tree_util.tree_multimap( lambda m, u, p: u + weight_decay * p if m else u, mask, correct_updates, params) updates, state = update_fn(correct_updates, state, params) chex.assert_tree_all_close(updates, new_correct_updates)
def test_masked(self, opt_builder, use_fn): mask = { 'a': True, 'b': [False, True], 'c': { 'd': True, 'e': (False, True) } } mask_arg = lambda _: mask if use_fn else mask params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) input_updates = jax.tree_map(lambda x: x / 10., params) # Negate the updates wherever the mask is True def masked_negate(updates): return jax.tree_multimap(lambda upd, m: -upd if m else upd, updates, mask) correct_updates = masked_negate(input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = masked_negate(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates)
def update_fn(updates, state, params=None): labels = param_labels(updates) if callable( param_labels) else param_labels new_inner_state = {} for group, tx in transforms.items(): masked_tx = wrappers.masked(tx, make_mask(labels, group)) updates, new_inner_state[group] = masked_tx.update( updates, state.inner_states[group], params) return updates, MultiTransformState(new_inner_state)
def test_tree_mismatch_fails(self, extra_key_in_mask): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} if extra_key_in_mask: mask['c']['extra'] = True else: params['c']['extra'] = 7 init_fn = wrappers.masked(_build_sgd(), mask)[0] with self.assertRaises(ValueError): init_fn(params)
def init_fn(params): labels = param_labels(params) if callable( param_labels) else param_labels label_set = set(jax.tree_leaves(labels)) if not label_set.issubset(transforms.keys()): raise ValueError( 'Some parameters have no corresponding transformation.\n' f'Parameter labels: {list(sorted(label_set))} \n' f'Transforms keys: {list(sorted(transforms.keys()))} \n') inner_states = { group: wrappers.masked(tx, make_mask(labels, group)).init(params) for group, tx in transforms.items() } return MultiTransformState(inner_states)
def lars( learning_rate: ScalarOrSchedule, weight_decay: float = 0., weight_decay_mask: MaskOrFn = True, trust_coefficient: float = 0.001, eps: float = 0., trust_ratio_mask: MaskOrFn = True, momentum: float = 0.9, nesterov: bool = False, ) -> base.GradientTransformation: """The LARS optimiser. LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimiser. References: You et al, 2017: https://arxiv.org/abs/1708.03888 Args: learning_rate: this is a fixed global scaling factor. weight_decay (default `0.`): strength of the weight decay regularization. weight_decay_mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. trust_coefficient: a multiplier for the trust ratio. eps: optional additive constant in the trust ratio denominator. trust_ratio_mask: a tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. momentum: the decay rate for momentum. nesterov: whether to use Nesterov momentum. Returns: the corresponding `GradientTransformation`. """ return combine.chain( transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), wrappers.masked( inner=transform.scale_by_trust_ratio( trust_coefficient=trust_coefficient, eps=eps), mask=trust_ratio_mask), _scale_by_learning_rate(learning_rate), transform.trace(decay=momentum, nesterov=nesterov), )
def test_mask_fn(self): params = { 'a': jnp.ones((1, 2)), 'b': (jnp.ones((1, )), np.ones((1, 2, 3))) } mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p) init_fn, update_fn = wrappers.masked( transform.add_decayed_weights(0.1), mask_fn) update_fn = self.variant(update_fn) state = self.variant(init_fn)(params) grads = jax.tree_map(lambda x: x * 2, params) updates, state = update_fn(grads, state, params) np.testing.assert_allclose(updates['a'], grads['a'] + 0.1 * params['a']) np.testing.assert_allclose(updates['b'][0], grads['b'][0]) np.testing.assert_allclose(updates['b'][1], grads['b'][1] + 0.1 * params['b'][1])
def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn): mask = { 'a': True, 'b': [False, True], 'c': { 'd': True, 'e': (False, True) } } mask_arg = lambda _: mask if use_fn else mask params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}} params = jax.tree_map(jnp.asarray, params) if extra_key_in_mask: mask['c']['extra'] = True else: params['c']['extra'] = 7 init_fn = wrappers.masked(_build_sgd(), mask_arg)[0] with self.assertRaises(ValueError): init_fn(params)
def test_masked(self, opt_builder): mask = {'a': True, 'b': [False, True], 'c': {'d': True, 'e': (False, True)}} params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}} input_updates = jax.tree_util.tree_map(lambda x: x/10., params) # Negate the updates wherever the mask is True def masked_negate(updates): return jax.tree_util.tree_multimap( lambda upd, m: -upd if m else upd, updates, mask) correct_updates = masked_negate(input_updates) init_fn, update_fn = wrappers.masked(opt_builder(), mask) update_fn = self.variant(update_fn) state = init_fn(params) updates, state = update_fn(input_updates, state, params) chex.assert_tree_all_close(updates, correct_updates) # Check repeated application, this time with no params. correct_updates = masked_negate(correct_updates) updates, state = update_fn(updates, state) chex.assert_tree_all_close(updates, correct_updates)
def custom_optim(learning_rate, mask): return wrappers.masked(transform.scale(-learning_rate), mask)
def test_empty(self, container): init_fn, update_fn = wrappers.masked(_build_sgd(), container()) update_fn(container(), init_fn(container()))