def test_other_properties(self):
        with self.get_session():
            mean, stddev = self.get_mean_stddev(**self.simple_params)

            # test the parameters of the distribution
            dist = Normal(**self.simple_params)
            np.testing.assert_allclose(dist.mean.eval(), mean)
            np.testing.assert_allclose(dist.stddev.eval(), stddev)
            np.testing.assert_allclose(dist.logstd.eval(), np.log(stddev))
            np.testing.assert_allclose(dist.var.eval(), np.square(stddev))
            np.testing.assert_allclose(dist.logvar.eval(), 2. * np.log(stddev))
            np.testing.assert_allclose(dist.precision.eval(),
                                       1. / np.square(stddev))
            np.testing.assert_allclose(dist.log_precision.eval(),
                                       -2. * np.log(stddev))

            # test the parameters of the distribution when logstd is specified
            dist = Normal(mean, logstd=np.log(stddev))
            np.testing.assert_allclose(dist.mean.eval(), mean)
            np.testing.assert_allclose(dist.stddev.eval(), stddev)
            np.testing.assert_allclose(dist.logstd.eval(), np.log(stddev))
            np.testing.assert_allclose(dist.var.eval(), np.square(stddev))
            np.testing.assert_allclose(dist.logvar.eval(), 2. * np.log(stddev))
            np.testing.assert_allclose(dist.precision.eval(),
                                       1. / np.square(stddev))
            np.testing.assert_allclose(dist.log_precision.eval(),
                                       -2. * np.log(stddev))
    def test_construction_error(self):
        with self.get_session():
            # test construction due to no std specified
            with self.assertRaisesRegex(
                    ValueError,
                    'One and only one of `stddev`, `logstd` should '
                    'be specified.'):
                Normal(1.)

            with self.assertRaisesRegex(
                    ValueError,
                    'One and only one of `stddev`, `logstd` should '
                    'be specified.'):
                Normal(1., 2., 3.)

            # test construction due to data type error
            with self.assertRaisesRegex(
                    TypeError, 'Normal distribution parameters must be float '
                    'numbers'):
                Normal(1, 2)

            # test construction error due to shape mismatch
            with self.assertRaisesRegex(
                    ValueError, '`mean` and `stddev`/`logstd` should be '
                    'broadcastable'):
                Normal(np.arange(2, dtype=np.float32),
                       np.arange(3, dtype=np.float32))
Example #3
0
    def test_sgvb(self):
        with self.get_session():
            a = StochasticTensor(Normal(0., 1.),
                                 observed=np.asarray([[0., 1., 2.]]))
            b = StochasticTensor(Normal(1., 2.),
                                 observed=np.asarray([[1., 2., 3.]]))
            c = StochasticTensor(Normal(2., 3.),
                                 observed=np.asarray([[2., 3., 4.]]))

            lower_bound = sgvb([a, b], [c])
            self.assertEqual(lower_bound.get_shape().as_list(), [1, 3])
            np.testing.assert_almost_equal(lower_bound.eval(),
                                           (a.log_prob() + b.log_prob() -
                                            c.log_prob()).eval())

            lower_bound = sgvb([a], [b, c], latent_axis=0)
            self.assertEqual(lower_bound.get_shape().as_list(), [3])
            np.testing.assert_almost_equal(lower_bound.eval(),
                                           (a.log_prob() - b.log_prob() -
                                            c.log_prob()).eval().reshape([3]))

            lower_bound = sgvb([a], [b, c], latent_axis=[0, 1])
            self.assertEqual(lower_bound.get_shape().as_list(), [])
            np.testing.assert_almost_equal(
                lower_bound.eval(),
                np.mean((a.log_prob() - b.log_prob() - c.log_prob()).eval()))
Example #4
0
 def test_prob_and_log_prob(self):
     with self.get_session():
         distrib = Normal(np.asarray(0., dtype=np.float32),
                          np.asarray([1.0, 2.0, 3.0], dtype=np.float32))
         observed = np.arange(24, dtype=np.float32).reshape([4, 2, 3])
         t = StochasticTensor(distrib, observed=observed)
         np.testing.assert_almost_equal(t.log_prob().eval(),
                                        distrib.log_prob(observed).eval())
         np.testing.assert_almost_equal(t.prob().eval(),
                                        distrib.prob(observed).eval())
