Ejemplo n.º 1
0
 def test_is_reparemeterized(distrib_flag, sample_flag=None):
     normal = zd.Normal(mean=x, std=1., is_reparameterized=distrib_flag)
     distrib = ZhuSuanDistribution(normal)
     samples = distrib.sample(is_reparameterized=sample_flag)
     grads = tf.gradients(samples, x)
     if sample_flag is True or (sample_flag is None and distrib_flag):
         self.assertIsNotNone(grads[0])
     else:
         self.assertIsNone(grads[0])
Ejemplo n.º 2
0
    def test_sample(self):
        # sample re-parameterized samples from a non-reparameterized
        # distribution should cause an error
        with pytest.raises(RuntimeError, match='.* is not re-parameterized'):
            d = ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False))
            self.assertFalse(d.is_reparameterized)
            _ = d.sample(is_reparameterized=True)

        # test sampling with default is_reparameterized = True
        samples = tf.constant(12345678.)
        d = ZhuSuanDistribution(
            Mock(spec=zd.Normal,
                 wraps=zd.Normal(mean=0., std=1.),
                 is_reparameterized=True,
                 sample=Mock(return_value=samples)))
        t = d.sample()
        self.assertIsInstance(t, StochasticTensor)
        self.assertIsNone(t.n_samples)
        self.assertEqual(t.group_ndims, 0)
        self.assertTrue(t.is_reparameterized)
        with self.test_session():
            np.testing.assert_equal(d.sample().eval(), samples.eval())

        # test sampling with default is_reparameterized = False
        self.assertFalse(
            ZhuSuanDistribution(
                Mock(spec=zd.Normal,
                     wraps=zd.Normal(mean=0., std=1.),
                     is_reparameterized=False,
                     sample=Mock(
                         return_value=samples))).sample().is_reparameterized)

        # test sampling with n_samples
        t = d.sample(n_samples=2)
        self.assertEqual(t.n_samples, 2)

        # test sampling with overrided is_reparameterized attribute
        t = d.sample(is_reparameterized=False)
        self.assertFalse(t.is_reparameterized)
Ejemplo n.º 3
0
 def test_proxied_props_and_methods(self):
     zs_distrib = zd.OnehotCategorical(tf.zeros([3, 4, 5]),
                                       dtype=tf.int64,
                                       group_ndims=1)
     distrib = ZhuSuanDistribution(zs_distrib)
     for attr in ['dtype', 'is_continuous', 'is_reparameterized']:
         self.assertEqual(getattr(zs_distrib, attr),
                          getattr(distrib, attr),
                          msg='Attribute `{}` does not equal'.format(attr))
     for meth in ['get_value_shape', 'get_batch_shape']:
         self.assertEqual(
             getattr(zs_distrib, meth)(),
             getattr(distrib, meth)(),
             msg='Output of method `{}` does not equal'.format(meth))
     with self.test_session():
         for attr in ['value_shape', 'batch_shape']:
             np.testing.assert_equal(
                 getattr(zs_distrib, attr).eval(),
                 getattr(distrib, attr).eval(),
                 err_msg='Value of attribute `{}` does not equal'.format(
                     attr))
Ejemplo n.º 4
0
 def test_repr(self):
     distrib = ZhuSuanDistribution(
         Mock(spec=zd.Distribution,
              __repr__=Mock(return_value='repr_output')))
     self.assertEqual(repr(distrib), 'Distribution(repr_output)')
Ejemplo n.º 5
0
 def test_type_error(self):
     with pytest.raises(TypeError,
                        match='`distribution` is not an instance of '
                        '`zhusuan.distributions.Distribution`'):
         _ = ZhuSuanDistribution(1)
Ejemplo n.º 6
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=1).eval(),
                normal1.prob(x).eval())

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=group_ndims).get_shape(),
                normal1d.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x, group_ndims=group_ndims).eval(),
                normal1d.prob(x).eval())

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())
Ejemplo n.º 7
0
 def test_repr(self):
     distrib = ZhuSuanDistribution(
         Mock(spec=zd.Normal,
              wraps=zd.Normal(mean=0., std=1.),
              __repr__=Mock(return_value='repr_output')))
     self.assertEqual(repr(distrib), 'Distribution(repr_output)')
Ejemplo n.º 8
0
    def test_prob_and_log_prob(self):
        x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24.
        normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]))
        normal1 = zd.Normal(mean=tf.zeros([3, 4]),
                            std=tf.ones([3, 4]),
                            group_ndims=1)

        # test with default group_ndims
        distrib = ZhuSuanDistribution(normal)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(
                distrib.prob(x).eval(),
                normal.prob(x).eval())

        # test with static group_ndims
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x, group_ndims=1).get_shape(),
                normal1.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x, group_ndims=1).get_shape(),
                normal1.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=1).eval(),
                normal1.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x, group_ndims=1).eval(),
                                       normal1.prob(x).eval(),
                                       rtol=1e-5)

        # test with dynamic group_ndims
        group_ndims = tf.constant(1, dtype=tf.int32)
        normal1d = zd.Normal(mean=normal.mean,
                             std=normal.std,
                             group_ndims=group_ndims)
        with self.test_session():
            # Note: Because we added auxiliary asserts to reduce_mean in our
            # log_prob, the following two static shapes will not be equal.
            #
            # self.assertEqual(
            #     distrib.log_prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.log_prob(x).get_shape()
            # )
            # self.assertEqual(
            #     distrib.prob(x, group_ndims=group_ndims).get_shape(),
            #     normal1d.prob(x).get_shape()
            # )
            np.testing.assert_allclose(
                distrib.log_prob(x, group_ndims=group_ndims).eval(),
                normal1d.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(
                x, group_ndims=group_ndims).eval(),
                                       normal1d.prob(x).eval(),
                                       rtol=1e-5)

        # test with bad dynamic group_ndims
        group_ndims = tf.constant(-1, dtype=tf.int32)
        with self.test_session():
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.log_prob(x, group_ndims=group_ndims).eval()
            with pytest.raises(Exception,
                               match='group_ndims must be non-negative'):
                _ = distrib.prob(x, group_ndims=group_ndims).eval()

        # test override the default group_ndims in ZhuSuan distribution
        distrib = ZhuSuanDistribution(normal1)
        with self.test_session():
            self.assertEqual(
                distrib.log_prob(x).get_shape(),
                normal.log_prob(x).get_shape())
            self.assertEqual(
                distrib.prob(x).get_shape(),
                normal.prob(x).get_shape())
            np.testing.assert_allclose(
                distrib.log_prob(x).eval(),
                normal.log_prob(x).eval())
            np.testing.assert_allclose(distrib.prob(x).eval(),
                                       normal.prob(x).eval(),
                                       rtol=1e-5)

        # test compute_density
        distrib = ZhuSuanDistribution(normal1)
        t = distrib.sample()
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=False)
        self.assertIsNone(t._self_log_prob)
        t = distrib.sample(compute_density=True)
        self.assertIsNotNone(t._self_log_prob)