Beispiel #1
0
    def test_invert_flow(self):
        with self.test_session() as sess:
            # test invert a normal flow
            flow = QuadraticFlow(2., 5.)
            inv_flow = flow.invert()

            self.assertIsInstance(inv_flow, InvertFlow)
            self.assertEqual(inv_flow.x_value_ndims, 0)
            self.assertEqual(inv_flow.y_value_ndims, 0)
            self.assertFalse(inv_flow.require_batch_dims)

            test_x = np.arange(12, dtype=np.float32) + 1.
            test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.)

            self.assertFalse(flow._has_built)
            y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x))
            self.assertTrue(flow._has_built)

            np.testing.assert_allclose(sess.run(y), test_y)
            np.testing.assert_allclose(sess.run(log_det_y), test_log_det)
            invertible_flow_standard_check(self, inv_flow, sess, test_y)

            # test invert an InvertFlow
            inv_inv_flow = inv_flow.invert()
            self.assertIs(inv_inv_flow, flow)

            # test use with FlowDistribution
            normal = Normal(mean=1., std=2.)
            inv_flow = QuadraticFlow(2., 5.).invert()
            distrib = FlowDistribution(normal, inv_flow)
            distrib_log_det = distrib.log_prob(test_x)
            np.testing.assert_allclose(*sess.run(
                [distrib_log_det,
                 normal.log_prob(test_y) + test_log_det]))
    def test_with_normal(self):
        mean = np.random.normal(size=[4, 5]).astype(np.float64)
        logstd = np.random.normal(size=mean.shape).astype(np.float64)
        x = np.random.normal(size=[3, 4, 5])

        with self.test_session() as sess:
            normal = Normal(mean=mean, logstd=logstd)
            distrib = normal.batch_ndims_to_value(1)

            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [4])
            self.assertEqual(distrib.dtype, tf.float64)
            self.assertTrue(distrib.is_continuous)
            self.assertTrue(distrib.is_reparameterized)
            self.assertIs(distrib.base_distribution, normal)

            log_prob = distrib.log_prob(x)
            log_prob2 = distrib.log_prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(x, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(x, group_ndims=2)]))

            prob = distrib.prob(x)
            prob2 = distrib.prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(prob), (3, 4))
            self.assertEqual(get_static_shape(prob2), (3, ))
            np.testing.assert_allclose(
                *sess.run([prob, normal.prob(x, group_ndims=1)]))
            np.testing.assert_allclose(
                *sess.run([prob2, normal.prob(x, group_ndims=2)]))

            sample = distrib.sample(3, compute_density=False)
            sample2 = distrib.sample(3, compute_density=True, group_ndims=1)
            log_prob = sample.log_prob()
            log_prob2 = sample2.log_prob()
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(sample, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(sample2, group_ndims=2)]))