Example #5
0
 def test_prob_with_group_events_ndims(self):
     with self.get_session():
         distrib = Normal(np.asarray(0., dtype=np.float32),
                          np.asarray([1.0, 2.0, 3.0], dtype=np.float32),
                          group_event_ndims=1)
         observed = np.asarray([[-1., 1., 2.], [0., 0., 0.]])
         t = StochasticTensor(distrib, observed=observed)
         np.testing.assert_allclose(
             t.prob(group_event_ndims=0).eval(),
             distrib.prob(t, group_event_ndims=0).eval())
         np.testing.assert_allclose(
             t.prob(group_event_ndims=1).eval(),
             distrib.prob(t, group_event_ndims=1).eval())
Example #6
0
    def test_attributes_from_distribution(self):
        with self.get_session():
            distrib = Normal(0., 1.)
            t = StochasticTensor(distrib, tf.constant(0.))

            for k in ['is_continuous', 'is_reparameterized']:
                self.assertEqual(getattr(distrib, k), getattr(t, k),
                                 msg='attribute %r mismatch.' % k)
Example #7
0
    def test_session_run_issue_49(self):
        # test fix for the bug at https://github.com/thu-ml/zhusuan/issues/49
        x_mean = tf.zeros([1, 2])
        x_logstd = tf.zeros([1, 2])
        x = Normal(mean=x_mean, logstd=x_logstd).sample()

        with self.get_session() as sess:
            sess.run(tf.global_variables_initializer())
            _ = sess.run(x)
Example #8
0
 def test_convert_to_tensor_if_dynamic(self):
     for v in [
             tf.placeholder(tf.int32, ()),
             tf.get_variable('v', shape=(), dtype=tf.int32),
             StochasticTensor(Normal(0., 1.), 1.)
     ]:
         self.assertIsInstance(convert_to_tensor_if_dynamic(v), tf.Tensor)
     for v in [1, 1.0, object(), (), [], {}, np.array([1, 2, 3])]:
         self.assertIs(convert_to_tensor_if_dynamic(v), v)
    def test_basic(self):
        with self.get_session() as sess:
            a = Normal(0., np.asarray([1., 2., 3.]),
                       group_event_ndims=1).sample_n(16)
            b = Normal(1., 2.).sample_n(16)
            a_prob, b_prob = gather_log_lower_bound([a, b])

            self.assertIsInstance(a_prob, tf.Tensor)
            self.assertEqual(a_prob.get_shape().as_list(), [16])
            res, ans = sess.run(
                [a_prob,
                 a.distribution.log_prob(a, group_event_ndims=1)])
            np.testing.assert_almost_equal(res, ans)

            self.assertIsInstance(b_prob, tf.Tensor)
            self.assertEqual(b_prob.get_shape().as_list(), [16])
            res, ans = sess.run(
                [b_prob,
                 b.distribution.log_prob(b, group_event_ndims=0)])
            np.testing.assert_almost_equal(res, ans)
Example #10
0
 def test_is_dynamic_tensor_like(self):
     for v in [
             tf.placeholder(tf.int32, ()),
             tf.get_variable('v', shape=(), dtype=tf.int32),
             StochasticTensor(Normal(0., 1.), 1.)
     ]:
         self.assertTrue(
             is_dynamic_tensor_like(v),
             msg='%r should be interpreted as a dynamic tensor.' % (v, ))
     for v in [1, 1.0, object(), (), [], {}, np.array([1, 2, 3])]:
         self.assertFalse(
             is_dynamic_tensor_like(v),
             msg='%r should not be interpreted as a dynamic tensor.' %
             (v, ))
    def test_integrated(self):
        with self.get_session(use_gpu=True) as sess:
            tf.set_random_seed(1234)
            p = Normal(0., 1.)
            q = Normal(1., 2.)
            x = q.sample_n(10000)
            log_p = p.log_prob(x)
            log_q = q.log_prob(x)

            # test integrated equality
            output = sess.run(
                importance_sampling(log_p, log_p, log_q, latent_axis=0))
            self.assertLess(np.abs(output + 0.5 * (np.log(2 * np.pi) + 1)),
                            1e-2)
    def test_element_wise(self):
        with self.get_session(use_gpu=True) as sess:
            tf.set_random_seed(1234)
            p = Normal(0., 1.)
            q = Normal(1., 2.)
            x = q.sample_n(1000)
            log_p = p.log_prob(x)
            log_q = q.log_prob(x)

            # test element-wise equality
            output, expected = sess.run([
                importance_sampling(log_p, log_p, log_q),
                log_p * tf.exp(log_p - log_q)
            ])
            np.testing.assert_almost_equal(output, expected)