Esempio n. 1
0
    def test_actnorm_invertibility(self):
        name = "actnorm"
        x, x_mask, _ = self.get_data()

        x_inv, logabsdet = glow.actnorm(name,
                                        x,
                                        x_mask,
                                        inverse=False,
                                        init=False)
        x_inv_inv, logabsdet_inv = glow.actnorm(name,
                                                x_inv,
                                                x_mask,
                                                inverse=True,
                                                init=False)
        self.evaluate(tf.global_variables_initializer())
        x, x_inv, x_inv_inv, x_mask, logabsdet, logabsdet_inv = (self.evaluate(
            [x, x_inv, x_inv_inv, x_mask, logabsdet, logabsdet_inv]))
        diff = x - x_inv_inv
        logabsdet_sum = logabsdet + logabsdet_inv
        self.assertEqual(x.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
        self.assertEqual(x_inv.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
        self.assertEqual(x_inv_inv.shape,
                         (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
        self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
        self.assertTrue(np.allclose(logabsdet_sum, 0.0, atol=1e-5))
  def test_actnorm(self):
    _, x_mask, _ = self.get_data()
    x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
                         mean=50.0, stddev=10.0, dtype=DTYPE)
    x_act, logabsdet = glow.actnorm(
        "actnorm", x, x_mask, inverse=False, init=True)

    x_act_nopad = tf.boolean_mask(x_act, x_mask)
    x_mean, x_var = tf.nn.moments(x_act_nopad, axes=[0])
    self.evaluate(tf.global_variables_initializer())
    x, x_act, logabsdet, x_mean, x_var = (
        self.evaluate([x, x_act, logabsdet, x_mean, x_var]))
    self.assertEqual(x_act.shape, (BATCH_SIZE, TARGET_LENGTH, N_CHANNELS))
    self.assertEqual(logabsdet.shape, (BATCH_SIZE,))
    self.assertTrue(np.allclose(x_mean, 0.0, atol=1e-5))
    self.assertTrue(np.allclose(x_var, 1.0, atol=1e-5))