def update_fn(updates, state, params=None): del params step_size = step_size_fn(state.count) updates = jax.tree_map( lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates) return updates, ScaleByScheduleState( count=utils.safe_int32_increment(state.count))
def update_fn(updates, state, params=None): del params new_ema = _update_moment(updates, state.ema, decay, order=1) count_inc = utils.safe_int32_increment(state.count) if debias: updates = _bias_correction(new_ema, decay, count_inc) new_ema = utils.cast_tree(new_ema, accumulator_dtype) return updates, EmaState(count=count_inc, ema=new_ema)
def update_fn(updates, state, params=None): del params c = state.count % k acc = c != 0 grad_acc = jax.tree_multimap( lambda g, ga: acc * ga + g, updates, state.grad_acc) emit = c == (k - 1) updates = jax.tree_map(lambda ga: emit * ga, grad_acc) count_inc = utils.safe_int32_increment(state.count) return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc)
def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) count_inc = utils.safe_int32_increment(state.count) mu_hat = _bias_correction(mu, b1, count_inc) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.tree_multimap( lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
def update_fn(updates, state, params=None): del params mu = _update_moment(updates, state.mu, b1, 1) nu = _update_moment(updates, state.nu, b2, 2) count_inc = utils.safe_int32_increment(state.count) b2t = b2**count_inc ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) mu_hat = _bias_correction(mu, b1, count_inc) nu_hat = _bias_correction(nu, b2, count_inc) updates = jax.lax.cond(ro >= threshold, _radam_update, lambda _: mu_hat, (ro, mu_hat, nu_hat)) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
def update_fn(updates, state, params=None): # pylint: disable=missing-docstring del params num_vars = len(jax.tree_leaves(updates)) treedef = jax.tree_structure(updates) count_inc = utils.safe_int32_increment(state.count) variance = eta / count_inc**gamma all_keys = jax.random.split(state.rng_key, num=num_vars + 1) noise = jax.tree_multimap( lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), updates, jax.tree_unflatten(treedef, all_keys[1:])) updates = jax.tree_multimap( lambda g, n: g + variance.astype(g.dtype) * n, updates, noise) return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])
def update_fn(grads, state, params): """Apply gradient transformation.""" if params is None: raise ValueError(base.NO_PARAMS_MSG) def _update(grad, v_row, v_col, v, param, step): shape = param.shape decay_rate_t = _decay_rate_pow(step - step_offset, decay_rate) # Scaled by factorized second moment statistics. new_v_row = jnp.zeros((1, )) new_v_col = jnp.zeros((1, )) new_v = jnp.zeros((1, )) factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) if factored_dims is not None: d1, d0 = factored_dims grad_sqr = numerics.abs_sq(grad) + epsilon new_v_row = (decay_rate_t * v_row + (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0)) new_v_col = (decay_rate_t * v_col + (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1)) reduced_d1 = d1 - 1 if d1 > d0 else d1 row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 update = (grad * jnp.expand_dims(row_factor, axis=d0) * jnp.expand_dims(col_factor, axis=d1)) else: grad_sqr = numerics.abs_sq(grad) + epsilon new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr update = grad * (new_v)**-0.5 return _UpdateResult(update, new_v_row, new_v_col, new_v) # Transform grad and compute new per-parameter stats. output = jax.tree_multimap(lambda *args: _update(*args, state.count), grads, state.v_row, state.v_col, state.v, params) # Unpack updates / stats and return. updates = jax.tree_map(lambda o: o.update, output) return updates, _to_state(utils.safe_int32_increment(state.count), output)
def update_fn(updates, state, params=None): count_inc = utils.safe_int32_increment(state.count) dtype = getattr(next(iter(jax.tree_leaves(updates)), None), 'dtype', None) hparams = { k: _convert_floats(v, dtype) for k, v in state.hyperparams.items() } hparams.update(schedule_fn(count_inc, dtype)) updates, inner_state = inner_factory(**other_hps, **hparams).update( updates, state.inner_state, params) # pylint:disable=too-many-function-args return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
def update_fn(updates, state, params=None): del params count_inc = utils.safe_int32_increment(state.count) new_vals = _bias_correction(updates, decay, count_inc) return new_vals, BiasCorrectionState(count=count_inc)