def testUtilityClipGrads(self): g = (jnp.ones(2), (jnp.ones(3), jnp.ones(4))) norm = optimizers.l2_norm(g) ans = optimizers.clip_grads(g, 1.1 * norm) expected = g self.assertAllClose(ans, expected, check_dtypes=False) ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm)) expected = 0.9 * norm self.assertAllClose(ans, expected, check_dtypes=False)
def update(self, params, opt_state, batch: util.Transition): """The actual update function.""" (_, logs), grads = jax.value_and_grad(self._loss, has_aux=True)(params, batch) grad_norm_unclipped = optimizers.l2_norm(grads) updates, updated_opt_state = self._opt.update(grads, opt_state) params = optax.apply_updates(params, updates) weight_norm = optimizers.l2_norm(params) logs.update({ 'grad_norm_unclipped': grad_norm_unclipped, 'weight_norm': weight_norm, }) return params, updated_opt_state, logs
def testUtilityNorm(self): x0 = (jnp.ones(2), (jnp.ones(3), jnp.ones(4))) norm = optimizers.l2_norm(x0) expected = np.sqrt(np.sum(np.ones(2 + 3 + 4)**2)) self.assertAllClose(norm, expected, check_dtypes=False)