Ejemplo n.º 1
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=1).eval(),
                normal1.prob(x).eval())

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=group_ndims).eval(),
                normal1d.prob(x).eval())

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())
Ejemplo n.º 2
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x, group_ndims=1).eval(),
                                       normal1.prob(x).eval(),
                                       rtol=1e-5)

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            # Note: Because we added auxiliary asserts to reduce_mean in our
            # log_prob, the following two static shapes will not be equal.
            #
            # self.assertEqual(
            #     distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.log_prob(x).get_shape()
            # )
            # self.assertEqual(
            #     distrib.prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.prob(x).get_shape()
            # )
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(
                x, group_ndims=group_ndims).eval(),
                                       normal1d.prob(x).eval(),
                                       rtol=1e-5)

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x).eval(),
                                       normal.prob(x).eval(),
                                       rtol=1e-5)

        # test compute_density
        distrib = ZhuSuanDistribution(normal1)
        t = distrib.sample()
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=False)
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=True)
        self.assertIsNotNone(t._self_log_prob)