def _do_sampling(*,
                 kind,
                 proposal_kernel_kwargs,
                 num_draws,
                 initial_position,
                 trace_fn,
                 bijector,
                 return_final_kernel_results,
                 seed):
  """Sample from base HMC kernel."""
  kernel = _make_base_kernel(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs)
  return sample.sample_chain(
      num_draws,
      initial_position,
      kernel=kernel,
      # pylint: disable=g-long-lambda
      trace_fn=lambda state, pkr: trace_fn(state,
                                           bijector,
                                           tf.constant(False),
                                           pkr),
      # pylint: enable=g-long-lambda
      return_final_kernel_results=return_final_kernel_results,
      seed=seed)
def _do_sampling(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs,
                 num_draws, num_burnin_steps, initial_position,
                 initial_running_variance, trace_fn, bijector,
                 return_final_kernel_results, seed, chain_axis_names,
                 shard_axis_names):
    """Sample from base HMC kernel."""
    kernel = make_windowed_adapt_kernel(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        initial_running_variance=initial_running_variance,
        chain_axis_names=chain_axis_names,
        shard_axis_names=shard_axis_names)
    return sample.sample_chain(
        num_draws,
        initial_position,
        kernel=kernel,
        num_burnin_steps=num_burnin_steps,
        # pylint: disable=g-long-lambda
        trace_fn=lambda state, pkr: trace_fn(
            state, bijector, pkr.step <= dual_averaging_kwargs[
                'num_adaptation_steps'], pkr.inner_results.inner_results.
            inner_results),
        # pylint: enable=g-long-lambda
        return_final_kernel_results=return_final_kernel_results,
        seed=seed)
Exemple #3
0
def _fast_window(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs,
                 num_draws, initial_position, bijector, trace_fn, seed):
    """Sample using just step size adaptation."""
    dual_averaging_kwargs.update({'num_adaptation_steps': num_draws})
    kernel = make_fast_adapt_kernel(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs)
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        draws, trace, fkr = sample.sample_chain(
            num_draws,
            initial_position,
            kernel=kernel,
            return_final_kernel_results=True,
            # pylint: disable=g-long-lambda
            trace_fn=lambda state, pkr: trace_fn(
                state, bijector, tf.constant(True), pkr.inner_results),
            seed=seed)
        # pylint: enable=g-long-lambda

    draw_and_chain_axes = [0, 1]
    prev_mean, prev_var = tf.nn.moments(draws[-num_draws // 2:],
                                        axes=draw_and_chain_axes)

    num_samples = tf.cast(num_draws / 2,
                          dtype=dtype_util.common_dtype([prev_mean, prev_var],
                                                        tf.float32))
    weighted_running_variance = sample_stats.RunningVariance.from_stats(
        num_samples=num_samples, mean=prev_mean, variance=prev_var)

    step_size = unnest.get_outermost(fkr, 'step_size')
    return draws, trace, step_size, weighted_running_variance
 def sample_mcmc_chain():
   return sample.sample_chain(
       num_results=num_total_samples // num_chains,
       num_burnin_steps=1000,
       current_state=tf.eye(dim, batch_shape=[num_chains], dtype=tf.float64),
       trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,
       kernel=kernel,
       seed=test_util.test_seed())
Exemple #5
0
def _slow_window(*,
                 kind,
                 proposal_kernel_kwargs,
                 dual_averaging_kwargs,
                 num_draws,
                 initial_position,
                 initial_running_variance,
                 bijector,
                 trace_fn,
                 seed):
  """Sample using both step size and mass matrix adaptation."""
  dual_averaging_kwargs = dict(dual_averaging_kwargs)
  dual_averaging_kwargs.setdefault('num_adaptation_steps', num_draws)
  kernel = make_slow_adapt_kernel(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      initial_running_variance=initial_running_variance)
  with warnings.catch_warnings():
    warnings.simplefilter('ignore')

    draws, trace, fkr = sample.sample_chain(
        num_draws,
        initial_position,
        kernel=kernel,
        return_final_kernel_results=True,
        # pylint: disable=g-long-lambda
        trace_fn=lambda state, pkr: trace_fn(state,
                                             bijector,
                                             tf.constant(True),
                                             pkr.inner_results.inner_results),
        seed=seed)
    # pylint: enable=g-long-lambda

  draw_and_chain_axes = [0, 1]
  weighted_running_variance = []
  for d in list(draws):
    prev_mean, prev_var = tf.nn.moments(d[-num_draws // 2:],
                                        axes=draw_and_chain_axes)
    num_samples = tf.cast(
        num_draws / 2,
        dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32))
    weighted_running_variance.append(sample_stats.RunningVariance.from_stats(
        num_samples=num_samples,
        mean=prev_mean,
        variance=prev_var))

  step_size = unnest.get_outermost(fkr, 'step_size')
  momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution')

  return draws, trace, step_size, weighted_running_variance, momentum_distribution
Exemple #6
0
def _do_sampling(*, target_log_prob_fn, num_leapfrog_steps, num_draws,
                 initial_position, step_size, momentum_distribution, trace_fn,
                 bijector, return_final_kernel_results, seed):
    """Sample from base HMC kernel."""
    kernel = _make_base_kernel(target_log_prob_fn=target_log_prob_fn,
                               step_size=step_size,
                               num_leapfrog_steps=num_leapfrog_steps,
                               momentum_distribution=momentum_distribution)
    return sample.sample_chain(
        num_draws,
        initial_position,
        kernel=kernel,
        # pylint: disable=g-long-lambda
        trace_fn=lambda state, pkr: trace_fn(state, bijector, tf.constant(
            False), pkr),
        # pylint: enable=g-long-lambda
        return_final_kernel_results=return_final_kernel_results,
        seed=seed)
Exemple #7
0
def _slow_window(*, target_log_prob_fn, num_leapfrog_steps, num_draws,
                 initial_position, initial_running_variance, initial_step_size,
                 target_accept_prob, bijector, trace_fn, seed):
    """Sample using both step size and mass matrix adaptation."""
    kernel = make_slow_adapt_kernel(
        target_log_prob_fn=target_log_prob_fn,
        initial_running_variance=initial_running_variance,
        initial_step_size=initial_step_size,
        num_leapfrog_steps=num_leapfrog_steps,
        num_adaptation_steps=num_draws,
        target_accept_prob=target_accept_prob)
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')

        draws, trace, fkr = sample.sample_chain(
            num_draws,
            initial_position,
            kernel=kernel,
            return_final_kernel_results=True,
            # pylint: disable=g-long-lambda
            trace_fn=lambda state, pkr: trace_fn(
                state, bijector, tf.constant(True), pkr.inner_results.
                inner_results),
            seed=seed)
        # pylint: enable=g-long-lambda

    draw_and_chain_axes = [0, 1]
    prev_mean, prev_var = tf.nn.moments(draws[-num_draws // 2:],
                                        axes=draw_and_chain_axes)
    num_samples = tf.cast(num_draws / 2,
                          dtype=dtype_util.common_dtype([prev_mean, prev_var],
                                                        tf.float32))
    weighted_running_variance = sample_stats.RunningVariance.from_stats(
        num_samples=num_samples, mean=prev_mean, variance=prev_var)

    step_size = unnest.get_outermost(fkr, 'step_size')
    momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution')

    return draws, trace, step_size, weighted_running_variance, momentum_distribution