def and_mask(agreement_threshold: float) -> optax.GradientTransformation: def init_fn(_): # Required by optax return ANDMaskState() def update_fn(updates, opt_state, params=None): def and_mask(update): # Compute the masked gradients for a single parameter tensor mask = jnp.abs(jnp.mean(jnp.sign(update), 0)) >= agreement_threshold mask = mask.astype(jnp.float32) avg_update = jnp.mean(update, 0) mask_t = mask.sum() / mask.size update = mask * avg_update * (1. / (1e-10 + mask_t)) return update del params # Following optax code style # Compute the masked gradients over all parameters # jax.tree_map maps a function (lambda function in this case) over a pytree to produce a new pytree. updates = jax.tree_map(lambda x: and_mask(x), updates) return updates, opt_state return transform.GradientTransformation(init_fn, update_fn)
def masked(inner: transform.GradientTransformation, mask: Any) -> transform.GradientTransformation: """Mask updates so only a subset of them are computed. For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In many networks, these are the only parameters with only one dimension. So, you may mask these out as follows: ``` mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params) custom_weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask) ``` For the `inner` transform, state will only be stored for the parameters that have a mask value of `True`. Args: inner: Inner transformation to mask. mask: A PyTree with the same structure as the parameters or is a prefix of the parameter PyTree. The leaves should be booleans which are `True` for leaves/subtrees you want to apply the transformation to, and `False` for those you want to skip. Returns: New GradientTransformation wrapping `inner`. """ flat_mask, treedef = tree_flatten(mask) def init_fn(params): flat_params = treedef.flatten_up_to(params) masked_params = [p for p, m in zip(flat_params, flat_mask) if m] return MaskedState(inner_state=inner.init(masked_params)) def update_fn(updates, state, params=None): # Flatten then filter out updates/params not in the mask: flat_updates = treedef.flatten_up_to(updates) masked_updates = [g for g, m in zip(flat_updates, flat_mask) if m] if params: flat_params = treedef.flatten_up_to(params) masked_params = [p for p, m in zip(flat_params, flat_mask) if m] else: masked_params = None # Compute new updates new_masked_updates, new_inner_state = inner.update( masked_updates, state.inner_state, masked_params) # Incorporate new_masked_updates into flat_updates, then unflatten new_masked_updates = iter(new_masked_updates) for i, m in enumerate(flat_mask): if m: flat_updates[i] = next(new_masked_updates) new_updates = treedef.unflatten(flat_updates) return new_updates, MaskedState(inner_state=new_inner_state) return transform.GradientTransformation(init_fn, update_fn)
def wrapped_transform(*args, **kwargs) -> transform.GradientTransformation: bound_arguments = inner_signature.bind(*args, **kwargs) bound_arguments.apply_defaults() sched_hps, numeric_hps, other_hps = {}, {}, {} for name, value in bound_arguments.arguments.items(): if name in static_args or isinstance(value, bool): other_hps[name] = value elif callable(value): sched_hps[name] = value elif isinstance(value, (int, float, chex.Array)): numeric_hps[name] = value else: other_hps[name] = value def schedule_fn(count, dtype): return { k: _convert_floats(f(count), dtype) for k, f in sched_hps.items() } def init_fn(params): count = jnp.zeros([], jnp.int32) dtype = getattr(next(iter(jax.tree_leaves(params)), None), 'dtype', None) hparams = { k: jnp.asarray(_convert_floats(v, dtype)) for k, v in numeric_hps.items() } hparams.update(schedule_fn(count, dtype)) return InjectHyperparamsState( # pylint:disable=too-many-function-args count, hparams, inner_factory(**other_hps, **hparams).init(params)) def update_fn(updates, state, params=None): count_inc = utils.safe_int32_increment(state.count) dtype = getattr(next(iter(jax.tree_leaves(updates)), None), 'dtype', None) hparams = { k: _convert_floats(v, dtype) for k, v in state.hyperparams.items() } hparams.update(schedule_fn(count_inc, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update( updates, state.inner_state, params) # pylint:disable=too-many-function-args return updates, InjectHyperparamsState(count_inc, hparams, inner_state) # pylint:enable=too-many-function-args return transform.GradientTransformation(init_fn, update_fn)
def apply_if_finite( inner: transform.GradientTransformation, max_consecutive_errors: int) -> transform.GradientTransformation: """A function that wraps an optimiser to make it robust to a few NaNs or Infs. The purpose of this function is to prevent any optimisation to happen if the gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the gradients, the wrapped optimiser ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimiser gives up and accepts the update. Args: inner: Inner transformation to be wrapped. max_consecutive_errors: Maximum number of consecutive gradient updates containing NaNs of Infs that the wrapped optimiser will ignore. After that many ignored updates, the optimiser will give up and accept. Returns: New GradientTransformation. """ def init(params): return ApplyIfFiniteState(notfinite_count=jnp.zeros([], jnp.int64), last_finite=jnp.array(True, jnp.bool_), total_notfinite=jnp.zeros([], jnp.int64), inner_state=inner.init(params)) def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64), 1 + state.notfinite_count) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond(jnp.logical_or( isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState( notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite, inner_state=new_inner_state) return transform.GradientTransformation(init=init, update=update)
def differentially_private_aggregate( l2_norm_clip: float, noise_multiplier: float, seed: int) -> transform.GradientTransformation: """Aggregates gradients based on the DPSGD algorithm. WARNING: Unlike other transforms, `differentially_private_aggregate` expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using `jax.vmap`). It can still be composed with other transformations as long as it is the first in the chain. References: [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) Args: l2_norm_clip: maximum L2 norm of the per-example gradients. noise_multiplier: ratio of standard deviation to the clipping norm. seed: initial seed used for the jax.random.PRNGKey Returns: A `GradientTransformation`. """ noise_std = l2_norm_clip * noise_multiplier def init_fn(_): return DifferentiallyPrivateAggregateState( rng_key=jax.random.PRNGKey(seed)) def update_fn(updates, state, params=None): del params grads_flat, grads_treedef = jax.tree_flatten(updates) bsize = grads_flat[0].shape[0] if any(g.ndim == 0 or bsize != g.shape[0] for g in grads_flat): raise ValueError( 'Unlike other transforms, `differentially_private_aggregate` expects' ' `updates` to have a batch dimension in the 0th axis. That is, this' ' function expects per-example gradients as input.') new_key, *rngs = jax.random.split(state.rng_key, len(grads_flat) + 1) global_grad_norms = jax.vmap(transform.global_norm)(grads_flat) divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) clipped = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads_flat] noised = [ (g + noise_std * jax.random.normal(r, g.shape, g.dtype)) / bsize for g, r in zip(clipped, rngs) ] return (jax.tree_unflatten(grads_treedef, noised), DifferentiallyPrivateAggregateState(rng_key=new_key)) return transform.GradientTransformation(init_fn, update_fn)
def flatten( inner: transform.GradientTransformation ) -> transform.GradientTransformation: """Flattens parameters and gradients for init and update of inner transform. This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage. Args: inner: Inner transformation to flatten inputs for. Returns: New GradientTransformation. """ def _flatten(params): """Flattens and concatenates all tensors in params to a single vector.""" params, _ = tree_flatten(params) return jnp.concatenate([jnp.reshape(param, [-1]) for param in params]) def _unflatten(updates, flat): """Extracts tensors from flat, using the structure and shapes of params.""" updates_flat, treedef = tree_flatten(updates) offsets = [] for update in updates_flat: size = np.prod(update.shape) if offsets: offsets.append(size + offsets[-1]) else: offsets.append(size) del offsets[-1] flat_split = jnp.split(flat, offsets) reshaped = [ jnp.reshape(flat_update, update.shape) for flat_update, update in zip(flat_split, updates_flat) ] return tree_unflatten(treedef, reshaped) def init_fn(params): flat = _flatten(params) return inner.init(flat) def update_fn(updates, state, params=None): if params is not None: params = _flatten(params) updates_flat, state = inner.update(_flatten(updates), state, params) updates = _unflatten(updates, updates_flat) return updates, state return transform.GradientTransformation(init_fn, update_fn)
def test_optimizer(step_size: float) -> transform.GradientTransformation: """Fast optimizer for the lookahead tests.""" # Use SGD for simplicity but add non-trivial optimizer state so that the # resetting behaviour of lookahead can be tested. def init_fn(params): aggregate_grads = jax.tree_map(jnp.zeros_like, params) return TestOptimizerState(aggregate_grads, is_reset=True) def update_fn(updates, state, params=None): del params # unused by the test optimizer aggregate_grads = update.apply_updates(state.aggregate_grads, updates) updates = jax.tree_map(lambda u: step_size * u, updates) return updates, TestOptimizerState(aggregate_grads, is_reset=False) return transform.GradientTransformation(init_fn, update_fn)
def maybe_update( inner: transform.GradientTransformation, should_update_fn: Callable[[Array], Array] ) -> transform.GradientTransformation: """Calls the inner update function only at certain steps. Creates a transformation wrapper which counts the number of times the `update` function has been called. This counter is passed to the `should_update_fn` to decide when to call the inner update function. When not calling the inner update function, the `updates` and the inner state are left untouched and just passed through. The step counter is increased regardless. Args: inner: the inner transformation. should_update_fn: this function takes in a step counter (array of shape [] and dtype int64), and returns a boolean array of shape []. Returns: An `optax.GradientTransformation`. """ def init_fn(params): return MaybeUpdateState(inner_state=inner.init(params), step=jnp.zeros([], dtype=jnp.int64)) def update_fn(updates, state, params=None): def do_update(_): return inner.update(updates, state.inner_state, params) def reject_update(_): return updates, state.inner_state updates, new_inner_state = lax.cond(should_update_fn(state.step), do_update, reject_update, operand=None) return updates, MaybeUpdateState(new_inner_state, state.step + 1) return transform.GradientTransformation(init_fn, update_fn)
def lookahead(fast_optimizer: transform.GradientTransformation, sync_period: int, slow_step_size: float, reset_state: bool = False) -> transform.GradientTransformation: """Lookahead optimizer. Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer. Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly. References: [Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf) Args: fast_optimizer: The optimizer to use in the inner loop of lookahead. sync_period: Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1. slow_step_size: Step size of the slow parameter updates. reset_state: Whether to reset the optimizer state of the fast opimizer after each synchronization. Returns: A `GradientTransformation` with init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only. """ if sync_period < 1: raise ValueError('Synchronization period must be >= 1.') def init_fn(params: transform.Params) -> LookaheadState: try: fast_params = params.fast except AttributeError: # Allowing init_fn to be called with fast parameters reduces the # modifications necessary to adapt code to use lookahead in some cases. logging.warning( '`params` has no attribute `fast`. Continuing by assuming that ' 'only fast parameters were passed to lookahead init.') fast_params = params return LookaheadState(fast_state=fast_optimizer.init(fast_params), steps_since_sync=jnp.zeros(shape=(), dtype=jnp.int32)) def update_fn( updates: transform.Updates, state: LookaheadState, params: LookaheadParams) -> Tuple[LookaheadParams, LookaheadState]: updates, fast_state = fast_optimizer.update(updates, state.fast_state, params) sync_next = (state.steps_since_sync == sync_period - 1) updates = _lookahead_update(updates, sync_next, params, slow_step_size) if reset_state: # Jittable way of resetting the fast optimizer state if parameters will be # synchronized after this update step. initial_state = fast_optimizer.init(params.fast) fast_state = jax.tree_multimap( lambda current, init: (1 - sync_next) * current + sync_next * init, fast_state, initial_state) steps_since_sync = (state.steps_since_sync + 1) % sync_period return updates, LookaheadState(fast_state, steps_since_sync) return transform.GradientTransformation(init_fn, update_fn)
def gradient_transformation(self) -> transform.GradientTransformation: return transform.GradientTransformation(init=self.init, update=self.update)