예제 #1
0
  def check_split_latent_conditioning(self, merge_std):
    with tf.Graph().as_default():
      rng = np.random.RandomState(0)
      x_rand = rng.randn(12, 32, 32, 32).astype(np.float32)
      latent_rand = rng.randn(12, 32, 32, 16).astype(np.float32)
      x_t = tf.convert_to_tensor(x_rand)
      latent_t = tf.convert_to_tensor(latent_rand)
      hparams = glow.glow_hparams()
      hparams.level_scale = merge_std
      hparams.add_hparam("latent_dist_encoder", "pointwise")

      # Test initalization.
      # x2 ~ N(scale * latent, 1.0) where initial scale is 1.0
      exp_x2 = x_rand[:, :, :, 16:]
      exp_eps = x_rand[:, :, :, 16:] - latent_rand
      x_inv, _, eps, x2_t, _ = glow_ops.split(
          merge_std, x_t, cond_latents=latent_t, hparams=hparams,
          condition=True)
      # Test reversibility.
      x_inv_inv, _, _ = glow_ops.split(
          merge_std, x_inv, cond_latents=latent_t, eps=eps, reverse=True,
          hparams=hparams, condition=True)
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        actual_eps, actual_x2, diff_np = sess.run([eps, x2_t, x_inv_inv - x_t])
        self.assertTrue(np.allclose(diff_np, 0.0, atol=1e-5))
        self.assertTrue(np.allclose(actual_eps, exp_eps))
        self.assertTrue(np.allclose(exp_x2, actual_x2))
예제 #2
0
  def check_split_latent_conditioning(self, merge_std):
    with tf.Graph().as_default():
      rng = np.random.RandomState(0)
      x_rand = rng.randn(12, 32, 32, 32).astype(np.float32)
      latent_rand = rng.randn(12, 32, 32, 16).astype(np.float32)
      x_t = tf.convert_to_tensor(x_rand)
      latent_t = tf.convert_to_tensor(latent_rand)
      hparams = glow.glow_hparams()
      hparams.level_scale = merge_std
      hparams.add_hparam("latent_dist_encoder", "pointwise")

      # Test initalization.
      # x2 ~ N(scale * latent, 1.0) where initial scale is 1.0
      exp_x2 = x_rand[:, :, :, 16:]
      exp_eps = x_rand[:, :, :, 16:] - latent_rand
      x_inv, _, eps, x2_t, _ = glow_ops.split(
          merge_std, x_t, cond_latents=latent_t, hparams=hparams,
          condition=True)
      # Test reversibility.
      x_inv_inv, _, _ = glow_ops.split(
          merge_std, x_inv, cond_latents=latent_t, eps=eps, reverse=True,
          hparams=hparams, condition=True)
      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        actual_eps, actual_x2, diff_np = sess.run([eps, x2_t, x_inv_inv - x_t])
        self.assertTrue(np.allclose(diff_np, 0.0, atol=1e-5))
        self.assertTrue(np.allclose(actual_eps, exp_eps))
        self.assertTrue(np.allclose(exp_x2, actual_x2))
예제 #3
0
 def test_split(self):
     with tf.Graph().as_default():
         x = tf.random_uniform(shape=(16, 5, 5, 32))
         x_inv, _, eps = glow_ops.split("split", x)
         x_inv_inv = glow_ops.split("split", x_inv, reverse=True, eps=eps)
         with tf.Session() as session:
             session.run(tf.global_variables_initializer())
             x_inv_np, diff = session.run([x_inv, x - x_inv_inv])
             self.assertEqual(x_inv_np.shape, (16, 5, 5, 16))
             self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
예제 #4
0
 def test_split(self):
   with tf.Graph().as_default():
     x = tf.random_uniform(shape=(16, 5, 5, 32))
     x_inv, _, eps, z, _ = glow_ops.split("split", x)
     x_inv_inv, _, _ = glow_ops.split("split", x_inv, reverse=True, eps=eps)
     with tf.Session() as session:
       session.run(tf.global_variables_initializer())
       x_inv_np, diff, z_np = session.run([x_inv, x - x_inv_inv, z])
       self.assertEqual(z_np.shape, (16, 5, 5, 16))
       self.assertEqual(x_inv_np.shape, (16, 5, 5, 16))
       self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))