Ejemplo n.º 1
0
    def test_prefix_mask(self, opt_builder):
        """Test when the mask is a prefix of the updates PyTree."""
        mask = {'a': True, 'b': False, 'c': {'d': False, 'e': True}}
        params = {'a': 1., 'b': {'f': 2.}, 'c': {'d': 3., 'e': ([4., 5.], 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10., params)

        # Negate the updates wherever the mask (or mask parent) is True
        def _masked_sgd_on_updates(m, upd):
            return jax.tree_map(lambda x: -x, upd) if m else upd

        correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask,
                                            input_updates)

        init_fn, update_fn = wrappers.masked(opt_builder(), mask)
        update_fn = self.variant(update_fn)
        state = self.variant(init_fn)(params)
        updates, state = update_fn(input_updates, state, params)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = jax.tree_multimap(_masked_sgd_on_updates, mask,
                                            correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)
Ejemplo n.º 2
0
def add_decayed_weights(
    weight_decay: float = 0.0,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
) -> base.GradientTransformation:
    """Add parameter scaled by `weight_decay`.

  Args:
    weight_decay: a scalar weight decay rate.
    mask: a tree with same structure as (or a prefix of) the params PyTree,
      or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.

  Returns:
    An (init_fn, update_fn) tuple.
  """
    def init_fn(_):
        return AddDecayedWeightsState()

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError(base.NO_PARAMS_MSG)
        updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates,
                                    params)
        return updates, state

    # If mask is not `None`, apply mask to the gradient transformation.
    # E.g. it is common to skip weight decay on bias units and batch stats.
    if mask is not None:
        return wrappers.masked(base.GradientTransformation(init_fn, update_fn),
                               mask)
    return base.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 3
0
  def test_update_requires_params(self):
    weight_decay = 0.1
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, input_updates, params)

    init_fn, update_fn = wrappers.masked(
        transform.additive_weight_decay(weight_decay), mask)
    update_fn = self.variant(update_fn)

    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    params = update.apply_updates(params, updates)

    # Test repeated application
    new_correct_updates = jax.tree_util.tree_multimap(
        lambda m, u, p: u + weight_decay * p if m else u,
        mask, correct_updates, params)
    updates, state = update_fn(correct_updates, state, params)
    chex.assert_tree_all_close(updates, new_correct_updates)
Ejemplo n.º 4
0
    def test_masked(self, opt_builder, use_fn):
        mask = {
            'a': True,
            'b': [False, True],
            'c': {
                'd': True,
                'e': (False, True)
            }
        }
        mask_arg = lambda _: mask if use_fn else mask
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)
        input_updates = jax.tree_map(lambda x: x / 10., params)

        # Negate the updates wherever the mask is True
        def masked_negate(updates):
            return jax.tree_multimap(lambda upd, m: -upd
                                     if m else upd, updates, mask)

        correct_updates = masked_negate(input_updates)

        init_fn, update_fn = wrappers.masked(opt_builder(), mask_arg)
        update_fn = self.variant(update_fn)
        state = self.variant(init_fn)(params)
        updates, state = update_fn(input_updates, state, params)
        chex.assert_tree_all_close(updates, correct_updates)

        # Check repeated application, this time with no params.
        correct_updates = masked_negate(correct_updates)
        updates, state = update_fn(updates, state)
        chex.assert_tree_all_close(updates, correct_updates)
Ejemplo n.º 5
0
 def update_fn(updates, state, params=None):
     labels = param_labels(updates) if callable(
         param_labels) else param_labels
     new_inner_state = {}
     for group, tx in transforms.items():
         masked_tx = wrappers.masked(tx, make_mask(labels, group))
         updates, new_inner_state[group] = masked_tx.update(
             updates, state.inner_states[group], params)
     return updates, MultiTransformState(new_inner_state)
Ejemplo n.º 6
0
  def test_tree_mismatch_fails(self, extra_key_in_mask):
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}

    if extra_key_in_mask:
      mask['c']['extra'] = True
    else:
      params['c']['extra'] = 7

    init_fn = wrappers.masked(_build_sgd(), mask)[0]
    with self.assertRaises(ValueError):
      init_fn(params)
Ejemplo n.º 7
0
    def init_fn(params):
        labels = param_labels(params) if callable(
            param_labels) else param_labels

        label_set = set(jax.tree_leaves(labels))
        if not label_set.issubset(transforms.keys()):
            raise ValueError(
                'Some parameters have no corresponding transformation.\n'
                f'Parameter labels: {list(sorted(label_set))} \n'
                f'Transforms keys: {list(sorted(transforms.keys()))} \n')

        inner_states = {
            group: wrappers.masked(tx, make_mask(labels, group)).init(params)
            for group, tx in transforms.items()
        }
        return MultiTransformState(inner_states)
