def test_prob_and_log_prob(self): x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24. normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4])) normal1 = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]), group_ndims=1) # test with default group_ndims distrib = ZhuSuanDistribution(normal) with self.test_session(): self.assertEqual( distrib.log_prob(x).get_shape(), normal.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x).get_shape(), normal.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x).eval(), normal.log_prob(x).eval()) np.testing.assert_allclose( distrib.prob(x).eval(), normal.prob(x).eval()) # test with static group_ndims with self.test_session(): self.assertEqual( distrib.log_prob(x, group_ndims=1).get_shape(), normal1.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x, group_ndims=1).get_shape(), normal1.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x, group_ndims=1).eval(), normal1.log_prob(x).eval()) np.testing.assert_allclose( distrib.prob(x, group_ndims=1).eval(), normal1.prob(x).eval()) # test with dynamic group_ndims group_ndims = tf.constant(1, dtype=tf.int32) normal1d = zd.Normal(mean=normal.mean, std=normal.std, group_ndims=group_ndims) with self.test_session(): self.assertEqual( distrib.log_prob(x, group_ndims=group_ndims).get_shape(), normal1d.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x, group_ndims=group_ndims).get_shape(), normal1d.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x, group_ndims=group_ndims).eval(), normal1d.log_prob(x).eval()) np.testing.assert_allclose( distrib.prob(x, group_ndims=group_ndims).eval(), normal1d.prob(x).eval()) # test with bad dynamic group_ndims group_ndims = tf.constant(-1, dtype=tf.int32) with self.test_session(): with pytest.raises(Exception, match='group_ndims must be non-negative'): _ = distrib.log_prob(x, group_ndims=group_ndims).eval() with pytest.raises(Exception, match='group_ndims must be non-negative'): _ = distrib.prob(x, group_ndims=group_ndims).eval() # test override the default group_ndims in ZhuSuan distribution distrib = ZhuSuanDistribution(normal1) with self.test_session(): self.assertEqual( distrib.log_prob(x).get_shape(), normal.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x).get_shape(), normal.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x).eval(), normal.log_prob(x).eval()) np.testing.assert_allclose( distrib.prob(x).eval(), normal.prob(x).eval())
def test_prob_and_log_prob(self): x = tf.reshape(tf.range(24, dtype=tf.float32), [2, 3, 4]) / 24. normal = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4])) normal1 = zd.Normal(mean=tf.zeros([3, 4]), std=tf.ones([3, 4]), group_ndims=1) # test with default group_ndims distrib = ZhuSuanDistribution(normal) with self.test_session(): self.assertEqual( distrib.log_prob(x).get_shape(), normal.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x).get_shape(), normal.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x).eval(), normal.log_prob(x).eval()) np.testing.assert_allclose( distrib.prob(x).eval(), normal.prob(x).eval()) # test with static group_ndims with self.test_session(): self.assertEqual( distrib.log_prob(x, group_ndims=1).get_shape(), normal1.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x, group_ndims=1).get_shape(), normal1.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x, group_ndims=1).eval(), normal1.log_prob(x).eval()) np.testing.assert_allclose(distrib.prob(x, group_ndims=1).eval(), normal1.prob(x).eval(), rtol=1e-5) # test with dynamic group_ndims group_ndims = tf.constant(1, dtype=tf.int32) normal1d = zd.Normal(mean=normal.mean, std=normal.std, group_ndims=group_ndims) with self.test_session(): # Note: Because we added auxiliary asserts to reduce_mean in our # log_prob, the following two static shapes will not be equal. # # self.assertEqual( # distrib.log_prob(x, group_ndims=group_ndims).get_shape(), # normal1d.log_prob(x).get_shape() # ) # self.assertEqual( # distrib.prob(x, group_ndims=group_ndims).get_shape(), # normal1d.prob(x).get_shape() # ) np.testing.assert_allclose( distrib.log_prob(x, group_ndims=group_ndims).eval(), normal1d.log_prob(x).eval()) np.testing.assert_allclose(distrib.prob( x, group_ndims=group_ndims).eval(), normal1d.prob(x).eval(), rtol=1e-5) # test with bad dynamic group_ndims group_ndims = tf.constant(-1, dtype=tf.int32) with self.test_session(): with pytest.raises(Exception, match='group_ndims must be non-negative'): _ = distrib.log_prob(x, group_ndims=group_ndims).eval() with pytest.raises(Exception, match='group_ndims must be non-negative'): _ = distrib.prob(x, group_ndims=group_ndims).eval() # test override the default group_ndims in ZhuSuan distribution distrib = ZhuSuanDistribution(normal1) with self.test_session(): self.assertEqual( distrib.log_prob(x).get_shape(), normal.log_prob(x).get_shape()) self.assertEqual( distrib.prob(x).get_shape(), normal.prob(x).get_shape()) np.testing.assert_allclose( distrib.log_prob(x).eval(), normal.log_prob(x).eval()) np.testing.assert_allclose(distrib.prob(x).eval(), normal.prob(x).eval(), rtol=1e-5) # test compute_density distrib = ZhuSuanDistribution(normal1) t = distrib.sample() self.assertIsNone(t._self_log_prob) t = distrib.sample(compute_density=False) self.assertIsNone(t._self_log_prob) t = distrib.sample(compute_density=True) self.assertIsNotNone(t._self_log_prob)