def test_multipart_bijector(self): seed_stream = test_util.test_seed_stream() prior = tfd.JointDistributionSequential([ tfd.Gamma(1., 1.), lambda scale: tfd.Uniform(0., scale), lambda concentration: tfd.CholeskyLKJ(4, concentration), ], validate_args=True) likelihood = lambda corr: tfd.MultivariateNormalTriL(scale_tril=corr) obs = self.evaluate( likelihood( prior.sample(seed=seed_stream())[-1]).sample(seed=seed_stream())) bij = prior.experimental_default_event_space_bijector() def target_log_prob(scale, conc, corr): return prior.log_prob(scale, conc, corr) + likelihood(corr).log_prob(obs) kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob, num_leapfrog_steps=3, step_size=.5) kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bij) init = self.evaluate( tuple(tf.random.uniform(s, -2., 2., seed=seed_stream()) for s in bij.inverse_event_shape(prior.event_shape))) state = bij.forward(init) kr = kernel.bootstrap_results(state) next_state, next_kr = kernel.one_step(state, kr, seed=seed_stream()) self.evaluate((state, kr, next_state, next_kr)) expected = (target_log_prob(*state) - bij.inverse_log_det_jacobian(state, [0, 0, 2])) actual = kernel._inner_kernel.target_log_prob_fn(*init) # pylint: disable=protected-access self.assertAllClose(expected, actual)
def _fn(kernel_size, bias_size, dtype=tf.float32): n = kernel_size + bias_size model = tf.keras.Sequential([ tfp.layers.DistributionLambda( lambda t: tfd.CholeskyLKJ(dimension=n, concentration=concentration) ), tfp.layers.DistributionLambda( lambda t: tfd.MultivariateNormalTriL( loc=tf.zeros(n, dtype), scale_tril=t ) ) ]) return model
def test_bijector_shapes(self): d = tfd.Sample(tfd.Uniform(tf.zeros([5]), 1.), 2) b = d.experimental_default_event_space_bijector() self.assertEqual((2,), d.event_shape) self.assertEqual((2,), b.inverse_event_shape((2,))) self.assertEqual((2,), b.forward_event_shape((2,))) self.assertEqual((5, 2), b.forward_event_shape((5, 2))) self.assertEqual((5, 2), b.inverse_event_shape((5, 2))) self.assertEqual((3, 5, 2), b.inverse_event_shape((3, 5, 2))) self.assertEqual((3, 5, 2), b.forward_event_shape((3, 5, 2))) d = tfd.Sample(tfd.CholeskyLKJ(4, concentration=tf.ones([5])), 2) b = d.experimental_default_event_space_bijector() self.assertEqual((2, 4, 4), d.event_shape) dim = (4 * 3) // 2 self.assertEqual((5, 2, dim), b.inverse_event_shape((5, 2, 4, 4))) self.assertEqual((5, 2, 4, 4), b.forward_event_shape((5, 2, dim))) self.assertEqual((3, 5, 2, dim), b.inverse_event_shape((3, 5, 2, 4, 4))) self.assertEqual((3, 5, 2, 4, 4), b.forward_event_shape((3, 5, 2, dim)))
def test_bijector_cholesky_lkj(self): # Let's try with a shape-shifting underlying bijector vec=>mat. d = tfd.Sample(tfd.CholeskyLKJ(4, concentration=tf.ones([5])), 2) b = d.experimental_default_event_space_bijector() y = self.evaluate(d.sample(seed=test_util.test_seed())) x = b.inverse(y) + 0 self.assertAllClose(y, b.forward(x)) y = self.evaluate(d.sample(7, seed=test_util.test_seed())) x = b.inverse(y) + 0 self.assertAllClose(y, b.forward(x)) d2 = tfd.Independent(tfd.CholeskyLKJ(4, concentration=tf.ones([5, 2])), reinterpreted_batch_ndims=1) b2 = d2.experimental_default_event_space_bijector() self.assertAllClose( b2.forward_log_det_jacobian(x, event_ndims=len([2, 6])), b.forward_log_det_jacobian(x, event_ndims=len([2, 6]))) x_sliced = x[..., :1, :] x_bcast = tf.concat([x_sliced, x_sliced], axis=-2) self.assertAllClose( b2.forward_log_det_jacobian(x_bcast, event_ndims=len([2, 6])), b.forward_log_det_jacobian(x_sliced, event_ndims=len([1, 6]))) # Should this test pass? Right now, Independent's default bijector does not # broadcast the LDJ to match underlying batch shape. # self.assertAllClose( # b2.forward_log_det_jacobian(x_sliced, event_ndims=len([1, 6])), # b.forward_log_det_jacobian(x_bcast, event_ndims=len([2, 6]))) self.assertAllClose( b2.inverse_log_det_jacobian(y, event_ndims=len([2, 4, 4])), b.inverse_log_det_jacobian(y, event_ndims=len([2, 4, 4]))) y_sliced = y[..., :1, :, :] y_bcast = tf.concat([y_sliced, y_sliced], axis=-3) self.assertAllClose( b2.inverse_log_det_jacobian(y_bcast, event_ndims=len([2, 4, 4])), b.inverse_log_det_jacobian(y_sliced, event_ndims=len([1, 4, 4]))) # Should this test pass? Right now, Independent's default bijector does not # broadcast the LDJ to match underlying batch shape. # self.assertAllClose( # b2.inverse_log_det_jacobian(y_sliced, event_ndims=len([1, 4, 4])), # b.inverse_log_det_jacobian(y_bcast, event_ndims=len([2, 4, 4]))) self.assertAllClose( b.forward_log_det_jacobian(x_sliced, event_ndims=len([2, 6])), -b.inverse_log_det_jacobian(y_bcast, event_ndims=len([2, 4, 4])), rtol=1e-5) # Now, with another sample shape. d = tfd.Sample(tfd.CholeskyLKJ(4, concentration=tf.ones([5])), [2, 7]) b = d.experimental_default_event_space_bijector() y = self.evaluate(d.sample(11, seed=test_util.test_seed())) x = b.inverse(y) + 0 self.assertAllClose(y, b.forward(x)) d2 = tfd.Independent(tfd.CholeskyLKJ(4, concentration=tf.ones([5, 2, 7])), reinterpreted_batch_ndims=2) b2 = d2.experimental_default_event_space_bijector() self.assertAllClose( b2.forward_log_det_jacobian(x, event_ndims=len([2, 7, 6])), b.forward_log_det_jacobian(x, event_ndims=len([2, 7, 6]))) self.assertAllClose( b2.inverse_log_det_jacobian(y, event_ndims=len([2, 7, 4, 4])), b.inverse_log_det_jacobian(y, event_ndims=len([2, 7, 4, 4]))) # Now, with another batch shape. d = tfd.Sample(tfd.CholeskyLKJ(4, concentration=tf.ones([5, 7])), 2) b = d.experimental_default_event_space_bijector() y = self.evaluate(d.sample(11, seed=test_util.test_seed())) x = b.inverse(y) + 0 self.assertAllClose(y, b.forward(x)) self.assertAllClose( b2.forward_log_det_jacobian(x, event_ndims=len([2, 7, 6])), b.forward_log_det_jacobian(x, event_ndims=len([2, 7, 6]))) self.assertAllClose( b2.inverse_log_det_jacobian(y, event_ndims=len([2, 7, 4, 4])), b.inverse_log_det_jacobian(y, event_ndims=len([2, 7, 4, 4]))) # Also verify it properly handles an extra "event" dim. self.assertAllClose( b2.forward_log_det_jacobian(x, event_ndims=len([5, 2, 7, 6])), b.forward_log_det_jacobian(x, event_ndims=len([5, 2, 7, 6]))) self.assertAllClose( b2.inverse_log_det_jacobian(y, event_ndims=len([5, 2, 7, 4, 4])), b.inverse_log_det_jacobian(y, event_ndims=len([5, 2, 7, 4, 4])))