Ejemplo n.º 1
0
    def update(updates, state, params=None):
        inner_state = state.inner_state
        flat_updates = tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(
            isfinite, jnp.zeros([], jnp.int32),
            numerics.safe_int32_increment(state.notfinite_count))

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = lax.cond(jnp.logical_or(
            isfinite, notfinite_count > max_consecutive_errors),
                                            do_update,
                                            reject_update,
                                            operand=None)

        return updates, ApplyIfFiniteState(notfinite_count=notfinite_count,
                                           last_finite=isfinite,
                                           total_notfinite=jnp.where(
                                               isfinite, state.total_notfinite,
                                               numerics.safe_int32_increment(
                                                   state.total_notfinite)),
                                           inner_state=new_inner_state)
Ejemplo n.º 2
0
 def update_fn(updates, state, params=None):
     del params
     step_size = step_size_fn(state.count)
     updates = jax.tree_map(
         lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates)
     return updates, ScaleByScheduleState(
         count=numerics.safe_int32_increment(state.count))
Ejemplo n.º 3
0
 def update_fn(updates, state, params=None):
     del params
     c = state.count % k
     acc = c != 0
     grad_acc = jax.tree_multimap(lambda g, ga: acc * ga + g, updates,
                                  state.grad_acc)
     emit = c == (k - 1)
     updates = jax.tree_map(lambda ga: emit * ga, grad_acc)
     count_inc = numerics.safe_int32_increment(state.count)
     return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc)
Ejemplo n.º 4
0
 def update_fn(updates, state, params=None):
     del params
     mu = _update_moment(updates, state.mu, b1, 1)
     nu = _update_moment(updates, state.nu, b2, 2)
     count_inc = numerics.safe_int32_increment(state.count)
     mu_hat = _bias_correction(mu, b1, count_inc)
     nu_hat = _bias_correction(nu, b2, count_inc)
     updates = jax.tree_multimap(
         lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
     return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Ejemplo n.º 5
0
    def update_fn(updates, state, params=None):
      count_inc = numerics.safe_int32_increment(state.count)
      dtype = getattr(next(iter(jax.tree_leaves(updates)), None), 'dtype', None)
      hparams = {k: _convert_floats(v, dtype)
                 for k, v in state.hyperparams.items()}
      hparams.update(schedule_fn(count_inc, dtype))
      updates, inner_state = inner_factory(**other_hps, **hparams).update(
          updates, state.inner_state, params)

      # pylint:disable=too-many-function-args
      return updates, InjectHyperparamsState(count_inc, hparams, inner_state)
Ejemplo n.º 6
0
 def update_fn(updates, state, params=None):
     del params
     mu = _update_moment(updates, state.mu, b1, 1)
     nu = _update_moment(updates, state.nu, b2, 2)
     count_inc = numerics.safe_int32_increment(state.count)
     b2t = b2**count_inc
     ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)
     mu_hat = _bias_correction(mu, b1, count_inc)
     nu_hat = _bias_correction(nu, b2, count_inc)
     updates = jax.lax.cond(ro >= threshold, _radam_update,
                            lambda _: mu_hat, (ro, mu_hat, nu_hat))
     return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Ejemplo n.º 7
0
 def final_step(args):
     del args
     updates, new_inner_state = self._opt.update(acc_grads,
                                                 state.inner_opt_state,
                                                 params=params)
     new_state = MultiStepsState(
         mini_step=jnp.zeros([], dtype=jnp.int32),
         gradient_step=numerics.safe_int32_increment(
             state.gradient_step),
         inner_opt_state=new_inner_state,
         acc_grads=_zeros_tree_like(acc_grads))
     return updates, new_state
Ejemplo n.º 8
0
 def update_fn(updates, state, params=None):
   del params
   mu = _update_moment(updates, state.mu, b1, 1)
   prediction_error = jax.tree_multimap(lambda g, m: g-m, updates, state.mu)
   nu = _update_moment_per_elem_norm(prediction_error, state.nu, b2, 2)
   nu = jax.tree_map(lambda v: v + eps_root, nu)
   count_inc = numerics.safe_int32_increment(state.count)
   mu_hat = _bias_correction(mu, b1, count_inc)
   nu_hat = _bias_correction(nu, b2, count_inc)
   updates = jax.tree_multimap(
       lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
   return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu)
Ejemplo n.º 9
0
 def update_fn(updates, state, params=None):  # pylint: disable=missing-docstring
     del params
     num_vars = len(jax.tree_leaves(updates))
     treedef = jax.tree_structure(updates)
     count_inc = numerics.safe_int32_increment(state.count)
     variance = eta / count_inc**gamma
     all_keys = jax.random.split(state.rng_key, num=num_vars + 1)
     noise = jax.tree_multimap(
         lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype),
         updates, jax.tree_unflatten(treedef, all_keys[1:]))
     updates = jax.tree_multimap(
         lambda g, n: g + variance.astype(g.dtype) * n, updates, noise)
     return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0])
Ejemplo n.º 10
0
    def update_fn(updates, state, params=None):
        def do_update(_):
            return inner.update(updates, state.inner_state, params)

        def reject_update(_):
            return updates, state.inner_state

        updates, new_inner_state = lax.cond(should_update_fn(state.step),
                                            do_update,
                                            reject_update,
                                            operand=None)
        return updates, MaybeUpdateState(
            new_inner_state, numerics.safe_int32_increment(state.step))
Ejemplo n.º 11
0
 def mid_step(args):
     del args
     updates_shape_dtype, _ = jax.eval_shape(self._opt.update,
                                             acc_grads,
                                             state.inner_opt_state,
                                             params=params)
     updates = jax.tree_map(lambda sd: jnp.zeros(sd.shape, sd.dtype),
                            updates_shape_dtype)
     new_state = MultiStepsState(
         mini_step=numerics.safe_int32_increment(state.mini_step),
         gradient_step=state.gradient_step,
         inner_opt_state=state.inner_opt_state,
         acc_grads=acc_grads)
     return updates, new_state