def testTransformLogProbFn(self): def log_prob_fn(x, y): return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) + tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = [ tfp.bijectors.AffineScalar(scale=self._constant(2.)), tfp.bijectors.AffineScalar(scale=self._constant(3.)) ] (transformed_log_prob_fn, transformed_init_state) = fun_mcmc.transform_log_prob_fn( log_prob_fn, bijectors, [self._constant(2.), self._constant(3.)]) self.assertIsInstance(transformed_init_state, list) self.assertAllClose([1., 1.], transformed_init_state) tlp, (orig_space, _) = ( transformed_log_prob_fn(self._constant(1.), self._constant(1.))) lp = log_prob_fn(self._constant(2.), self._constant(3.))[0] + sum( b.forward_log_det_jacobian(self._constant(1.), event_ndims=0) for b in bijectors) self.assertAllClose([2., 3.], orig_space) self.assertAllClose(lp, tlp)
def testTransformLogProbFnKwargs(self): def log_prob_fn(x, y): return tfd.Normal(0., 1.).log_prob(x) + tfd.Normal(1., 1.).log_prob(y), () bijectors = { 'x': tfb.AffineScalar(scale=2.), 'y': tfb.AffineScalar(scale=3.) } (transformed_log_prob_fn, transformed_init_state) = fun_mcmc.transform_log_prob_fn( log_prob_fn, bijectors, { 'x': 2., 'y': 3. }) self.assertIsInstance(transformed_init_state, dict) self.assertAllClose({'x': 1., 'y': 1.}, transformed_init_state) tlp, (orig_space, _) = transformed_log_prob_fn(x=1., y=1.) lp = log_prob_fn( x=2., y=3.)[0] + sum( b.forward_log_det_jacobian(1., event_ndims=0) for b in bijectors.values()) self.assertAllClose({'x': 2., 'y': 3.}, orig_space) self.assertAllClose(lp, tlp)
def testPreconditionedHMC(self): step_size = self._constant(0.2) num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2], dtype=self._dtype) base_mean = self._constant([1., 0]) base_cov = self._constant([[1, 0.5], [0.5, 1]]) bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda def kernel(hmc_state, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) return (hmc_state, seed), hmc_state.state_extra[0] if not self._is_on_jax: seed = _test_seed() else: seed = self._make_seed(_test_seed()) # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs # for the jit to do anything. _, chain = tf.function(lambda state, seed: fun_mcmc.trace( # pylint: disable=g-long-lambda state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), seed), fn=kernel, num_steps=num_steps))(state, seed) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed())) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def computation(state, seed): bijector = tfp.bijectors.Softplus() base_dist = tfp.distributions.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step, seed): if not self._is_on_jax: hmc_seed = _test_seed() else: hmc_seed, seed = util.split_seed(seed, 2) hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=hmc_seed) rate = fun_mcmc.prefab._polynomial_decay( # pylint: disable=protected-access step=step, step_size=self._constant(0.01), power=0.5, decay_steps=num_adapt_steps, final_step_size=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1, seed), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample( 4096, seed=self._make_seed(_test_seed())) return chain, log_accept_ratio_trace, true_samples
def testPreconditionedHMC(self): step_size = 0.2 num_steps = 2000 num_leapfrog_steps = 10 state = tf.ones([16, 2]) base_mean = [1., 0] base_cov = [[1, 0.5], [0.5, 1]] bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) # pylint: disable=g-long-lambda kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo( state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn, seed=_test_seed())) _, chain = fun_mcmc.trace( state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fn=kernel, num_steps=num_steps, trace_fn=lambda state, extra: state.state_extra[0]) # Discard the warmup samples. chain = chain[1000:] sample_mean = tf.reduce_mean(chain, axis=[0, 1]) sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) true_samples = target_dist.sample(4096, seed=_test_seed()) true_mean = tf.reduce_mean(true_samples, axis=0) true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1]) self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1) self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size_state, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=tf.exp(step_size_state.state), num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) loss_fn = fun_mcmc.make_surrogate_loss_fn( lambda _: (0.9 - mean_p_accept, ())) step_size_state, _ = fun_mcmc.adam_step( step_size_state, loss_fn, learning_rate=rate) return ((hmc_state, step_size_state, step + 1), (hmc_state.state_extra[0], hmc_extra.log_accept_ratio)) _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn), fun_mcmc.adam_init(tf.math.log(step_size)), 0), fn=kernel, num_steps=num_adapt_steps + num_steps, ) true_samples = target_dist.sample(4096, seed=_test_seed()) return chain, log_accept_ratio_trace, true_samples
def testTransformLogProbFn(self): def log_prob_fn(x, y): return tfd.Normal(0., 1.).log_prob(x) + tfd.Normal(1., 1.).log_prob(y), () bijectors = [tfb.AffineScalar(scale=2.), tfb.AffineScalar(scale=3.)] (transformed_log_prob_fn, transformed_init_state) = fun_mcmc.transform_log_prob_fn( log_prob_fn, bijectors, [2., 3.]) self.assertIsInstance(transformed_init_state, list) self.assertAllClose([1., 1.], transformed_init_state) tlp, (orig_space, _) = transformed_log_prob_fn(1., 1.) lp = log_prob_fn(2., 3.)[0] + sum( b.forward_log_det_jacobian(1., event_ndims=0) for b in bijectors) self.assertAllClose([2., 3.], orig_space) self.assertAllClose(lp, tlp)
def computation(state): bijector = tfb.Softplus() base_dist = tfd.MultivariateNormalFullCovariance( loc=base_mean, covariance_matrix=base_cov) target_dist = bijector(base_dist) def orig_target_log_prob_fn(x): return target_dist.log_prob(x), () target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn( orig_target_log_prob_fn, bijector, state) def kernel(hmc_state, step_size, step): hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo( hmc_state, step_size=step_size, num_integrator_steps=num_leapfrog_steps, target_log_prob_fn=target_log_prob_fn) rate = tf.compat.v1.train.polynomial_decay( 0.01, global_step=step, power=0.5, decay_steps=num_adapt_steps, end_learning_rate=0.) mean_p_accept = tf.reduce_mean( tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio))) step_size = fun_mcmc.sign_adaptation(step_size, output=mean_p_accept, set_point=0.9, adaptation_rate=rate) return (hmc_state, step_size, step + 1), hmc_extra _, (chain, log_accept_ratio_trace) = fun_mcmc.trace( (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0), kernel, num_adapt_steps + num_steps, trace_fn=lambda state, extra: (state[0].state_extra[0], extra.log_accept_ratio)) true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed()) return chain, log_accept_ratio_trace, true_samples
def testTransformLogProbFnKwargs(self): def log_prob_fn(x, y): return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) + tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = { 'x': tfp.bijectors.AffineScalar(scale=self._constant(2.)), 'y': tfp.bijectors.AffineScalar(scale=self._constant(3.)) } (transformed_log_prob_fn, transformed_init_state) = fun_mcmc.transform_log_prob_fn( log_prob_fn, bijectors, { 'x': self._constant(2.), 'y': self._constant(3.), }) self.assertIsInstance(transformed_init_state, dict) self.assertAllClose({ 'x': self._constant(1.), 'y': self._constant(1.), }, transformed_init_state) tlp, (orig_space, _) = transformed_log_prob_fn( x=self._constant(1.), y=self._constant(1.)) lp = log_prob_fn( x=self._constant(2.), y=self._constant(3.))[0] + sum( b.forward_log_det_jacobian(self._constant(1.), event_ndims=0) for b in bijectors.values()) self.assertAllClose({ 'x': self._constant(2.), 'y': self._constant(3.) }, orig_space) self.assertAllClose(lp, tlp)