def test_transform_on_transformed(self): with self.test_session() as sess: normal = Normal(mean=tf.zeros([3, 4, 5]), logstd=0.) self.assertEqual(normal.value_ndims, 0) self.assertEqual(normal.get_batch_shape().as_list(), [3, 4, 5]) self.assertEqual(list(sess.run(normal.batch_shape)), [3, 4, 5]) distrib = normal.batch_ndims_to_value(0) self.assertIs(distrib, normal) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(1) self.assertIsInstance(distrib2, BatchToValueDistribution) self.assertEqual(distrib2.value_ndims, 2) self.assertEqual(distrib2.get_batch_shape().as_list(), [3]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3]) self.assertIs(distrib.base_distribution, normal) distrib2 = distrib.expand_value_ndims(0) self.assertIs(distrib2, distrib) self.assertEqual(distrib2.value_ndims, 1) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib2.get_batch_shape().as_list(), [3, 4]) self.assertEqual(list(sess.run(distrib2.batch_shape)), [3, 4]) self.assertIs(distrib.base_distribution, normal)
def test_ndims_equals_zero_and_negative(self): normal = Normal(mean=tf.zeros([3, 4]), logstd=0.) self.assertIs(normal.batch_ndims_to_value(0), normal) self.assertIs(normal.expand_value_ndims(0), normal) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.batch_ndims_to_value(-1) with pytest.raises(ValueError, match='`ndims` must be non-negative integers'): _ = normal.expand_value_ndims(-1)
def test_with_normal(self): mean = np.random.normal(size=[4, 5]).astype(np.float64) logstd = np.random.normal(size=mean.shape).astype(np.float64) x = np.random.normal(size=[3, 4, 5]) with self.test_session() as sess: normal = Normal(mean=mean, logstd=logstd) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [4]) self.assertEqual(distrib.dtype, tf.float64) self.assertTrue(distrib.is_continuous) self.assertTrue(distrib.is_reparameterized) self.assertIs(distrib.base_distribution, normal) log_prob = distrib.log_prob(x) log_prob2 = distrib.log_prob(x, group_ndims=1) self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(x, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(x, group_ndims=2)])) prob = distrib.prob(x) prob2 = distrib.prob(x, group_ndims=1) self.assertEqual(get_static_shape(prob), (3, 4)) self.assertEqual(get_static_shape(prob2), (3, )) np.testing.assert_allclose( *sess.run([prob, normal.prob(x, group_ndims=1)])) np.testing.assert_allclose( *sess.run([prob2, normal.prob(x, group_ndims=2)])) sample = distrib.sample(3, compute_density=False) sample2 = distrib.sample(3, compute_density=True, group_ndims=1) log_prob = sample.log_prob() log_prob2 = sample2.log_prob() self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(sample, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(sample2, group_ndims=2)]))