def clip_by_global_norm(max_norm) -> GradientTransformation: """Clip updates using their global norm. References: [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) Args: max_norm: the maximum global norm for an update. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return ClipByGlobalNormState() def update_fn(updates, state, params=None): del params g_norm = global_norm(updates) trigger = g_norm < max_norm updates = jax.tree_map( lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) return updates, state return GradientTransformation(init_fn, update_fn)
def sgld_gradient_update(step_size_fn, seed, momentum_decay=0., preconditioner=None): """Optax implementation of the SGLD optimizer. If momentum_decay is set to zero, we get the SGLD method [1]. Otherwise, we get the underdamped SGLD (SGHMC) method [2]. Args: step_size_fn: a function taking training step as input and producing the step size as output. seed: int, random seed. momentum_decay: float, momentum decay parameter (default: 0). preconditioner: Preconditioner, an object representing the preconditioner or None; if None, identity preconditioner is used (default: None). [1] "Bayesian Learning via Stochastic Gradient Langevin Dynamics" Max Welling, Yee Whye Teh; ICML 2011 [2] "Stochastic Gradient Hamiltonian Monte Carlo" Tianqi Chen, Emily B. Fox, Carlos Guestrin; ICML 2014 """ if preconditioner is None: preconditioner = get_identity_preconditioner() def init_fn(params): return OptaxSGLDState(count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed), momentum=jax.tree_map(jnp.zeros_like, params), preconditioner_state=preconditioner.init(params)) def update_fn(gradient, state, params=None): del params lr = step_size_fn(state.count) lr_sqrt = jnp.sqrt(lr) noise_std = jnp.sqrt(2 * (1 - momentum_decay)) preconditioner_state = preconditioner.update_preconditioner( gradient, state.preconditioner_state) noise, new_key = tree_utils.normal_like_tree(gradient, state.rng_key) noise = preconditioner.multiply_by_m_sqrt(noise, preconditioner_state) def update_momentum(m, g, n): return momentum_decay * m + g * lr_sqrt + n * noise_std momentum = jax.tree_map(update_momentum, state.momentum, gradient, noise) updates = preconditioner.multiply_by_m_inv(momentum, preconditioner_state) updates = jax.tree_map(lambda m: m * lr_sqrt, updates) return updates, OptaxSGLDState( count=state.count + 1, rng_key=new_key, momentum=momentum, preconditioner_state=preconditioner_state) return GradientTransformation(init_fn, update_fn)
def additive_weight_decay(weight_decay: float = 0.0) -> GradientTransformation: """Add parameter scaled by `weight_decay`, to all parameters with more than one dim (i.e. exclude ln, bias etc) Args: weight_decay: a scalar weight decay rate. Returns: An (init_fn, update_fn) tuple. """ def init_fn(_): return AdditiveWeightDecayState() def update_fn(updates, state, params): updates = jax.tree_multimap(lambda g, p: g + weight_decay * p * (len(g.shape) > 1), updates, params) return updates, state return GradientTransformation(init_fn, update_fn)