Ejemplo n.º 1
0
def and_mask(agreement_threshold: float) -> optax.GradientTransformation:
  def init_fn(_):
    # Required by optax
    return ANDMaskState()

  def update_fn(updates, opt_state, params=None):

    def and_mask(update):
      # Compute the masked gradients for a single parameter tensor
      mask = jnp.abs(jnp.mean(jnp.sign(update), 0)) >= agreement_threshold
      mask = mask.astype(jnp.float32)
      avg_update = jnp.mean(update, 0)
      mask_t = mask.sum() / mask.size
      update = mask * avg_update * (1. / (1e-10 + mask_t))
      return update

    del params # Following optax code style
    
    # Compute the masked gradients over all parameters

    # jax.tree_map maps a function (lambda function in this case) over a pytree to produce a new pytree.
    updates = jax.tree_map(lambda x: and_mask(x), updates)
    return updates, opt_state

  return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 2
0
def masked(inner: transform.GradientTransformation,
           mask: Any) -> transform.GradientTransformation:
    """Mask updates so only a subset of them are computed.

  For example, it is common to skip weight decay for BatchNorm scale and all
  bias parameters. In many networks, these are the only parameters with only
  one dimension. So, you may mask these out as follows:

  ```
  mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params)
  custom_weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)
  ```

  For the `inner` transform, state will only be stored for the parameters that
  have a mask value of `True`.

  Args:
    inner: Inner transformation to mask.
    mask: A PyTree with the same structure as the parameters or is a prefix of
      the parameter PyTree. The leaves should be booleans which are `True` for
      leaves/subtrees you want to apply the transformation to, and `False` for
      those you want to skip.

  Returns:
    New GradientTransformation wrapping `inner`.
  """
    flat_mask, treedef = tree_flatten(mask)

    def init_fn(params):
        flat_params = treedef.flatten_up_to(params)
        masked_params = [p for p, m in zip(flat_params, flat_mask) if m]
        return MaskedState(inner_state=inner.init(masked_params))

    def update_fn(updates, state, params=None):
        # Flatten then filter out updates/params not in the mask:
        flat_updates = treedef.flatten_up_to(updates)
        masked_updates = [g for g, m in zip(flat_updates, flat_mask) if m]

        if params:
            flat_params = treedef.flatten_up_to(params)
            masked_params = [p for p, m in zip(flat_params, flat_mask) if m]
        else:
            masked_params = None

        # Compute new updates
        new_masked_updates, new_inner_state = inner.update(
            masked_updates, state.inner_state, masked_params)

        # Incorporate new_masked_updates into flat_updates, then unflatten
        new_masked_updates = iter(new_masked_updates)
        for i, m in enumerate(flat_mask):
            if m: flat_updates[i] = next(new_masked_updates)

        new_updates = treedef.unflatten(flat_updates)
        return new_updates, MaskedState(inner_state=new_inner_state)

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 3
0
    def wrapped_transform(*args, **kwargs) -> transform.GradientTransformation:
        bound_arguments = inner_signature.bind(*args, **kwargs)
        bound_arguments.apply_defaults()

        sched_hps, numeric_hps, other_hps = {}, {}, {}
        for name, value in bound_arguments.arguments.items():
            if name in static_args or isinstance(value, bool):
                other_hps[name] = value
            elif callable(value):
                sched_hps[name] = value
            elif isinstance(value, (int, float, chex.Array)):
                numeric_hps[name] = value
            else:
                other_hps[name] = value

        def schedule_fn(count, dtype):
            return {
                k: _convert_floats(f(count), dtype)
                for k, f in sched_hps.items()
            }

        def init_fn(params):
            count = jnp.zeros([], jnp.int32)
            dtype = getattr(next(iter(jax.tree_leaves(params)), None), 'dtype',
                            None)
            hparams = {
                k: jnp.asarray(_convert_floats(v, dtype))
                for k, v in numeric_hps.items()
            }
            hparams.update(schedule_fn(count, dtype))
            return InjectHyperparamsState(  # pylint:disable=too-many-function-args
                count, hparams,
                inner_factory(**other_hps, **hparams).init(params))

        def update_fn(updates, state, params=None):
            count_inc = utils.safe_int32_increment(state.count)
            dtype = getattr(next(iter(jax.tree_leaves(updates)), None),
                            'dtype', None)
            hparams = {
                k: _convert_floats(v, dtype)
                for k, v in state.hyperparams.items()
            }
            hparams.update(schedule_fn(count_inc, dtype))
            updates, inner_state = inner_factory(**other_hps,
                                                 **hparams).update(
                                                     updates,
                                                     state.inner_state, params)

            # pylint:disable=too-many-function-args
            return updates, InjectHyperparamsState(count_inc, hparams,
                                                   inner_state)
            # pylint:enable=too-many-function-args

        return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 4
