def make_hmc_kernel_fn(target_log_prob_fn, init_state, scalings): """Generate a hmc without transformation kernel.""" with tf.name_scope('make_hmc_kernel_fn'): state_std = [ tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state ] step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps) return hmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size)
def make_transform_hmc_kernel_fn(target_log_prob_fn, init_state, scalings): """Generate a transform hmc kernel.""" with tf.name_scope('make_transformed_hmc_kernel_fn'): # TransformedTransitionKernel doesn't modify the input step size, thus we # need to pass the appropriate step size that are already in unconstrained # space state_std = [ tf.math.reduce_std(bij.inverse(x), axis=0, keepdims=True) for x, bij in zip(init_state, unconstraining_bijectors) ] step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps) return transformed_kernel.TransformedTransitionKernel( hmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size), unconstraining_bijectors)
def make_transform_hmc_kernel_fn(target_log_prob_fn, init_state, scalings, seed=None): """Generate a transform hmc kernel.""" with tf.name_scope('make_transformed_hmc_kernel_fn'): seed = SeedStream(seed, salt='make_transformed_hmc_kernel_fn') state_std = [ bij.inverse( # pylint: disable=g-complex-comprehension tf.math.reduce_std(bij.forward(x), axis=0, keepdims=True)) for x, bij in zip(init_state, unconstraining_bijectors) ] step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps) return transformed_kernel.TransformedTransitionKernel( hmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size, seed=seed), unconstraining_bijectors)
def testSampleMarginals(self): # Verify that the marginals of the LKJ distribution are distributed # according to a (scaled) Beta distribution. The LKJ distributed samples are # obtained by sampling a CholeskyLKJ distribution using HMC and the # CorrelationCholesky bijector. dim = 4 concentration = np.array(2.5, dtype=np.float64) beta_concentration = np.array(.5 * dim + concentration - 1, np.float64) beta_dist = beta.Beta( concentration0=beta_concentration, concentration1=beta_concentration) inner_kernel = hmc.HamiltonianMonteCarlo( target_log_prob_fn=cholesky_lkj.CholeskyLKJ( dimension=dim, concentration=concentration).log_prob, num_leapfrog_steps=3, step_size=0.3) kernel = transformed_kernel.TransformedTransitionKernel( inner_kernel=inner_kernel, bijector=tfb.CorrelationCholesky()) num_chains = 10 num_total_samples = 30000 # Make sure that we have enough samples to catch a wrong sampler to within # a small enough discrepancy. self.assertLess( self.evaluate( st.min_num_samples_for_dkwm_cdf_test( discrepancy=0.04, false_fail_rate=1e-9, false_pass_rate=1e-9)), num_total_samples) @tf.function # Ensure that MCMC sampling is done efficiently. def sample_mcmc_chain(): return sample.sample_chain( num_results=num_total_samples // num_chains, num_burnin_steps=1000, current_state=tf.eye(dim, batch_shape=[num_chains], dtype=tf.float64), trace_fn=lambda _, pkr: pkr.inner_results.is_accepted, kernel=kernel, seed=test_util.test_seed()) # Draw samples from the HMC chains. chol_lkj_samples, is_accepted = self.evaluate(sample_mcmc_chain()) # Ensure that the per-chain acceptance rate is high enough. self.assertAllGreater(np.mean(is_accepted, axis=0), 0.8) # Transform from Cholesky LKJ samples to LKJ samples. lkj_samples = tf.matmul(chol_lkj_samples, chol_lkj_samples, adjoint_b=True) lkj_samples = tf.reshape(lkj_samples, shape=[num_total_samples, dim, dim]) # Only look at the entries strictly below the diagonal which is achieved by # the OutputToUnconstrained bijector. Also scale the marginals from the # range [-1,1] to [0,1]. scaled_lkj_samples = .5 * (OutputToUnconstrained().forward(lkj_samples) + 1) # Each of the off-diagonal marginals should be distributed according to a # Beta distribution. for i in range(dim * (dim - 1) // 2): self.evaluate( st.assert_true_cdf_equal_by_dkwm( scaled_lkj_samples[..., i], cdf=beta_dist.cdf, false_fail_rate=1e-9))