예제 #1
0
    def test_stable_global_norm_avoids_overflow(self):
        tensors = [tf.ones([4]), tf.ones([4, 4]) * 1e19, None]
        gnorm_is_inf = tf.math.is_inf(tf.linalg.global_norm(tensors))
        stable_gnorm_is_inf = tf.math.is_inf(
            numerically_stable_global_norm(tensors))

        with self.cached_session() as sess:
            self.assertTrue(sess.run(gnorm_is_inf))
            self.assertFalse(sess.run(stable_gnorm_is_inf))
예제 #2
0
    def test_stable_global_norm_unchanged(self):
        """Test that preconditioning doesn't change global norm value."""
        tf.compat.v1.set_random_seed(1234)
        tensors = [tf.random.uniform([3] * i, -10.0, 10.0) for i in range(6)]
        gnorm = tf.linalg.global_norm(tensors)
        precond_gnorm = numerically_stable_global_norm(tensors)

        with self.cached_session() as sess:
            for _ in range(
                    10):  # spot check closeness on more than one sample.
                gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
                self.assertNear(gnorm_np, precond_gnorm_np, 1e-4)