Exemple #1
0
 def update_fn(updates, state, params=None):
   del params
   step_size = step_size_fn(state.count)
   updates = jax.tree_map(
       lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates)
   return updates, ScaleByScheduleState(
       count=utils.safe_int32_increment(state.count))
Exemple #2
0
 def update_fn(updates, state, params=None):
     del params
     new_ema = _update_moment(updates, state.ema, decay, order=1)
     count_inc = utils.safe_int32_increment(state.count)
     if debias:
         updates = _bias_correction(new_ema, decay, count_inc)
     new_ema = utils.cast_tree(new_ema, accumulator_dtype)
     return updates, EmaState(count=count_inc, ema=new_ema)
Exemple #3
0
 def update_fn(updates, state, params=None):
   del params
   c = state.count % k
   acc = c != 0
   grad_acc = jax.tree_multimap(
       lambda g, ga: acc * ga + g, updates, state.grad_acc)
   emit = c == (k - 1)
   updates = jax.tree_map(lambda ga: emit * ga, grad_acc)
   count_inc = utils.safe_int32_increment(state.count)
   return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc)
Exemple #4
0
 def update_fn(updates, state, params=None):
   del params
   mu = _update_moment(updates, state.mu, b1, 1)
   nu = _update_moment(updates, state.nu, b2, 2)
   count_inc = utils.safe_int32_increment(state.count)
   mu_hat = _bias_correction(mu, b1, count_inc)
   nu_hat = _bias_correction(nu, b2, count_inc)
   updates = jax.tree_multimap(
       lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
   return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Exemple #5
0
 def update_fn(updates, state, params=None):
     del params
     mu = _update_moment(updates, state.mu, b1, 1)
     nu = _update_moment(updates, state.nu, b2, 2)
     count_inc = utils.safe_int32_increment(state.count)
     b2t = b2**count_inc
     ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)
     mu_hat = _bias_correction(mu, b1, count_inc)
     nu_hat = _bias_correction(nu, b2, count_inc)
     updates = jax.lax.cond(ro >= threshold, _radam_update,
                            lambda _: mu_hat, (ro, mu_hat, nu_hat))
     return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Exemple #6
0
 def update_fn(updates, state, params=None):  # pylint: disable=missing-docstring
     del params
     num_vars = len(jax.tree_leaves(updates))
     treedef = jax.tree_structure(updates)
     count_inc = utils.safe_int32_increment(state.count)
     variance = eta / count_inc**gamma
     all_keys = jax.random.split(state.rng_key, num=num_vars + 1)
     noise = jax.tree_multimap(
         lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype),
         updates, jax.tree_unflatten(treedef, all_keys[1:]))
     updates = jax.tree_multimap(
         lambda g, n: g + variance.astype(g.dtype) * n, updates, noise)
     return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])
Exemple #7
0
    def update_fn(grads, state, params):
        """Apply gradient transformation."""
        if params is None:
            raise ValueError(base.NO_PARAMS_MSG)

        def _update(grad, v_row, v_col, v, param, step):
            shape = param.shape
            decay_rate_t = _decay_rate_pow(step - step_offset, decay_rate)

            # Scaled by factorized second moment statistics.
            new_v_row = jnp.zeros((1, ))
            new_v_col = jnp.zeros((1, ))
            new_v = jnp.zeros((1, ))

            factored_dims = _factored_dims(shape, factored,
                                           min_dim_size_to_factor)
            if factored_dims is not None:
                d1, d0 = factored_dims
                grad_sqr = numerics.abs_sq(grad) + epsilon
                new_v_row = (decay_rate_t * v_row +
                             (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d0))
                new_v_col = (decay_rate_t * v_col +
                             (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1))
                reduced_d1 = d1 - 1 if d1 > d0 else d1
                row_col_mean = jnp.mean(new_v_row,
                                        axis=reduced_d1,
                                        keepdims=True)
                row_factor = (new_v_row / row_col_mean)**-0.5
                col_factor = (new_v_col)**-0.5
                update = (grad * jnp.expand_dims(row_factor, axis=d0) *
                          jnp.expand_dims(col_factor, axis=d1))
            else:
                grad_sqr = numerics.abs_sq(grad) + epsilon
                new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr
                update = grad * (new_v)**-0.5

            return _UpdateResult(update, new_v_row, new_v_col, new_v)

        # Transform grad and compute new per-parameter stats.
        output = jax.tree_multimap(lambda *args: _update(*args, state.count),
                                   grads, state.v_row, state.v_col, state.v,
                                   params)

        # Unpack updates / stats and return.
        updates = jax.tree_map(lambda o: o.update, output)
        return updates, _to_state(utils.safe_int32_increment(state.count),
                                  output)
Exemple #8
0
        def update_fn(updates, state, params=None):
            count_inc = utils.safe_int32_increment(state.count)
            dtype = getattr(next(iter(jax.tree_leaves(updates)), None),
                            'dtype', None)
            hparams = {
                k: _convert_floats(v, dtype)
                for k, v in state.hyperparams.items()
            }
            hparams.update(schedule_fn(count_inc, dtype))
            updates, inner_state = inner_factory(**other_hps,
                                                 **hparams).update(
                                                     updates,
                                                     state.inner_state, params)

            # pylint:disable=too-many-function-args
            return updates, InjectHyperparamsState(count_inc, hparams,
                                                   inner_state)
Exemple #9
0
 def update_fn(updates, state, params=None):
     del params
     count_inc = utils.safe_int32_increment(state.count)
     new_vals = _bias_correction(updates, decay, count_inc)
     return new_vals, BiasCorrectionState(count=count_inc)