예제 #1
0
 def test_global_norm(self):
   flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32)
   nested_updates = dict(
       a=jnp.array([2., 4.], dtype=jnp.float32),
       b=jnp.array([3., 5.], dtype=jnp.float32))
   np.testing.assert_array_equal(
       jnp.sqrt(jnp.sum(flat_updates**2)),
       linear_algebra.global_norm(nested_updates))
예제 #2
0
 def update_fn(updates, state, params=None):
   del params
   g_norm = linear_algebra.global_norm(updates)
   # TODO(b/163995078): revert back to the following (faster) implementation
   # once analysed how it affects backprop through update (e.g. meta-gradients)
   # g_norm = jnp.maximum(max_norm, g_norm)
   # updates = jax.tree_map(lambda t: (t / g_norm) * max_norm, updates)
   trigger = g_norm < max_norm
   updates = jax.tree_map(
       lambda t: jnp.where(trigger, t, (t / g_norm) * max_norm), updates)
   return updates, state
예제 #3
0
 def test_clip_by_global_norm(self):
     updates = self.per_step_updates
     for i in range(1, STEPS + 1):
         clipper = clipping.clip_by_global_norm(1. / i)
         # Check that the clipper actually works and global norm is <= max_norm
         updates, _ = clipper.update(updates, None)
         self.assertAlmostEqual(linear_algebra.global_norm(updates),
                                1. / i,
                                places=6)
         # Check that continuously clipping won't cause numerical issues.
         updates_step, _ = clipper.update(self.per_step_updates, None)
         chex.assert_tree_all_close(updates, updates_step)