0
def apply_if_finite(
        inner: transform.GradientTransformation,
        max_consecutive_errors: int) -> transform.GradientTransformation:
    """A function that wraps an optimiser to make it robust to a few NaNs or Infs.

  The purpose of this function is to prevent any optimisation to happen if the
  gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the
  gradients, the wrapped optimiser ignores that gradient update. If the NaNs or
  Infs persist after a given number of updates, the wrapped optimiser gives up
  and accepts the update.

  Args:
    inner: Inner transformation to be wrapped.
    max_consecutive_errors: Maximum number of consecutive gradient updates
      containing NaNs of Infs that the wrapped optimiser will ignore. After
      that many ignored updates, the optimiser will give up and accept.

  Returns:
    New GradientTransformation.
  """
    def init(params):
        return ApplyIfFiniteState(notfinite_count=jnp.zeros([], jnp.int64),
                                  last_finite=jnp.array(True, jnp.bool_),
                                  total_notfinite=jnp.zeros([], jnp.int64),
                                  inner_state=inner.init(params))

    def update(updates, state, params=None):
        inner_state = state.inner_state
        flat_updates = tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                    1 + state.notfinite_count)

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = lax.cond(jnp.logical_or(
            isfinite, notfinite_count > max_consecutive_errors),
                                            do_update,
                                            reject_update,
                                            operand=None)

        return updates, ApplyIfFiniteState(
            notfinite_count=notfinite_count,
            last_finite=isfinite,
            total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
            inner_state=new_inner_state)

    return transform.GradientTransformation(init=init, update=update)
