def test_transform_on_transformed(self):
        with self.test_session() as sess:
            normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.)
            self.assertEqual(normal.value_ndims, 0)
            self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5])
            self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5])

            distrib = normal.batch_ndims_to_value(0)
            self.assertIs(distrib, normal)

            distrib = normal.batch_ndims_to_value(1)
            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(1)
            self.assertIsInstance(distrib2, BatchToValueDistribution)
            self.assertEqual(distrib2.value_ndims, 2)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3])
            self.assertIs(distrib.base_distribution, normal)

            distrib2 = distrib.expand_value_ndims(0)
            self.assertIs(distrib2, distrib)
            self.assertEqual(distrib2.value_ndims, 1)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4])
            self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4])
            self.assertIs(distrib.base_distribution, normal)
    def test_ndims_equals_zero_and_negative(self):
        normal = Normal(mean=tf.zeros([3, 4]), logstd=0.)

        self.assertIs(normal.batch_ndims_to_value(0), normal)
        self.assertIs(normal.expand_value_ndims(0), normal)

        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.batch_ndims_to_value(-1)
        with pytest.raises(ValueError,
                           match='`ndims` must be non-negative integers'):
            _ = normal.expand_value_ndims(-1)
    def test_with_normal(self):
        mean = np.random.normal(size=[4, 5]).astype(np.float64)
        logstd = np.random.normal(size=mean.shape).astype(np.float64)
        x = np.random.normal(size=[3, 4, 5])

        with self.test_session() as sess:
            normal = Normal(mean=mean, logstd=logstd)
            distrib = normal.batch_ndims_to_value(1)

            self.assertIsInstance(distrib, BatchToValueDistribution)
            self.assertEqual(distrib.value_ndims, 1)
            self.assertEqual(distrib.get_batch_shape().as_list(), [4])
            self.assertEqual(list(sess.run(distrib.batch_shape)), [4])
            self.assertEqual(distrib.dtype, tf.float64)
            self.assertTrue(distrib.is_continuous)
            self.assertTrue(distrib.is_reparameterized)
            self.assertIs(distrib.base_distribution, normal)

            log_prob = distrib.log_prob(x)
            log_prob2 = distrib.log_prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(x, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(x, group_ndims=2)]))

            prob = distrib.prob(x)
            prob2 = distrib.prob(x, group_ndims=1)
            self.assertEqual(get_static_shape(prob), (3, 4))
            self.assertEqual(get_static_shape(prob2), (3, ))
            np.testing.assert_allclose(
                *sess.run([prob, normal.prob(x, group_ndims=1)]))
            np.testing.assert_allclose(
                *sess.run([prob2, normal.prob(x, group_ndims=2)]))

            sample = distrib.sample(3, compute_density=False)
            sample2 = distrib.sample(3, compute_density=True, group_ndims=1)
            log_prob = sample.log_prob()
            log_prob2 = sample2.log_prob()
            self.assertEqual(get_static_shape(log_prob), (3, 4))
            self.assertEqual(get_static_shape(log_prob2), (3, ))
            np.testing.assert_allclose(*sess.run(
                [log_prob, normal.log_prob(sample, group_ndims=1)]))
            np.testing.assert_allclose(*sess.run(
                [log_prob2, normal.log_prob(sample2, group_ndims=2)]))