Ejemplo n.º 1
0
 def test_log_prob_with_group_events_ndims(self):
     with self.get_session():
         distrib = Normal(np.asarray(0., dtype=np.float32),
                          np.asarray([1.0, 2.0, 3.0], dtype=np.float32),
                          group_event_ndims=1)
         observed = np.asarray([[-1., 1., 2.], [0., 0., 0.]])
         t = StochasticTensor(distrib, observed=observed)
         np.testing.assert_allclose(
             t.log_prob(group_event_ndims=0).eval(),
             distrib.log_prob(t, group_event_ndims=0).eval())
         np.testing.assert_allclose(
             t.log_prob(group_event_ndims=1).eval(),
             distrib.log_prob(t, group_event_ndims=1).eval())
    def test_integrated(self):
        with self.get_session(use_gpu=True) as sess:
            tf.set_random_seed(1234)
            p = Normal(0., 1.)
            q = Normal(1., 2.)
            x = q.sample_n(10000)
            log_p = p.log_prob(x)
            log_q = q.log_prob(x)

            # test integrated equality
            output = sess.run(
                importance_sampling(log_p, log_p, log_q, latent_axis=0))
            self.assertLess(np.abs(output + 0.5 * (np.log(2 * np.pi) + 1)),
                            1e-2)
    def test_element_wise(self):
        with self.get_session(use_gpu=True) as sess:
            tf.set_random_seed(1234)
            p = Normal(0., 1.)
            q = Normal(1., 2.)
            x = q.sample_n(1000)
            log_p = p.log_prob(x)
            log_q = q.log_prob(x)

            # test element-wise equality
            output, expected = sess.run([
                importance_sampling(log_p, log_p, log_q),
                log_p * tf.exp(log_p - log_q)
            ])
            np.testing.assert_almost_equal(output, expected)
Ejemplo n.º 4
0
 def test_prob_and_log_prob(self):
     with self.get_session():
         distrib = Normal(np.asarray(0., dtype=np.float32),
                          np.asarray([1.0, 2.0, 3.0], dtype=np.float32))
         observed = np.arange(24, dtype=np.float32).reshape([4, 2, 3])
         t = StochasticTensor(distrib, observed=observed)
         np.testing.assert_almost_equal(t.log_prob().eval(),
                                        distrib.log_prob(observed).eval())
         np.testing.assert_almost_equal(t.prob().eval(),
                                        distrib.prob(observed).eval())