def test_global_norm(self): flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32) nested_updates = dict( a=jnp.array([2., 4.], dtype=jnp.float32), b=jnp.array([3., 5.], dtype=jnp.float32)) np.testing.assert_array_equal( jnp.sqrt(jnp.sum(flat_updates**2)), linear_algebra.global_norm(nested_updates))
def update_fn(updates, state, params=None): del params g_norm = linear_algebra.global_norm(updates) # TODO(b/163995078): revert back to the following (faster) implementation # once analysed how it affects backprop through update (e.g. meta-gradients) # g_norm = jnp.maximum(max_norm, g_norm) # updates = jax.tree_map(lambda t: (t / g_norm) * max_norm, updates) trigger = g_norm < max_norm updates = jax.tree_map( lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates) return updates, state
def test_clip_by_global_norm(self): updates = self.per_step_updates for i in range(1, STEPS + 1): clipper = clipping.clip_by_global_norm(1. / i) # Check that the clipper actually works and global norm is <= max_norm updates, _ = clipper.update(updates, None) self.assertAlmostEqual(linear_algebra.global_norm(updates), 1. / i, places=6) # Check that continuously clipping won't cause numerical issues. updates_step, _ = clipper.update(self.per_step_updates, None) chex.assert_tree_all_close(updates, updates_step)