コード例 #1
0
ファイル: util.py プロジェクト: jeffhsu3/mesh-transformer-jax
def clip_by_global_norm(max_norm) -> GradientTransformation:
    """Clip updates using their global norm.

    References:
      [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

    Args:
      max_norm: the maximum global norm for an update.

    Returns:
      An (init_fn, update_fn) tuple.
    """

    def init_fn(_):
        return ClipByGlobalNormState()

    def update_fn(updates, state, params=None):
        del params
        g_norm = global_norm(updates)
        trigger = g_norm < max_norm
        updates = jax.tree_map(
            lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
        return updates, state

    return GradientTransformation(init_fn, update_fn)
コード例 #2
0
ファイル: sgmcmc.py プロジェクト: kokizzu/google-research
def sgld_gradient_update(step_size_fn,
                         seed,
                         momentum_decay=0.,
                         preconditioner=None):
    """Optax implementation of the SGLD optimizer.

  If momentum_decay is set to zero, we get the SGLD method [1]. Otherwise,
  we get the underdamped SGLD (SGHMC) method [2].

  Args:
    step_size_fn: a function taking training step as input and producing the
      step size as output.
    seed: int, random seed.
    momentum_decay: float, momentum decay parameter (default: 0).
    preconditioner: Preconditioner, an object representing the preconditioner
      or None; if None, identity preconditioner is used (default: None).  [1]
        "Bayesian Learning via Stochastic Gradient Langevin Dynamics" Max
        Welling, Yee Whye Teh; ICML 2011  [2] "Stochastic Gradient Hamiltonian
        Monte Carlo" Tianqi Chen, Emily B. Fox, Carlos Guestrin; ICML 2014
  """

    if preconditioner is None:
        preconditioner = get_identity_preconditioner()

    def init_fn(params):
        return OptaxSGLDState(count=jnp.zeros([], jnp.int32),
                              rng_key=jax.random.PRNGKey(seed),
                              momentum=jax.tree_map(jnp.zeros_like, params),
                              preconditioner_state=preconditioner.init(params))

    def update_fn(gradient, state, params=None):
        del params
        lr = step_size_fn(state.count)
        lr_sqrt = jnp.sqrt(lr)
        noise_std = jnp.sqrt(2 * (1 - momentum_decay))

        preconditioner_state = preconditioner.update_preconditioner(
            gradient, state.preconditioner_state)

        noise, new_key = tree_utils.normal_like_tree(gradient, state.rng_key)
        noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state)

        def update_momentum(m, g, n):
            return momentum_decay * m + g * lr_sqrt + n * noise_std

        momentum = jax.tree_map(update_momentum, state.momentum, gradient,
                                noise)
        updates = preconditioner.multiply_by_m_inv(momentum,
                                                   preconditioner_state)
        updates = jax.tree_map(lambda m: m * lr_sqrt, updates)
        return updates, OptaxSGLDState(
            count=state.count + 1,
            rng_key=new_key,
            momentum=momentum,
            preconditioner_state=preconditioner_state)

    return GradientTransformation(init_fn, update_fn)
コード例 #3
0
ファイル: util.py プロジェクト: jeffhsu3/mesh-transformer-jax
def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation:
    """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc)

    Args:
      weight_decay: a scalar weight decay rate.

    Returns:
      An (init_fn, update_fn) tuple.
    """

    def init_fn(_):
        return AdditiveWeightDecayState()

    def update_fn(updates, state, params):
        updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params)
        return updates, state

    return GradientTransformation(init_fn, update_fn)