def test_seed_reproducibility(self): first_fake_kernel = RandomTransitionKernel() second_fake_kernel = RandomTransitionKernel() seed = samplers.sanitize_seed(test_util.test_seed()) last_state_t = step_kernel( num_steps=1, current_state=0, kernel=RandomTransitionKernel(), seed=seed, ) for num_steps in range(2, 5): first_final_state_t = step_kernel( num_steps=num_steps, current_state=0., kernel=first_fake_kernel, seed=seed, ) second_final_state_t = step_kernel( num_steps=num_steps, current_state=1., # difference should be irrelevant kernel=second_fake_kernel, seed=seed, ) last_state, first_final_state, second_final_state = self.evaluate([ last_state_t, first_final_state_t, second_final_state_t ]) self.assertEqual(first_final_state, second_final_state) self.assertNotEqual(first_final_state, last_state) last_state_t = first_final_state_t
def one_step(self, current_state, previous_kernel_results, seed=None): """Collects one non-discarded chain state. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s), previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function). seed: Optional seed for reproducible sampling. Returns: new_chain_state: Newest non-discarded MCMC chain state drawn from the `inner_kernel`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ with tf.name_scope( mcmc_util.make_name(self.name, 'sample_discarding_kernel', 'one_step')): new_chain_state, inner_kernel_results = sample.step_kernel( num_steps=self._num_samples_to_skip( previous_kernel_results.call_counter) + 1, current_state=current_state, previous_kernel_results=previous_kernel_results.inner_results, kernel=self.inner_kernel, return_final_kernel_results=True, seed=seed, name=self.name) new_kernel_results = SampleDiscardingKernelResults( previous_kernel_results.call_counter + 1, inner_kernel_results) return new_chain_state, new_kernel_results
def test_initial_state(self): fake_kernel = TestTransitionKernel() final_state = step_kernel( num_steps=2, current_state=1, kernel=fake_kernel, ) final_state = self.evaluate(final_state) self.assertEqual(final_state, 3)
def test_simple_operation(self): fake_kernel = TestTransitionKernel() final_state, kernel_results = step_kernel( num_steps=2, current_state=0, kernel=fake_kernel, return_final_kernel_results=True ) final_state, kernel_results = self.evaluate([final_state, kernel_results]) self.assertEqual(final_state, 2) self.assertEqual(kernel_results.counter_1, 2) self.assertEqual(kernel_results.counter_2, 4)
def test_defined_pkr(self): fake_kernel = TestTransitionKernel() init_pkr = TestTransitionKernelResults( tf.constant(2, dtype=tf.int32), tf.constant(3, dtype=tf.int32)) final_state, kernel_results = step_kernel( num_steps=2, current_state=0, previous_kernel_results=init_pkr, kernel=fake_kernel, return_final_kernel_results=True ) final_state, kernel_results = self.evaluate([final_state, kernel_results]) self.assertEqual(final_state, 2) self.assertEqual(kernel_results.counter_1, 4) self.assertEqual(kernel_results.counter_2, 7)
def sample(self, num_steps, current_state, previous_kernel_results=None): """Sample from the configured kernel. Args: num_steps: Integer number of Markov chain steps. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s. Warm-start for the auxiliary state needed by the given `kernel`. If not supplied, `step_kernel` will cold-start with `kernel.bootstrap_results`. Returns: outputs: A `KernelOutputs` object containing the states, trace, etc. """ kernel = self.build(num_steps) state, results = sample.step_kernel( num_steps=num_steps, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=kernel, return_final_kernel_results=True) return kernel_outputs.KernelOutputs(kernel, state, results)
def sample_fold( num_steps, current_state, previous_kernel_results=None, kernel=None, reducer=None, previous_reducer_state=None, return_final_reducer_states=False, num_burnin_steps=0, num_steps_between_results=0, parallel_iterations=10, seed=None, name=None, ): """Computes the requested reductions over the `kernel`'s samples. To wit, runs the given `kernel` for `num_steps` steps, and consumes the stream of samples with the given `Reducer`s' `one_step` method(s). This runs in constant memory (unless a given `Reducer` builds a large structure). The driver internally composes the correct onion of `WithReductions` and `SampleDiscardingKernel` to implement the requested optionally thinned reduction; however, the kernel results of those applied Transition Kernels will not be returned. Hence, if warm-restarting reductions is desired, one should manually build the Transition Kernel onion and use `tfp.experimental.mcmc.step_kernel`. An arbitrary collection of `reducer` can be provided, and the resulting finalized statistic(s) will be returned in an identical structure. This function can sample from and reduce over multiple chains, in parallel. Whether or not there are multiple chains is dictated by how the `kernel` treats its inputs. Typically, the shape of the independent chains is shape of the result of the `target_log_prob_fn` used by the `kernel` when applied to the given `current_state`. Args: num_steps: Integer or scalar `Tensor` representing the number of `Reducer` steps. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s. Warm-start for the auxiliary state needed by the given `kernel`. If not supplied, `sample_fold` will cold-start with `kernel.bootstrap_results`. kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. reducer: A (possibly nested) structure of `Reducer`s to be evaluated on the `kernel`'s samples. If no reducers are given (`reducer=None`), then `None` will be returned in place of streaming calculations. previous_reducer_state: A (possibly nested) structure of running states corresponding to the structure in `reducer`. For resuming streaming reduction computations begun in a previous run. return_final_reducer_states: A Python `bool` giving whether to return resumable final reducer states. num_burnin_steps: Integer or scalar `Tensor` representing the number of chain steps to take before starting to collect results. Defaults to 0 (i.e., no burn-in). num_steps_between_results: Integer or scalar `Tensor` representing the number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. Defaults to 0 (i.e., no thinning). parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_sample_fold'). Returns: reduction_results: A (possibly nested) structure of finalized reducer statistics. The structure identically mimics that of `reducer`. end_state: The final state of the Markov chain(s). final_kernel_results: `collections.namedtuple` of internal calculations used to advance the supplied `kernel`. These results do not include the kernel results of `WithReductions` or `SampleDiscardingKernel`. final_reducer_states: A (possibly nested) structure of final running reducer states, if `return_final_reducer_states` was `True`. Can be used to resume streaming reductions when continuing sampling. """ with tf.name_scope(name or 'mcmc_sample_fold'): num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32, name='num_steps') current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='current_state'), current_state) reducer_was_none = False if reducer is None: reducer = [] reducer_was_none = True thinning_kernel = sample_discarding_kernel.SampleDiscardingKernel( inner_kernel=kernel, num_burnin_steps=num_burnin_steps, num_steps_between_results=num_steps_between_results) reduction_kernel = with_reductions.WithReductions( inner_kernel=thinning_kernel, reducer=reducer, ) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) thinning_pkr = thinning_kernel.bootstrap_results( current_state, previous_kernel_results) reduction_pkr = reduction_kernel.bootstrap_results( current_state, thinning_pkr, previous_reducer_state) end_state, final_kernel_results = exp_sample_lib.step_kernel( num_steps=num_steps, current_state=current_state, previous_kernel_results=reduction_pkr, kernel=reduction_kernel, return_final_kernel_results=True, parallel_iterations=parallel_iterations, seed=seed, name=name, ) reduction_results = nest.map_structure_up_to( reducer, lambda r, s: r.finalize(s), reducer, final_kernel_results.streaming_calculations, check_types=False) if reducer_was_none: reduction_results = None # TODO(axch): Choose a friendly return value convention that # - Doesn't burden the user with needless stuff when they don't want it # - Supports warm restart when the user does want it # - Doesn't trigger Pylint's unbalanced-tuple-unpacking warning. if return_final_reducer_states: return (reduction_results, end_state, final_kernel_results.inner_results.inner_results, final_kernel_results.streaming_calculations) else: return (reduction_results, end_state, final_kernel_results.inner_results.inner_results)
def sample_fold( num_steps, current_state, previous_kernel_results=None, kernel=None, reducer=None, num_burnin_steps=0, num_steps_between_results=0, parallel_iterations=10, seed=None, name=None, ): """Computes the requested reductions over the `kernel`'s samples. To wit, runs the given `kernel` for `num_steps` steps, and consumes the stream of samples with the given `Reducer`s' `one_step` method(s). This runs in constant memory (unless a given `Reducer` builds a large structure). The driver internally composes the correct onion of `WithReductions` and `SampleDiscardingKernel` to implement the requested optionally thinned reduction; however, the kernel results of those applied Transition Kernels will not be returned. Hence, if warm-restarting reductions is desired, one should manually build the Transition Kernel onion and use `tfp.experimental.mcmc.step_kernel`. An arbitrary collection of `reducer` can be provided, and the resulting finalized statistic(s) will be returned in an identical structure. Args: num_steps: Integer or scalar `Tensor` representing the number of `Reducer` steps. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s. Warm-start for the auxiliary state needed by the given `kernel`. If not supplied, `sample_fold` will cold-start with `kernel.bootstrap_results`. kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. reducer: A (possibly nested) structure of `Reducer`s to be evaluated on the `kernel`'s samples. If no reducers are given (`reducer=None`), then `None` will be returned in place of streaming calculations. num_burnin_steps: Integer or scalar `Tensor` representing the number of chain steps to take before starting to collect results. Defaults to 0 (i.e., no burn-in). num_steps_between_results: Integer or scalar `Tensor` representing the number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. Defaults to 0 (i.e., no thinning). parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_sample_fold'). Returns: reduction_results: A (possibly nested) structure of finalized reducer statistics. The structure identically mimics that of `reducer`. end_state: The final state of the Markov chain(s). final_kernel_results: `collections.namedtuple` of internal calculations used to advance the supplied `kernel`. These results do not include the kernel results of `WithReductions` or `SampleDiscardingKernel`. """ with tf.name_scope(name or 'mcmc_sample_fold'): num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32, name='num_steps') current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='current_state'), current_state) reducer_was_none = False if reducer is None: reducer = [] reducer_was_none = True reduction_kernel = with_reductions.WithReductions( inner_kernel=sample_discarding_kernel.SampleDiscardingKernel( inner_kernel=kernel, num_burnin_steps=num_burnin_steps, num_steps_between_results=num_steps_between_results), reducer=reducer, ) end_state, final_kernel_results = sample.step_kernel( num_steps=num_steps, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=reduction_kernel, return_final_kernel_results=True, parallel_iterations=parallel_iterations, seed=seed, name=name, ) reduction_results = nest.map_structure_up_to( reducer, lambda r, s: r.finalize(s), reducer, final_kernel_results.streaming_calculations, check_types=False) if reducer_was_none: reduction_results = None return (reduction_results, end_state, final_kernel_results.inner_results.inner_results)
def run_kernel( kernel, num_results, current_state, previous_kernel_results=None, reducer=(), previous_reducer_state=None, trace_fn=_trace_everything, parallel_iterations=10, seed=None, name=None, ): """Runs a Markov chain defined by the given `TransitionKernel`. This is meant as a (more) helpful frontend to the low-level `TransitionKernel`-based MCMC API, supporting several main features: - Running a batch of multiple independent chains using SIMD parallelism - Tracing the history of the chains, or not tracing it to save memory - Computing reductions over chain history, whether it is also traced or not - Warm (re-)start, including auxiliary state This function samples from a Markov chain at `current_state` whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. This function can sample from multiple chains, in parallel. Whether or not there are multiple chains is dictated by how the `kernel` treats its inputs. Typically, the shape of the independent chains is shape of the result of the `target_log_prob_fn` used by the `kernel` when applied to the given `current_state`. This function can compute reductions over the samples in tandem with sampling, for example to return summary statistics without materializing all the samples. To request reductions, pass a `Reducer` object, or a nested structure of `Reducer` objects, as the `reducer=` argument. In addition to the chain state, this function supports tracing of auxiliary variables used by the kernel, as well as intermediate values of any supplied reductions. The traced values are selected by specifying `trace_fn`. The `trace_fn` must be a callable accepting three arguments: the chain state, the kernel_results of the `kernel`, and the current results of the reductions, if any are supplied. The return value of `trace_fn` (which may be a `Tensor` or a nested structure of `Tensor`s) is accumulated, such that each `Tensor` gains a new outmost dimension representing time in the chain history. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning' is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results, or are shown to any supplied reductions. The extra steps are never materialized, and thus do not increase memory requirements. Args: kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_results: Integer number of (non-discarded) Markov chain draws to compute. current_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). reducer: A (possibly nested) structure of `Reducer`s to be evaluated on the `kernel`'s samples. If no reducers are given (`reducer=None`), their states will not be passed to any supplied `trace_fn`. previous_reducer_state: A (possibly nested) structure of running states corresponding to the structure in `reducer`. For resuming streaming reduction computations begun in a previous run. trace_fn: A callable that takes in the current chain state, the current auxiliary kernel state, and the current result of any reducers, and returns a `Tensor` or a nested collection of `Tensor`s that is then traced. If `None`, nothing is traced. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_run_kernel'). Returns: result: A `RunKernelResults` instance containing information about the sampling run. Main fields are `trace`, the history of outputs of `trace_fn`, and `reduction_results`, the final outputs of all supplied `Reducer`s. See `RunKernelResults` for contents of other fields. """ # Features omitted for simplicity: # - Can only warm start either all the reducers or none of them, not # piecemeal. # # Defects admitted for simplicity: # - All reducers are finalized internally at every step, whether the user # wished to trace them or not. We expect graph mode TF to avoid that unused # computation, but eager mode will not. # - The user is not given the opportunity to trace the running state of # reducers. For example, the user cannot trace the sum and count of a # running mean, only the running mean itself. Arguably this is a feature, # because the sum and count can be considered implementation details, the # hiding of which is the purpose of the `finalize` method. with tf.name_scope(name or 'mcmc_run_kernel'): if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') if trace_fn is None: trace_fn = lambda *args: () # Form kernel onion reduction_kernel = with_reductions.WithReductions(inner_kernel=kernel, reducer=reducer) # User trace function should be called with # - current chain state # - kernel results structure of the passed-in kernel # - if there were any reducers, their intermediate results # # `WithReductions` will show the TracingReducer the intermediate state as # the kernel results of the onion named `reduction_kernel` above. This # wrapper converts from that to what the user-supplied trace function needs # to see. def internal_trace_fn(curr_state, kr): if reducer: def fin(reducer, red_state): return reducer.finalize(red_state) # Extra level of list will be unwrapped by *reduction_args, below. reduction_args = [ nest.map_structure_up_to(reducer, fin, reducer, kr.reduction_results) ] else: reduction_args = [] return trace_fn(curr_state, kr.inner_results, *reduction_args) trace_reducer = tracing_reducer.TracingReducer( trace_fn=internal_trace_fn, size=num_results) tracing_kernel = with_reductions.WithReductions( inner_kernel=reduction_kernel, reducer=trace_reducer, ) # Bootstrap corresponding warm start if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) reduction_pkr = reduction_kernel.bootstrap_results( current_state, previous_kernel_results, previous_reducer_state) tracing_pkr = tracing_kernel.bootstrap_results(current_state, reduction_pkr) # pylint: disable=unbalanced-tuple-unpacking final_state, tracing_kernel_results = exp_sample_lib.step_kernel( num_steps=num_results, current_state=current_state, previous_kernel_results=tracing_pkr, kernel=tracing_kernel, return_final_kernel_results=True, parallel_iterations=parallel_iterations, seed=seed, name=name, ) trace = trace_reducer.finalize( tracing_kernel_results.reduction_results) reduction_kernel_results = tracing_kernel_results.inner_results reduction_results = nest.map_structure_up_to( reducer, lambda r, s: r.finalize(s), reducer, reduction_kernel_results.reduction_results, check_types=False) user_kernel_results = reduction_kernel_results.inner_results resume_kwargs = { 'current_state': final_state, 'previous_kernel_results': user_kernel_results, 'kernel': kernel, 'reducer': reducer, 'previous_reducer_state': reduction_kernel_results.reduction_results, } return RunKernelResults(trace=trace, reduction_results=reduction_results, final_state=final_state, final_kernel_results=user_kernel_results, resume_kwargs=resume_kwargs)
def sample_chain_with_burnin( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, trace_fn=_trace_current_state, parallel_iterations=10, seed=None, name=None, ): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from a Markov chain at `current_state` whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning' is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized, and thus do not increase memory requirements. In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by specifying `trace_fn`. By default, all chain states but no kernel results are traced. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional, a seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'experimental_mcmc_sample_chain_with_burnin'). Returns: result: A `RunKernelResults` instance containing information about the sampling run. Main field is `trace`, the history of outputs of `trace_fn`. See `RunKernelResults` for contents of other fields. #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ with tf.name_scope(name or 'experimental_mcmc_sample_chain_with_burnin'): if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') if trace_fn is None: trace_fn = lambda *args: () burnin_seed, sampling_seed = random.split_seed(seed, n=2) # Burn-in run chain_state, kr = exp_sample_lib.step_kernel( num_steps=num_burnin_steps, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=kernel, return_final_kernel_results=True, parallel_iterations=parallel_iterations, seed=burnin_seed, name='burnin') thinning_k = thinning_kernel.ThinningKernel( kernel, num_steps_to_skip=num_steps_between_results) # ThinningKernel doesn't wrap the kernel_results structure, so we don't need # any of the usual munging. results = run.run_kernel(num_results=num_results, current_state=chain_state, previous_kernel_results=kr, kernel=thinning_k, trace_fn=trace_fn, parallel_iterations=parallel_iterations, seed=sampling_seed, name='sampling') del results.resume_kwargs['reducer'] del results.resume_kwargs['previous_reducer_state'] return results