def test_full_norm_deals_with_empty_input_gracefully(self):
     norm = full_norm(None)
     self.assertEqual(0, norm)
     norm = full_norm([])
     self.assertEqual(0, norm)
     norm = full_norm(())
     self.assertEqual(0, norm)
 def assert_clipping_results(self, gradient_parts, clipped_gradient_parts, clip_threshold):
     self.assert_gradient_direction(gradient_parts, clipped_gradient_parts)
     norm_clipped = full_norm(clipped_gradient_parts)
     norm = full_norm(gradient_parts)
     self.assertLessEqual(norm_clipped, clip_threshold)
     self.assertLessEqual(norm_clipped, norm)
     self.assertLessEqual(norm_clipped, norm)
 def assert_gradient_direction(self, expected_gradient_parts, actual_gradient_parts):
     self.assertEqual(len(expected_gradient_parts), len(actual_gradient_parts))
     norm_expected = full_norm(expected_gradient_parts)
     norm_actual = full_norm(actual_gradient_parts)
     for g_expected, g_actual in zip(expected_gradient_parts, actual_gradient_parts):
         self.assertEqual(g_expected.shape, g_actual.shape)
         g_expected_normalized = g_expected / norm_expected
         g_actual_normalized = g_actual / norm_actual
         self.assertTrue(jnp.allclose(g_expected_normalized, g_actual_normalized))
 def test_full_norm_on_jax_tree(self):
     gradient_tree = (
         jnp.ones(shape=(17, 2, 3)),
         jnp.ones(shape=(2, 54)),
         (jnp.ones(shape=(2,3)), jnp.ones(shape=(3,4,5))),
         ()
     )
     norm = full_norm(gradient_tree)
     expectedNorm = 16.613247
     self.assertTrue(jnp.allclose(expectedNorm, norm))
 def test_clip_gradient_gives_input_when_threshold_exceeds_norm(self):
     clip_threshold = 2 * full_norm(self.gradient_parts)
     clipped_gradient_parts = clip_gradient(self.gradient_parts, clip_threshold)
     self.assert_array_tuple_close(self.gradient_parts, clipped_gradient_parts)
 def test_clip_gradient_clips_when_threshold_is_less_than_norm(self):
     clip_threshold = 0.1 * full_norm(self.gradient_parts)
     clipped_gradient_parts = clip_gradient(self.gradient_parts, clip_threshold)
     self.assert_clipping_results(self.gradient_parts, clipped_gradient_parts, clip_threshold)
 def test_full_norm_is_correct(self):
     norm = full_norm(self.gradient_parts)
     expectedNorm = jnp.sqrt(
         jnp.sum(jnp.array(tuple(jnp.sum(jnp.square(x)) for x in self.gradient_parts)))
     )
     self.assertTrue(jnp.allclose(expectedNorm, norm))
 def test_normalize_gradient(self):
     normalized_gradient_parts = normalize_gradient(self.gradient_parts)
     self.assert_gradient_direction(self.gradient_parts, normalized_gradient_parts)
     normalized_norm = full_norm(normalized_gradient_parts)
     self.assertTrue(jnp.allclose(1., normalized_norm))