Ejemplo n.º 8
0
def lars(
    learning_rate: ScalarOrSchedule,
    weight_decay: float = 0.,
    weight_decay_mask: MaskOrFn = True,
    trust_coefficient: float = 0.001,
    eps: float = 0.,
    trust_ratio_mask: MaskOrFn = True,
    momentum: float = 0.9,
    nesterov: bool = False,
) -> base.GradientTransformation:
  """The LARS optimiser.

  LAMB is a layer-wise adaptive optimiser introduced to help scale SGD to
  larger batch sizes. LARS later inspired the LAMB optimiser.

  References:
    You et al, 2017: https://arxiv.org/abs/1708.03888

  Args:
    learning_rate: this is a fixed global scaling factor.
    weight_decay (default `0.`): strength of the weight decay regularization.
    weight_decay_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    trust_coefficient: a multiplier for the trust ratio.
    eps: optional additive constant in the trust ratio denominator.
    trust_ratio_mask: a tree with same structure as (or a prefix of) the params
      PyTree, or a Callable that returns such a pytree given the params/updates.
      The leaves should be booleans, `True` for leaves/subtrees you want to
      apply the transformation to, and `False` for those you want to skip.
    momentum: the decay rate for momentum.
    nesterov: whether to use Nesterov momentum.

  Returns:
    the corresponding `GradientTransformation`.
  """
  return combine.chain(
      transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
      wrappers.masked(
          inner=transform.scale_by_trust_ratio(
              trust_coefficient=trust_coefficient, eps=eps),
          mask=trust_ratio_mask),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=nesterov),
  )
Ejemplo n.º 9
0
    def test_mask_fn(self):
        params = {
            'a': jnp.ones((1, 2)),
            'b': (jnp.ones((1, )), np.ones((1, 2, 3)))
        }
        mask_fn = lambda p: jax.tree_map(lambda x: x.ndim > 1, p)
        init_fn, update_fn = wrappers.masked(
            transform.add_decayed_weights(0.1), mask_fn)
        update_fn = self.variant(update_fn)

        state = self.variant(init_fn)(params)
        grads = jax.tree_map(lambda x: x * 2, params)
        updates, state = update_fn(grads, state, params)
        np.testing.assert_allclose(updates['a'],
                                   grads['a'] + 0.1 * params['a'])
        np.testing.assert_allclose(updates['b'][0], grads['b'][0])
        np.testing.assert_allclose(updates['b'][1],
                                   grads['b'][1] + 0.1 * params['b'][1])
Ejemplo n.º 10
0
    def test_tree_mismatch_fails(self, extra_key_in_mask, use_fn):
        mask = {
            'a': True,
            'b': [False, True],
            'c': {
                'd': True,
                'e': (False, True)
            }
        }
        mask_arg = lambda _: mask if use_fn else mask
        params = {'a': 1., 'b': [2., 3.], 'c': {'d': 4., 'e': (5., 6.)}}
        params = jax.tree_map(jnp.asarray, params)

        if extra_key_in_mask:
            mask['c']['extra'] = True
        else:
            params['c']['extra'] = 7

        init_fn = wrappers.masked(_build_sgd(), mask_arg)[0]
        with self.assertRaises(ValueError):
            init_fn(params)
Ejemplo n.º 11
0
  def test_masked(self, opt_builder):
    mask = {'a': True,
            'b': [False, True],
            'c': {'d': True, 'e': (False, True)}}
    params = {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': (5, 6)}}
    input_updates = jax.tree_util.tree_map(lambda x: x/10., params)

    # Negate the updates wherever the mask is True
    def masked_negate(updates):
      return jax.tree_util.tree_multimap(
          lambda upd, m: -upd if m else upd, updates, mask)
    correct_updates = masked_negate(input_updates)

    init_fn, update_fn = wrappers.masked(opt_builder(), mask)
    update_fn = self.variant(update_fn)
    state = init_fn(params)
    updates, state = update_fn(input_updates, state, params)
    chex.assert_tree_all_close(updates, correct_updates)

    # Check repeated application, this time with no params.
    correct_updates = masked_negate(correct_updates)
    updates, state = update_fn(updates, state)
    chex.assert_tree_all_close(updates, correct_updates)
Ejemplo n.º 12
0
 def custom_optim(learning_rate, mask):
     return wrappers.masked(transform.scale(-learning_rate), mask)
Ejemplo n.º 13
0
 def test_empty(self, container):
   init_fn, update_fn = wrappers.masked(_build_sgd(), container())
   update_fn(container(), init_fn(container()))