Exemple #1
0
 def test_is_reparemeterized(distrib_flag, sample_flag=None):
     normal = zd.Normal(mean=x, std=1., is_reparameterized=distrib_flag)
     distrib = ZhuSuanDistribution(normal)
     samples = distrib.sample(is_reparameterized=sample_flag)
     grads = tf.gradients(samples, x)
     if sample_flag is True or (sample_flag is None and distrib_flag):
         self.assertIsNotNone(grads[0])
     else:
         self.assertIsNone(grads[0])
Exemple #2
0
    def test_sample(self):
        # sample re-parameterized samples from a non-reparameterized
        # distribution should cause an error
        with pytest.raises(RuntimeError, match='.* is not re-parameterized'):
            d = ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False))
            self.assertFalse(d.is_reparameterized)
            _ = d.sample(is_reparameterized=True)

        # test sampling with default is_reparameterized = True
        samples = tf.constant(12345678.)
        d = ZhuSuanDistribution(
            Mock(spec=zd.Normal,
                 wraps=zd.Normal(mean=0., std=1.),
                 is_reparameterized=True,
                 sample=Mock(return_value=samples)))
        t = d.sample()
        self.assertIsInstance(t, StochasticTensor)
        self.assertIsNone(t.n_samples)
        self.assertEqual(t.group_ndims, 0)
        self.assertTrue(t.is_reparameterized)
        with self.test_session():
            np.testing.assert_equal(d.sample().eval(), samples.eval())

        # test sampling with default is_reparameterized = False
        self.assertFalse(
            ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False,
                     sample=Mock(
                         return_value=samples))).sample().is_reparameterized)

        # test sampling with n_samples
        t = d.sample(n_samples=2)
        self.assertEqual(t.n_samples, 2)

        # test sampling with overrided is_reparameterized attribute
        t = d.sample(is_reparameterized=False)
        self.assertFalse(t.is_reparameterized)
Exemple #3
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x, group_ndims=1).eval(),
                                       normal1.prob(x).eval(),
                                       rtol=1e-5)

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            # Note: Because we added auxiliary asserts to reduce_mean in our
            # log_prob, the following two static shapes will not be equal.
            #
            # self.assertEqual(
            #     distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.log_prob(x).get_shape()
            # )
            # self.assertEqual(
            #     distrib.prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.prob(x).get_shape()
            # )
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(
                x, group_ndims=group_ndims).eval(),
                                       normal1d.prob(x).eval(),
                                       rtol=1e-5)

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x).eval(),
                                       normal.prob(x).eval(),
                                       rtol=1e-5)

        # test compute_density
        distrib = ZhuSuanDistribution(normal1)
        t = distrib.sample()
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=False)
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=True)
        self.assertIsNotNone(t._self_log_prob)