def _do_sampling(*, kind, proposal_kernel_kwargs, num_draws, initial_position, trace_fn, bijector, return_final_kernel_results, seed): """Sample from base HMC kernel.""" kernel = _make_base_kernel( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs) return sample.sample_chain( num_draws, initial_position, kernel=kernel, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn(state, bijector, tf.constant(False), pkr), # pylint: enable=g-long-lambda return_final_kernel_results=return_final_kernel_results, seed=seed)
def _do_sampling(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs, num_draws, num_burnin_steps, initial_position, initial_running_variance, trace_fn, bijector, return_final_kernel_results, seed, chain_axis_names, shard_axis_names): """Sample from base HMC kernel.""" kernel = make_windowed_adapt_kernel( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, initial_running_variance=initial_running_variance, chain_axis_names=chain_axis_names, shard_axis_names=shard_axis_names) return sample.sample_chain( num_draws, initial_position, kernel=kernel, num_burnin_steps=num_burnin_steps, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn( state, bijector, pkr.step <= dual_averaging_kwargs[ 'num_adaptation_steps'], pkr.inner_results.inner_results. inner_results), # pylint: enable=g-long-lambda return_final_kernel_results=return_final_kernel_results, seed=seed)
def _fast_window(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs, num_draws, initial_position, bijector, trace_fn, seed): """Sample using just step size adaptation.""" dual_averaging_kwargs.update({'num_adaptation_steps': num_draws}) kernel = make_fast_adapt_kernel( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs) with warnings.catch_warnings(): warnings.simplefilter('ignore') draws, trace, fkr = sample.sample_chain( num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn( state, bijector, tf.constant(True), pkr.inner_results), seed=seed) # pylint: enable=g-long-lambda draw_and_chain_axes = [0, 1] prev_mean, prev_var = tf.nn.moments(draws[-num_draws // 2:], axes=draw_and_chain_axes) num_samples = tf.cast(num_draws / 2, dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32)) weighted_running_variance = sample_stats.RunningVariance.from_stats( num_samples=num_samples, mean=prev_mean, variance=prev_var) step_size = unnest.get_outermost(fkr, 'step_size') return draws, trace, step_size, weighted_running_variance
def sample_mcmc_chain(): return sample.sample_chain( num_results=num_total_samples // num_chains, num_burnin_steps=1000, current_state=tf.eye(dim, batch_shape=[num_chains], dtype=tf.float64), trace_fn=lambda _, pkr: pkr.inner_results.is_accepted, kernel=kernel, seed=test_util.test_seed())
def _slow_window(*, kind, proposal_kernel_kwargs, dual_averaging_kwargs, num_draws, initial_position, initial_running_variance, bijector, trace_fn, seed): """Sample using both step size and mass matrix adaptation.""" dual_averaging_kwargs = dict(dual_averaging_kwargs) dual_averaging_kwargs.setdefault('num_adaptation_steps', num_draws) kernel = make_slow_adapt_kernel( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, initial_running_variance=initial_running_variance) with warnings.catch_warnings(): warnings.simplefilter('ignore') draws, trace, fkr = sample.sample_chain( num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn(state, bijector, tf.constant(True), pkr.inner_results.inner_results), seed=seed) # pylint: enable=g-long-lambda draw_and_chain_axes = [0, 1] weighted_running_variance = [] for d in list(draws): prev_mean, prev_var = tf.nn.moments(d[-num_draws // 2:], axes=draw_and_chain_axes) num_samples = tf.cast( num_draws / 2, dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32)) weighted_running_variance.append(sample_stats.RunningVariance.from_stats( num_samples=num_samples, mean=prev_mean, variance=prev_var)) step_size = unnest.get_outermost(fkr, 'step_size') momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution') return draws, trace, step_size, weighted_running_variance, momentum_distribution
def _do_sampling(*, target_log_prob_fn, num_leapfrog_steps, num_draws, initial_position, step_size, momentum_distribution, trace_fn, bijector, return_final_kernel_results, seed): """Sample from base HMC kernel.""" kernel = _make_base_kernel(target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, momentum_distribution=momentum_distribution) return sample.sample_chain( num_draws, initial_position, kernel=kernel, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn(state, bijector, tf.constant( False), pkr), # pylint: enable=g-long-lambda return_final_kernel_results=return_final_kernel_results, seed=seed)
def _slow_window(*, target_log_prob_fn, num_leapfrog_steps, num_draws, initial_position, initial_running_variance, initial_step_size, target_accept_prob, bijector, trace_fn, seed): """Sample using both step size and mass matrix adaptation.""" kernel = make_slow_adapt_kernel( target_log_prob_fn=target_log_prob_fn, initial_running_variance=initial_running_variance, initial_step_size=initial_step_size, num_leapfrog_steps=num_leapfrog_steps, num_adaptation_steps=num_draws, target_accept_prob=target_accept_prob) with warnings.catch_warnings(): warnings.simplefilter('ignore') draws, trace, fkr = sample.sample_chain( num_draws, initial_position, kernel=kernel, return_final_kernel_results=True, # pylint: disable=g-long-lambda trace_fn=lambda state, pkr: trace_fn( state, bijector, tf.constant(True), pkr.inner_results. inner_results), seed=seed) # pylint: enable=g-long-lambda draw_and_chain_axes = [0, 1] prev_mean, prev_var = tf.nn.moments(draws[-num_draws // 2:], axes=draw_and_chain_axes) num_samples = tf.cast(num_draws / 2, dtype=dtype_util.common_dtype([prev_mean, prev_var], tf.float32)) weighted_running_variance = sample_stats.RunningVariance.from_stats( num_samples=num_samples, mean=prev_mean, variance=prev_var) step_size = unnest.get_outermost(fkr, 'step_size') momentum_distribution = unnest.get_outermost(fkr, 'momentum_distribution') return draws, trace, step_size, weighted_running_variance, momentum_distribution