Пример #1
0
    def test_stable_global_norm_avoids_overflow(self):
        tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None]
        gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors))
        stable_gnorm_is_inf = math_ops.is_inf(
            tfgan_losses._numerically_stable_global_norm(tensors))

        with self.test_session(use_gpu=True):
            self.assertTrue(gnorm_is_inf.eval())
            self.assertFalse(stable_gnorm_is_inf.eval())
Пример #2
0
  def test_stable_global_norm_avoids_overflow(self):
    tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None]
    gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors))
    stable_gnorm_is_inf = math_ops.is_inf(
        tfgan_losses._numerically_stable_global_norm(tensors))

    with self.test_session(use_gpu=True):
      self.assertTrue(gnorm_is_inf.eval())
      self.assertFalse(stable_gnorm_is_inf.eval())
Пример #3
0
  def test_stable_global_norm_unchanged(self):
    """Test that preconditioning doesn't change global norm value."""
    random_seed.set_random_seed(1234)
    tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)]
    gnorm = clip_ops.global_norm(tensors)
    precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors)

    with self.test_session(use_gpu=True) as sess:
      for _ in range(10):  # spot check closeness on more than one sample.
        gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
        self.assertNear(gnorm_np, precond_gnorm_np, 1e-5)
Пример #4
0
  def test_stable_global_norm_unchanged(self):
    """Test that preconditioning doesn't change global norm value."""
    random_seed.set_random_seed(1234)
    tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)]
    gnorm = clip_ops.global_norm(tensors)
    precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors)

    with self.test_session(use_gpu=True) as sess:
      for _ in range(10):  # spot check closeness on more than one sample.
        gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
        self.assertNear(gnorm_np, precond_gnorm_np, 1e-5)