def test_invert_flow(self): with self.test_session() as sess: # test invert a normal flow flow = QuadraticFlow(2., 5.) inv_flow = flow.invert() self.assertIsInstance(inv_flow, InvertFlow) self.assertEqual(inv_flow.x_value_ndims, 0) self.assertEqual(inv_flow.y_value_ndims, 0) self.assertFalse(inv_flow.require_batch_dims) test_x = np.arange(12, dtype=np.float32) + 1. test_y, test_log_det = quadratic_transform(npyops, test_x, 2., 5.) self.assertFalse(flow._has_built) y, log_det_y = inv_flow.inverse_transform(tf.constant(test_x)) self.assertTrue(flow._has_built) np.testing.assert_allclose(sess.run(y), test_y) np.testing.assert_allclose(sess.run(log_det_y), test_log_det) invertible_flow_standard_check(self, inv_flow, sess, test_y) # test invert an InvertFlow inv_inv_flow = inv_flow.invert() self.assertIs(inv_inv_flow, flow) # test use with FlowDistribution normal = Normal(mean=1., std=2.) inv_flow = QuadraticFlow(2., 5.).invert() distrib = FlowDistribution(normal, inv_flow) distrib_log_det = distrib.log_prob(test_x) np.testing.assert_allclose(*sess.run( [distrib_log_det, normal.log_prob(test_y) + test_log_det]))
def test_with_normal(self): mean = np.random.normal(size=[4, 5]).astype(np.float64) logstd = np.random.normal(size=mean.shape).astype(np.float64) x = np.random.normal(size=[3, 4, 5]) with self.test_session() as sess: normal = Normal(mean=mean, logstd=logstd) distrib = normal.batch_ndims_to_value(1) self.assertIsInstance(distrib, BatchToValueDistribution) self.assertEqual(distrib.value_ndims, 1) self.assertEqual(distrib.get_batch_shape().as_list(), [4]) self.assertEqual(list(sess.run(distrib.batch_shape)), [4]) self.assertEqual(distrib.dtype, tf.float64) self.assertTrue(distrib.is_continuous) self.assertTrue(distrib.is_reparameterized) self.assertIs(distrib.base_distribution, normal) log_prob = distrib.log_prob(x) log_prob2 = distrib.log_prob(x, group_ndims=1) self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(x, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(x, group_ndims=2)])) prob = distrib.prob(x) prob2 = distrib.prob(x, group_ndims=1) self.assertEqual(get_static_shape(prob), (3, 4)) self.assertEqual(get_static_shape(prob2), (3, )) np.testing.assert_allclose( *sess.run([prob, normal.prob(x, group_ndims=1)])) np.testing.assert_allclose( *sess.run([prob2, normal.prob(x, group_ndims=2)])) sample = distrib.sample(3, compute_density=False) sample2 = distrib.sample(3, compute_density=True, group_ndims=1) log_prob = sample.log_prob() log_prob2 = sample2.log_prob() self.assertEqual(get_static_shape(log_prob), (3, 4)) self.assertEqual(get_static_shape(log_prob2), (3, )) np.testing.assert_allclose(*sess.run( [log_prob, normal.log_prob(sample, group_ndims=1)])) np.testing.assert_allclose(*sess.run( [log_prob2, normal.log_prob(sample2, group_ndims=2)]))