def sample_chain( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, trace_fn=_trace_kernel_results, return_final_kernel_results=False, parallel_iterations=10, seed=None, name=None, ): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from a Markov chain at `current_state` whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning' is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized, and thus do not increase memory requirements. In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by specifying `trace_fn`. By default, all kernel results are traced but in the future the default will be changed to no results being traced, so plan accordingly. See below for some examples of this feature. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state and the trace specified by the `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional, a seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'experimental_mcmc_sample_chain'). Returns: checkpointable_states_and_trace: if `return_final_kernel_results` is `True`. The return value is an instance of `CheckpointableStatesAndTrace`. all_states: if `return_final_kernel_results` is `False` and `trace_fn` is `None`. The return value is a `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. states_and_trace: if `return_final_kernel_results` is `False` and `trace_fn` is not `None`. The return value is an instance of `StatesAndTrace`. #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ with tf.name_scope(name or 'experimental_mcmc_sample_chain'): if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if trace_fn is sample_chain.__defaults__[4]: warnings.warn( 'Tracing all kernel results by default is deprecated. Set ' 'the `trace_fn` argument to None (the future default ' 'value) or an explicit callback that traces the values ' 'you are interested in.') # `WithReductions` assumes all its reducers want to reduce over the # immediate inner results of its kernel results. However, # We don't care about the kernel results of `SampleDiscardingKernel`; hence, # we evaluate the `trace_fn` on a deeper level of inner results. def real_trace_fn(curr_state, kr): return curr_state, trace_fn(curr_state, kr.inner_results) trace_reducer = tracing_reducer.TracingReducer(trace_fn=real_trace_fn, size=num_results) # pylint: disable=unbalanced-tuple-unpacking trace_results, _, final_kernel_results = sample_fold( num_steps=num_results, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=kernel, reducer=trace_reducer, num_burnin_steps=num_burnin_steps, num_steps_between_results=num_steps_between_results, parallel_iterations=parallel_iterations, seed=seed, name=name, ) all_states, trace = trace_results if return_final_kernel_results: return sample.CheckpointableStatesAndTrace( all_states=all_states, trace=trace, final_kernel_results=final_kernel_results) else: if no_trace: return all_states else: return sample.StatesAndTrace(all_states=all_states, trace=trace)
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, chain_axis_names, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if isinstance(n_chains, int): n_chains = [n_chains] if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): # A Tensor num_draws argument breaks XLA, which requires static TensorArray # trace_fn result allocation sizes. num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps) if 'num_adaptation_steps' in dual_averaging_kwargs: warnings.warn( 'Dual averaging adaptation will use the value specified in' ' the `num_adaptation_steps` argument for its construction,' ' hence there is no need to specify it in the' ' `dual_averaging_kwargs` argument.') # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. dual_averaging_kwargs['num_adaptation_steps'] = num_adaptation_steps dual_averaging_kwargs.setdefault( 'reduce_fn', functools.partial( generic_math.reduce_log_harmonic_mean_exp, # There is only one log_accept_prob per chain, and we reduce across # all chains, so typically the all_gather will be gathering scalars, # which should be relatively efficient. experimental_allow_all_gather=True)) # By default, reduce over named axes for step size adaptation dual_averaging_kwargs.setdefault('experimental_reduce_chain_axis_names', chain_axis_names) setup_seed, sample_seed = samplers.split_seed(samplers.sanitize_seed(seed), n=2) (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, batch_shape, shard_axis_names) = _setup_mcmc(joint_dist, n_chains=n_chains, init_position=current_state, seed=setup_seed, **pins) if proposal_kernel_kwargs.get('step_size') is None: if batch_shape.shape != (0, ): # Scalar batch has a 0-vector shape. raise ValueError( 'Batch target density must specify init_step_size. Got ' f'batch shape {batch_shape} from joint {joint_dist}.') init_step_size = _get_step_size(initial_transformed_position, target_log_prob_fn) else: init_step_size = step_broadcast(proposal_kernel_kwargs['step_size']) proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': init_step_size, 'momentum_distribution': _init_momentum(initial_transformed_position, batch_shape=ps.concat([n_chains, batch_shape], axis=0), shard_axis_names=shard_axis_names) }) initial_running_variance = [ sample_stats.RunningVariance.from_stats( # pylint: disable=g-complex-comprehension num_samples=tf.zeros([], part.dtype), mean=tf.zeros_like(part), variance=tf.ones_like(part)) for part in initial_transformed_position ] # TODO(phandu): Consider splitting out warmup and post warmup phases # to avoid executing adaptation code during the post warmup phase. ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=n_draws if discard_tuning else n_draws + num_adaptation_steps, num_burnin_steps=num_adaptation_steps if discard_tuning else 0, initial_position=initial_transformed_position, initial_running_variance=initial_running_variance, bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, chain_axis_names=chain_axis_names, shard_axis_names=shard_axis_names, seed=sample_seed) if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace)
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc( joint_dist, n_chains=n_chains, seed=setup_seed, **pins) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) # If we (over) optimistically assume good scaling, this will be near the # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts, # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv. # http://arxiv.org/abs/1905.09813. init_step_size = tf.cast( ps.shape(initial_transformed_position)[-1], tf.float32) ** -0.25 all_draws = [] all_traces = [] proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': tf.fill([n_chains, 1], init_step_size), 'momentum_distribution': _init_momentum(initial_transformed_position), }) draws, trace, step_size, running_variance = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=first_window_size, initial_position=initial_transformed_position, bijector=bijector, trace_fn=trace_fn, seed=init_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws.append(draws) all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. draws, trace, step_size, running_variance, momentum_distribution = _slow_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=draws[-1], initial_running_variance=running_variance, bijector=bijector, trace_fn=trace_fn, seed=slow_seed) all_draws.append(draws) all_traces.append(trace) proposal_kernel_kwargs.update( {'step_size': step_size, 'momentum_distribution': momentum_distribution}) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, running_variance = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=last_window_size, initial_position=draws[-1], bijector=bijector, trace_fn=trace_fn, seed=fast_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws.append(draws) all_traces.append(trace) ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, num_draws=n_draws, initial_position=draws[-1], bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret all_draws.append(draws) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret all_draws.append(draws) all_traces.append(trace) if no_trace: return bijector.inverse(tf.concat(all_draws, axis=0)) else: return sample.StatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): # A Tensor num_draws argument breaks XLA, which requires static TensorArray # trace_fn result allocation sizes. num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, batch_shape) = _setup_mcmc( joint_dist, n_chains=n_chains, init_position=current_state, seed=setup_seed, **pins) if proposal_kernel_kwargs.get('step_size') is None: if batch_shape.shape != (0,): # Scalar batch has a 0-vector shape. raise ValueError('Batch target density must specify init_step_size. Got ' f'batch shape {batch_shape} from joint {joint_dist}.') init_step_size = _get_step_size(initial_transformed_position, target_log_prob_fn) else: init_step_size = step_broadcast(proposal_kernel_kwargs['step_size']) proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': init_step_size, 'momentum_distribution': _init_momentum( initial_transformed_position, batch_shape=ps.concat([[n_chains], batch_shape], axis=0))}) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) all_traces = [] # Using tf.function here and on _slow_window_closure caches tracing # of _fast_window and _slow_window, respectively, within a single # call to windowed sampling. Why not annotate _fast_window and # _slow_window directly? Two reasons: # - Caching across calls to windowed sampling is probably futile, # because the trace function and bijector will be different Python # objects, preventing cache hits. # - The cache of a global tf.function sticks around for the lifetime # of the Python process, potentially leaking memory. @tf.function(autograph=False) def _fast_window_closure(proposal_kernel_kwargs, window_size, initial_position, seed): return _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=initial_position, bijector=bijector, trace_fn=trace_fn, seed=seed) draws, trace, step_size, running_variances = _fast_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=first_window_size, initial_position=initial_transformed_position, seed=init_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws = [[d] for d in draws] all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) @tf.function(autograph=False) def _slow_window_closure(proposal_kernel_kwargs, window_size, initial_position, running_variances, seed): return _slow_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=initial_position, initial_running_variance=running_variances, bijector=bijector, trace_fn=trace_fn, seed=seed) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. (draws, trace, step_size, running_variances, momentum_distribution ) = _slow_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=window_size, initial_position=[d[-1] for d in draws], running_variances=running_variances, seed=slow_seed) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) proposal_kernel_kwargs.update( {'step_size': step_size, 'momentum_distribution': momentum_distribution}) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, _ = _fast_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=last_window_size, initial_position=[d[-1] for d in draws], seed=fast_seed) proposal_kernel_kwargs.update({'step_size': step_size}) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, num_draws=n_draws, initial_position=[d[-1] for d in draws], bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse( [tf.concat(d, axis=0) for d in all_draws]), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws]) if no_trace: return all_states else: all_traces.append(trace) return sample.StatesAndTrace( all_states=all_states, trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))
def windowed_adaptive_hmc(n_draws, joint_dist, *, num_leapfrog_steps, n_chains=64, num_adaptation_steps=525, target_accept_prob=0.75, trace_fn=_default_trace_fn, return_final_kernel_results=False, discard_tuning=True, seed=None, **pins): """Adapt and sample from a joint distribution, conditioned on pins. This uses Hamiltonian Monte Carlo to do the sampling. Step size is tuned using a dual-averaging adaptation, and the kernel is conditioned using a diagonal mass matrix, which is estimated using expanding windows. Args: n_draws: int Number of draws after adaptation. joint_dist: `tfd.JointDistribution` A joint distribution to sample from. num_leapfrog_steps: int Number of leapfrog steps to use for the Hamiltonian Monte Carlo step. n_chains: int Number of independent chains to run MCMC with. num_adaptation_steps: int Number of draws used to adapt step size and target_accept_prob: float Target acceptance rate for the step size adaptation. trace_fn: Optional callable The trace function should accept the arguments `(state, bijector, is_adapting, phmc_kernel_results)`, where the `state` is an unconstrained, flattened float tensor, `bijector` is the `tfb.Bijector` that is used for unconstraining and flattening, `is_adapting` is a boolean to mark whether the draw is from an adaptation step, and `phmc_kernel_results` is the `UncalibratedPreconditionedHamiltonianMonteCarloKernelResults` from the `PreconditionedHamiltonianMonteCarlo` kernel. Note that `bijector.inverse(state)` will provide access to the current draw in the untransformed space, using the structure of the provided `joint_dist`. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state and the trace specified by the `trace_fn`. discard_tuning: bool Whether to return tuning traces and draws. seed: Optional, a seed for reproducible sampling. **pins: These are used to condition the provided joint distribution, and are passed directly to `joint_dist.experimental_pin(**pins)`. Returns: A single structure of draws is returned in case the trace_fn is `None`, and `return_final_kernel_results` is `False`. If there is a trace function, the return value is a tuple, with the trace second. If the `return_final_kernel_results` is `True`, the return value is a tuple of length 3, with final kernel results returned last. If `discard_tuning` is `True`, the tensors in `draws` and `trace` will have length `n_draws`, otherwise, they will have length `n_draws + num_adaptation_steps`. """ if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False num_leapfrog_steps = tf.convert_to_tensor(num_leapfrog_steps) target_accept_prob = tf.convert_to_tensor(target_accept_prob) num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc( joint_dist, n_chains=n_chains, seed=setup_seed, **pins) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) # If we (over) optimistically assume good scaling, this will be near the # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts, # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv. # http://arxiv.org/abs/1905.09813. init_step_size = tf.cast( 1. / ps.shape(initial_transformed_position)[-1]**0.25, tf.float32) all_draws = [] all_traces = [] draws, trace, step_size, running_variance = _fast_window( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, num_draws=first_window_size, initial_position=initial_transformed_position, initial_step_size=tf.fill([n_chains, 1], init_step_size), target_accept_prob=target_accept_prob, momentum_distribution=_init_momentum(initial_transformed_position), bijector=bijector, trace_fn=trace_fn, seed=init_seed) all_draws.append(draws) all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. draws, trace, step_size, running_variance, momentum_distribution = _slow_window( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, num_draws=window_size, initial_position=draws[-1], initial_running_variance=running_variance, initial_step_size=step_size, target_accept_prob=target_accept_prob, bijector=bijector, trace_fn=trace_fn, seed=slow_seed) all_draws.append(draws) all_traces.append(trace) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, running_variance = _fast_window( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, num_draws=last_window_size, initial_position=draws[-1], initial_step_size=step_size, target_accept_prob=target_accept_prob, momentum_distribution=momentum_distribution, bijector=bijector, trace_fn=trace_fn, seed=fast_seed) all_draws.append(draws) all_traces.append(trace) ret = _do_sampling(target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, num_draws=n_draws, initial_position=draws[-1], step_size=step_size, momentum_distribution=momentum_distribution, bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace( all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret all_draws.append(draws) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret all_draws.append(draws) all_traces.append(trace) if no_trace: return bijector.inverse(tf.concat(all_draws, axis=0)) else: return sample.StatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure( lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): # A Tensor num_draws argument breaks XLA, which requires static TensorArray # trace_fn result allocation sizes. num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc( joint_dist, n_chains=n_chains, init_position=current_state, seed=setup_seed, **pins) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) # If we (over) optimistically assume good scaling, this will be near the # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts, # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv. # http://arxiv.org/abs/1905.09813. init_step_size = 0.5 * sum([ tf.cast(ps.shape(state_part)[-1], tf.float32) for state_part in initial_transformed_position])**-0.25 proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': tf.fill([n_chains, 1], init_step_size, name='step_size'), 'momentum_distribution': _init_momentum(initial_transformed_position, batch_shape=[n_chains]), }) all_traces = [] draws, trace, step_size, running_variances = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=first_window_size, initial_position=initial_transformed_position, bijector=bijector, trace_fn=trace_fn, seed=init_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws = [[d] for d in draws] all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. draws, trace, step_size, running_variances, momentum_distribution = _slow_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=[d[-1] for d in draws], initial_running_variance=running_variances, bijector=bijector, trace_fn=trace_fn, seed=slow_seed) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) proposal_kernel_kwargs.update( {'step_size': step_size, 'momentum_distribution': momentum_distribution}) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, _ = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=last_window_size, initial_position=[d[-1] for d in draws], bijector=bijector, trace_fn=trace_fn, seed=fast_seed) proposal_kernel_kwargs.update({'step_size': step_size}) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, num_draws=n_draws, initial_position=[d[-1] for d in draws], bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse( [tf.concat(d, axis=0) for d in all_draws]), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws]) if no_trace: return all_states else: all_traces.append(trace) return sample.StatesAndTrace( all_states=all_states, trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))