def make_hmc_kernel_fn(target_log_prob_fn, init_state, scalings):
    """Generate a hmc without transformation kernel."""

    with tf.name_scope('make_hmc_kernel_fn'):
      state_std = [
          tf.math.reduce_std(x, axis=0, keepdims=True)
          for x in init_state
      ]
      step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps)
      return hmc.HamiltonianMonteCarlo(
          target_log_prob_fn=target_log_prob_fn,
          num_leapfrog_steps=num_leapfrog_steps,
          step_size=step_size)
    def make_transform_hmc_kernel_fn(target_log_prob_fn, init_state, scalings):
        """Generate a transform hmc kernel."""

        with tf.name_scope('make_transformed_hmc_kernel_fn'):
            # TransformedTransitionKernel doesn't modify the input step size, thus we
            # need to pass the appropriate step size that are already in unconstrained
            # space
            state_std = [
                tf.math.reduce_std(bij.inverse(x), axis=0, keepdims=True)
                for x, bij in zip(init_state, unconstraining_bijectors)
            ]
            step_size = compute_hmc_step_size(scalings, state_std,
                                              num_leapfrog_steps)
            return transformed_kernel.TransformedTransitionKernel(
                hmc.HamiltonianMonteCarlo(
                    target_log_prob_fn=target_log_prob_fn,
                    num_leapfrog_steps=num_leapfrog_steps,
                    step_size=step_size), unconstraining_bijectors)
    def make_transform_hmc_kernel_fn(target_log_prob_fn,
                                     init_state,
                                     scalings,
                                     seed=None):
        """Generate a transform hmc kernel."""

        with tf.name_scope('make_transformed_hmc_kernel_fn'):
            seed = SeedStream(seed, salt='make_transformed_hmc_kernel_fn')
            state_std = [
                bij.inverse(  # pylint: disable=g-complex-comprehension
                    tf.math.reduce_std(bij.forward(x), axis=0, keepdims=True))
                for x, bij in zip(init_state, unconstraining_bijectors)
            ]
            step_size = compute_hmc_step_size(scalings, state_std,
                                              num_leapfrog_steps)
            return transformed_kernel.TransformedTransitionKernel(
                hmc.HamiltonianMonteCarlo(
                    target_log_prob_fn=target_log_prob_fn,
                    num_leapfrog_steps=num_leapfrog_steps,
                    step_size=step_size,
                    seed=seed), unconstraining_bijectors)
  def testSampleMarginals(self):
    # Verify that the marginals of the LKJ distribution are distributed
    # according to a (scaled) Beta distribution. The LKJ distributed samples are
    # obtained by sampling a CholeskyLKJ distribution using HMC and the
    # CorrelationCholesky bijector.
    dim = 4
    concentration = np.array(2.5, dtype=np.float64)
    beta_concentration = np.array(.5 * dim + concentration - 1, np.float64)
    beta_dist = beta.Beta(
        concentration0=beta_concentration, concentration1=beta_concentration)

    inner_kernel = hmc.HamiltonianMonteCarlo(
        target_log_prob_fn=cholesky_lkj.CholeskyLKJ(
            dimension=dim, concentration=concentration).log_prob,
        num_leapfrog_steps=3,
        step_size=0.3)

    kernel = transformed_kernel.TransformedTransitionKernel(
        inner_kernel=inner_kernel, bijector=tfb.CorrelationCholesky())

    num_chains = 10
    num_total_samples = 30000

    # Make sure that we have enough samples to catch a wrong sampler to within
    # a small enough discrepancy.
    self.assertLess(
        self.evaluate(
            st.min_num_samples_for_dkwm_cdf_test(
                discrepancy=0.04, false_fail_rate=1e-9, false_pass_rate=1e-9)),
        num_total_samples)

    @tf.function  # Ensure that MCMC sampling is done efficiently.
    def sample_mcmc_chain():
      return sample.sample_chain(
          num_results=num_total_samples // num_chains,
          num_burnin_steps=1000,
          current_state=tf.eye(dim, batch_shape=[num_chains], dtype=tf.float64),
          trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,
          kernel=kernel,
          seed=test_util.test_seed())

    # Draw samples from the HMC chains.
    chol_lkj_samples, is_accepted = self.evaluate(sample_mcmc_chain())

    # Ensure that the per-chain acceptance rate is high enough.
    self.assertAllGreater(np.mean(is_accepted, axis=0), 0.8)

    # Transform from Cholesky LKJ samples to LKJ samples.
    lkj_samples = tf.matmul(chol_lkj_samples, chol_lkj_samples, adjoint_b=True)
    lkj_samples = tf.reshape(lkj_samples, shape=[num_total_samples, dim, dim])

    # Only look at the entries strictly below the diagonal which is achieved by
    # the OutputToUnconstrained bijector. Also scale the marginals from the
    # range [-1,1] to [0,1].
    scaled_lkj_samples = .5 * (OutputToUnconstrained().forward(lkj_samples) + 1)

    # Each of the off-diagonal marginals should be distributed according to a
    # Beta distribution.
    for i in range(dim * (dim - 1) // 2):
      self.evaluate(
          st.assert_true_cdf_equal_by_dkwm(
              scaled_lkj_samples[..., i],
              cdf=beta_dist.cdf,
              false_fail_rate=1e-9))