def check_split_latent_conditioning(self, merge_std): with tf.Graph().as_default(): rng = np.random.RandomState(0) x_rand = rng.randn(12, 32, 32, 32).astype(np.float32) latent_rand = rng.randn(12, 32, 32, 16).astype(np.float32) x_t = tf.convert_to_tensor(x_rand) latent_t = tf.convert_to_tensor(latent_rand) hparams = glow.glow_hparams() hparams.level_scale = merge_std hparams.add_hparam("latent_dist_encoder", "pointwise") # Test initalization. # x2 ~ N(scale * latent, 1.0) where initial scale is 1.0 exp_x2 = x_rand[:, :, :, 16:] exp_eps = x_rand[:, :, :, 16:] - latent_rand x_inv, _, eps, x2_t, _ = glow_ops.split( merge_std, x_t, cond_latents=latent_t, hparams=hparams, condition=True) # Test reversibility. x_inv_inv, _, _ = glow_ops.split( merge_std, x_inv, cond_latents=latent_t, eps=eps, reverse=True, hparams=hparams, condition=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) actual_eps, actual_x2, diff_np = sess.run([eps, x2_t, x_inv_inv - x_t]) self.assertTrue(np.allclose(diff_np, 0.0, atol=1e-5)) self.assertTrue(np.allclose(actual_eps, exp_eps)) self.assertTrue(np.allclose(exp_x2, actual_x2))
def test_split(self): with tf.Graph().as_default(): x = tf.random_uniform(shape=(16, 5, 5, 32)) x_inv, _, eps = glow_ops.split("split", x) x_inv_inv = glow_ops.split("split", x_inv, reverse=True, eps=eps) with tf.Session() as session: session.run(tf.global_variables_initializer()) x_inv_np, diff = session.run([x_inv, x - x_inv_inv]) self.assertEqual(x_inv_np.shape, (16, 5, 5, 16)) self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
def test_split(self): with tf.Graph().as_default(): x = tf.random_uniform(shape=(16, 5, 5, 32)) x_inv, _, eps, z, _ = glow_ops.split("split", x) x_inv_inv, _, _ = glow_ops.split("split", x_inv, reverse=True, eps=eps) with tf.Session() as session: session.run(tf.global_variables_initializer()) x_inv_np, diff, z_np = session.run([x_inv, x - x_inv_inv, z]) self.assertEqual(z_np.shape, (16, 5, 5, 16)) self.assertEqual(x_inv_np.shape, (16, 5, 5, 16)) self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))