Ejemplo n.º 5
0
def differentially_private_aggregate(
        l2_norm_clip: float, noise_multiplier: float,
        seed: int) -> transform.GradientTransformation:
    """Aggregates gradients based on the DPSGD algorithm.

  WARNING: Unlike other transforms, `differentially_private_aggregate` expects
  the input updates to have a batch dimension in the 0th axis. That is, this
  function expects per-example gradients as input (which are easy to obtain in
  JAX using `jax.vmap`). It can still be composed with other transformations as
  long as it is the first in the chain.

  References:
    [Abadi et al, 2016](https://arxiv.org/abs/1607.00133)

  Args:
    l2_norm_clip: maximum L2 norm of the per-example gradients.
    noise_multiplier: ratio of standard deviation to the clipping norm.
    seed: initial seed used for the jax.random.PRNGKey

  Returns:
    A `GradientTransformation`.
  """
    noise_std = l2_norm_clip * noise_multiplier

    def init_fn(_):
        return DifferentiallyPrivateAggregateState(
            rng_key=jax.random.PRNGKey(seed))

    def update_fn(updates, state, params=None):
        del params
        grads_flat, grads_treedef = jax.tree_flatten(updates)
        bsize = grads_flat[0].shape[0]

        if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat):
            raise ValueError(
                'Unlike other transforms, `differentially_private_aggregate` expects'
                ' `updates` to have a batch dimension in the 0th axis. That is, this'
                ' function expects per-example gradients as input.')

        new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1)
        global_grad_norms = jax.vmap(transform.global_norm)(grads_flat)
        divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0)
        clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1)
                   for g in grads_flat]
        noised = [
            (g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize
            for g, r in zip(clipped, rngs)
        ]
        return (jax.tree_unflatten(grads_treedef, noised),
                DifferentiallyPrivateAggregateState(rng_key=new_key))

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 6
0
def flatten(
    inner: transform.GradientTransformation
) -> transform.GradientTransformation:
    """Flattens parameters and gradients for init and update of inner transform.

  This can reduce the overhead of performing many calculations on lots of small
  variables, at the cost of slightly increased memory usage.

  Args:
    inner: Inner transformation to flatten inputs for.

  Returns:
    New GradientTransformation.
  """
    def _flatten(params):
        """Flattens and concatenates all tensors in params to a single vector."""
        params, _ = tree_flatten(params)
        return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])

    def _unflatten(updates, flat):
        """Extracts tensors from flat, using the structure and shapes of params."""
        updates_flat, treedef = tree_flatten(updates)
        offsets = []
        for update in updates_flat:
            size = np.prod(update.shape)
            if offsets:
                offsets.append(size + offsets[-1])
            else:
                offsets.append(size)
        del offsets[-1]
        flat_split = jnp.split(flat, offsets)
        reshaped = [
            jnp.reshape(flat_update, update.shape)
            for flat_update, update in zip(flat_split, updates_flat)
        ]
        return tree_unflatten(treedef, reshaped)

    def init_fn(params):
        flat = _flatten(params)
        return inner.init(flat)

    def update_fn(updates, state, params=None):
        if params is not None:
            params = _flatten(params)
        updates_flat, state = inner.update(_flatten(updates), state, params)
        updates = _unflatten(updates, updates_flat)
        return updates, state

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 7
0
def test_optimizer(step_size: float) -> transform.GradientTransformation:
    """Fast optimizer for the lookahead tests."""

    # Use SGD for simplicity but add non-trivial optimizer state so that the
    # resetting behaviour of lookahead can be tested.
    def init_fn(params):
        aggregate_grads = jax.tree_map(jnp.zeros_like, params)
        return TestOptimizerState(aggregate_grads, is_reset=True)

    def update_fn(updates, state, params=None):
        del params  # unused by the test optimizer
        aggregate_grads = update.apply_updates(state.aggregate_grads, updates)
        updates = jax.tree_map(lambda u: step_size * u, updates)
        return updates, TestOptimizerState(aggregate_grads, is_reset=False)

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 8
0
def maybe_update(
    inner: transform.GradientTransformation,
    should_update_fn: Callable[[Array], Array]
) -> transform.GradientTransformation:
    """Calls the inner update function only at certain steps.

  Creates a transformation wrapper which counts the number of times the `update`
  function has been called. This counter is passed to the `should_update_fn` to
  decide when to call the inner update function.

  When not calling the inner update function, the `updates` and the inner state
  are left untouched and just passed through. The step counter is increased
  regardless.

  Args:
    inner: the inner transformation.
    should_update_fn: this function takes in a step counter (array of shape []
      and dtype int64), and returns a boolean array of shape [].

  Returns:
    An `optax.GradientTransformation`.
  """
    def init_fn(params):
        return MaybeUpdateState(inner_state=inner.init(params),
                                step=jnp.zeros([], dtype=jnp.int64))

    def update_fn(updates, state, params=None):
        def do_update(_):
            return inner.update(updates, state.inner_state, params)

        def reject_update(_):
            return updates, state.inner_state

        updates, new_inner_state = lax.cond(should_update_fn(state.step),
                                            do_update,
                                            reject_update,
                                            operand=None)
        return updates, MaybeUpdateState(new_inner_state, state.step + 1)

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 9
0
def lookahead(fast_optimizer: transform.GradientTransformation,
              sync_period: int,
              slow_step_size: float,
              reset_state: bool = False) -> transform.GradientTransformation:
    """Lookahead optimizer.

  Performs steps with a fast optimizer and periodically updates a set of slow
  parameters. Optionally resets the fast optimizer state after synchronization
  by calling the init function of the fast optimizer.

  Updates returned by the lookahead optimizer should not be modified before they
  are applied, otherwise fast and slow parameters are not synchronized
  correctly.

  References:
    [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

  Args:
    fast_optimizer: The optimizer to use in the inner loop of lookahead.
    sync_period: Number of fast optimizer steps to take before synchronizing
      parameters. Must be >= 1.
    slow_step_size: Step size of the slow parameter updates.
    reset_state: Whether to reset the optimizer state of the fast opimizer after
      each synchronization.

  Returns:
    A `GradientTransformation` with init and update functions. The updates
    passed to the update function should be calculated using the fast lookahead
    parameters only.
  """
    if sync_period < 1:
        raise ValueError('Synchronization period must be >= 1.')

    def init_fn(params: transform.Params) -> LookaheadState:
        try:
            fast_params = params.fast
        except AttributeError:
            # Allowing init_fn to be called with fast parameters reduces the
            # modifications necessary to adapt code to use lookahead in some cases.
            logging.warning(
                '`params` has no attribute `fast`. Continuing by assuming that '
                'only fast parameters were passed to lookahead init.')
            fast_params = params

        return LookaheadState(fast_state=fast_optimizer.init(fast_params),
                              steps_since_sync=jnp.zeros(shape=(),
                                                         dtype=jnp.int32))

    def update_fn(
            updates: transform.Updates, state: LookaheadState,
            params: LookaheadParams) -> Tuple[LookaheadParams, LookaheadState]:
        updates, fast_state = fast_optimizer.update(updates, state.fast_state,
                                                    params)

        sync_next = (state.steps_since_sync == sync_period - 1)
        updates = _lookahead_update(updates, sync_next, params, slow_step_size)
        if reset_state:
            # Jittable way of resetting the fast optimizer state if parameters will be
            # synchronized after this update step.
            initial_state = fast_optimizer.init(params.fast)
            fast_state = jax.tree_multimap(
                lambda current, init:
                (1 - sync_next) * current + sync_next * init, fast_state,
                initial_state)

        steps_since_sync = (state.steps_since_sync + 1) % sync_period
        return updates, LookaheadState(fast_state, steps_since_sync)

    return transform.GradientTransformation(init_fn, update_fn)
Ejemplo n.º 10
0
 def gradient_transformation(self) -> transform.GradientTransformation:
     return transform.GradientTransformation(init=self.init,
                                             update=self.update)