def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, maybe_expand=False, description='target_log_prob'): """Processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] state_parts = [tf.convert_to_tensor(s, name='current_state') for s in state_parts] # Verifies that the input static shape is fully defined. state_shapes_defined = [s.shape.is_fully_defined() for s in state_parts] if not np.all(state_shapes_defined): raise ValueError('All static shapes must be fully defined.') target_log_prob = _maybe_call_fn( target_log_prob_fn, state_parts, target_log_prob, description) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor( s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob ]
def _fn(state_parts, seed): """Adds a normal perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Default value: `None`. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name, 'random_walk_normal_fn', values=[state_parts, scale, seed]): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') seed_stream = distributions.SeedStream(seed, salt='RandomWalkNormalFn') next_state_parts = [ tf.random_normal( mean=state_part, stddev=scale_part, shape=tf.shape(state_part), dtype=state_part.dtype.base_dtype, seed=seed_stream() ) for scale_part, state_part in zip(scales, state_parts)] return next_state_parts
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( name=mcmc_util.make_name(self.name, 'rwm', 'one_step'), values=[self.seed, current_state, previous_kernel_results.target_log_prob]): with tf.name_scope('initialize'): current_state_parts = (list(current_state) if mcmc_util.is_list_like(current_state) else [current_state]) current_state_parts = [tf.convert_to_tensor(s, name='current_state') for s in current_state_parts] self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='rwm_kernel_proposal') new_state_fn = self.new_state_fn next_state_parts = new_state_fn(current_state_parts, self._seed_stream) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros( shape=tf.shape(next_target_log_prob), dtype=next_target_log_prob.dtype.base_dtype), target_log_prob=next_target_log_prob, ), ]
def build_assign_op(): if mcmc_util.is_list_like(step_size_var): return [ ss.assign_add(ss * tf.cast(adjustment, ss.dtype)) for ss in step_size_var ] return step_size_var.assign_add( step_size_var * tf.cast(adjustment, step_size_var.dtype))
def inverse_transform_fn(bijector): """Makes a function which applies a list of Bijectors' `inverse`s.""" if not is_list_like(bijector): bijector = [bijector] def fn(state_parts): return [b.inverse(sp) for b, sp in zip(bijector, state_parts)] return fn
def bootstrap_results(self, init_state): with tf.name_scope(self.name, 'rwm_bootstrap_results', [init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] init_target_log_prob = self.target_log_prob_fn(*init_state) return UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob)
def forward_transform_fn(bijector): """Makes a function which applies a list of Bijectors' `forward`s.""" if not is_list_like(bijector): bijector = [bijector] def fn(transformed_state_parts): return [b.forward(sp) for b, sp in zip(bijector, transformed_state_parts)] return fn
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of the Transformed Kernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s), _after_ application of `bijector.forward`. The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. The `inner_kernel.one_step` does not actually use `current_state`, rather it takes as input `previous_kernel_results.transformed_state` (because `TransformedTransitionKernel` creates a copy of the input inner_kernel with a modified `target_log_prob_fn` which internally applies the `bijector.forward`). previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ with tf.name_scope( name=make_name(self.name, 'transformed_kernel', 'one_step'), values=[previous_kernel_results]): transformed_next_state, kernel_results = self._inner_kernel.one_step( previous_kernel_results.transformed_state, previous_kernel_results.inner_results) transformed_next_state_parts = ( transformed_next_state if is_list_like(transformed_next_state) else [transformed_next_state]) next_state_parts = self._forward_transform(transformed_next_state_parts) next_state = ( next_state_parts if is_list_like(transformed_next_state) else next_state_parts[0]) kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_next_state, inner_results=kernel_results) return next_state, kernel_results
def forward_log_det_jacobian_fn(bijector): """Makes a function which applies a list of Bijectors' `log_det_jacobian`s.""" if not is_list_like(bijector): bijector = [bijector] def fn(transformed_state_parts, event_ndims): return sum([ b.forward_log_det_jacobian(sp, event_ndims=e) for b, e, sp in zip(bijector, event_ndims, transformed_state_parts) ]) return fn
def _maybe_call_fn(fn, fn_arg_list, fn_result=None, description='target_log_prob'): """Helper which computes `fn_result` if needed.""" fn_arg_list = (list(fn_arg_list) if mcmc_util.is_list_like(fn_arg_list) else [fn_arg_list]) if fn_result is None: fn_result = fn(*fn_arg_list) if not fn_result.dtype.is_floating: raise TypeError('`{}` must be a `Tensor` with `float` `dtype`.'.format( description)) return fn_result
def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] state_parts = [tf.convert_to_tensor(s, name='current_state') for s in state_parts] if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor( s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob, grads_target_log_prob, ]
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'slice', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] direction = [tf.zeros_like(x) for x in init_state] init_target_log_prob = self.target_log_prob_fn(*init_state) # pylint:disable=not-callable return SliceSamplerKernelResults( target_log_prob=init_target_log_prob, bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob), dtype=tf.bool), direction=direction, upper_bounds=tf.zeros_like(init_target_log_prob), lower_bounds=tf.zeros_like(init_target_log_prob) )
def _loop_body(iter_, ais_weights, current_state, kernel_results): """Closure which implements `tf.while_loop` body.""" x = (current_state if mcmc_util.is_list_like(current_state) else [current_state]) proposal_log_prob = proposal_log_prob_fn(*x) target_log_prob = target_log_prob_fn(*x) ais_weights += ((target_log_prob - proposal_log_prob) / tf.cast(num_steps, ais_weights.dtype)) kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_)) next_state, inner_results = kernel.one_step( current_state, kernel_results.inner_results) kernel_results = AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) return [iter_ + 1, ais_weights, next_state, kernel_results]
def step_size_simple_update_fn(step_size_var, kernel_results): """Updates (list of) `step_size` using a standard adaptive MCMC procedure. Args: step_size_var: (List of) `tf.Variable`s representing the per `state_part` HMC `step_size`. kernel_results: `collections.namedtuple` containing `Tensor`s representing values from most recent call to `one_step`. Returns: step_size_assign: (List of) `Tensor`(s) representing updated `step_size_var`(s). """ if kernel_results is None: if mcmc_util.is_list_like(step_size_var): return [tf.identity(ss) for ss in step_size_var] return tf.identity(step_size_var) log_n = tf.log(tf.cast(tf.size(kernel_results.log_accept_ratio), kernel_results.log_accept_ratio.dtype)) log_mean_accept_ratio = tf.reduce_logsumexp( tf.minimum(kernel_results.log_accept_ratio, 0.)) - log_n adjustment = tf.where( log_mean_accept_ratio < tf.cast( tf.log(target_rate), log_mean_accept_ratio.dtype), -decrement_multiplier / (1. + decrement_multiplier), increment_multiplier) def build_assign_op(): if mcmc_util.is_list_like(step_size_var): return [ ss.assign_add(ss * tf.cast(adjustment, ss.dtype)) for ss in step_size_var ] return step_size_var.assign_add( step_size_var * tf.cast(adjustment, step_size_var.dtype)) if num_adaptation_steps is None: return build_assign_op() else: with tf.control_dependencies([step_counter.assign_add(1)]): return tf.cond(step_counter < num_adaptation_steps, build_assign_op, lambda: step_size_var)
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'), values=[init_state]): replica_results = [self.replica_kernels[i].bootstrap_results(init_state) for i in range(self.num_replica)] init_state_parts = (list(init_state) if mcmc_util.is_list_like(init_state) else [init_state]) replica_states = [[tf.identity(s) for s in init_state_parts] for i in range(self.num_replica)] def maybe_flatten(x): return x if mcmc_util.is_list_like(init_state) else x[0] replica_states = [maybe_flatten(s) for s in replica_states] next_replica_idx = tf.range(self.num_replica) [ exchange_proposed, exchange_proposed_n, ] = self.exchange_proposed_fn(self.num_replica, seed=self._seed_stream) exchange_proposed = tf.zeros_like(exchange_proposed) exchange_proposed_n = tf.zeros_like(exchange_proposed_n) return ReplicaExchangeMCKernelResults( replica_states=replica_states, replica_results=replica_results, next_replica_idx=next_replica_idx, exchange_proposed=exchange_proposed, exchange_proposed_n=exchange_proposed_n, sampled_replica_states=replica_states, sampled_replica_results=replica_results, )
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [tf.convert_to_tensor(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, )
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Hamiltonian Monte Carlo. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. """ previous_step_size_assign = ( [] if self.step_size_update_fn is None else (previous_kernel_results.extra.step_size_assign if mcmc_util.is_list_like( previous_kernel_results.extra.step_size_assign) else [previous_kernel_results.extra.step_size_assign])) with tf.control_dependencies(previous_step_size_assign): next_state, kernel_results = self._impl.one_step( current_state, previous_kernel_results) if self.step_size_update_fn is not None: step_size_assign = self.step_size_update_fn( # pylint: disable=not-callable self.step_size, kernel_results) kernel_results = kernel_results._replace( extra=HamiltonianMonteCarloExtraKernelResults( step_size_assign=step_size_assign)) return next_state, kernel_results
def _fn(state_parts, seed): """Adds a uniform perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Default value: `None`. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name, 'random_walk_uniform_fn', values=[state_parts, scale, seed]): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') next_state_parts = [] for scale_part, state_part in zip(scales, state_parts): # Mutate seed with each use. seed = distributions_util.gen_new_seed( seed, salt='random_walk_uniform_fn') next_state_parts.append(tf.random_uniform( minval=state_part - scale_part, maxval=state_part + scale_part, shape=tf.shape(state_part), dtype=state_part.dtype.base_dtype, seed=seed)) return next_state_parts
def one_step(current_state, previous_kernel_results): # Make next_state. if is_list_like(current_state): next_state = [] for i, s in enumerate(current_state): next_state.append(tf.identity(s * dtype(i + 2), name='next_state')) else: next_state = tf.identity(2. * current_state, name='next_state') # Make kernel_results. kernel_results = {} for fn in sorted(previous_kernel_results._fields): if fn == 'grads_target_log_prob': kernel_results['grads_target_log_prob'] = [ tf.identity(0.5 * g, name='grad_target_log_prob') for g in previous_kernel_results.grads_target_log_prob] else: kernel_results[fn] = tf.identity( 0.5 * getattr(previous_kernel_results, fn, None), name=fn) kernel_results = type(previous_kernel_results)(**kernel_results) # Done. return next_state, kernel_results
def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0]
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob(previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn('Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.log(tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose( is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results
def bootstrap_results(self, init_state=None, transformed_init_state=None): """Returns an object with the same type as returned by `one_step`. Unlike other `TransitionKernel`s, `TransformedTransitionKernel.bootstrap_results` has the option of initializing the `TransformedTransitionKernelResults` from either an initial state, eg, requiring computing `bijector.inverse(init_state)`, or directly from `transformed_init_state`, i.e., a `Tensor` or list of `Tensor`s which is interpretted as the `bijector.inverse` transformed state. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. transformed_init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". #### Examples To use `transformed_init_state` in context of `tfp.mcmc.sample_chain`, you need to explicitly pass the `previous_kernel_results`, e.g., ```python transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...) init_state = ... # Doesnt matter. transformed_init_state = ... # Does matter. results, _ = tfp.mcmc.sample_chain( num_results=..., current_state=init_state, previous_kernel_results=transformed_kernel.bootstrap_results( transformed_init_state=transformed_init_state), kernel=transformed_kernel) ``` """ if (init_state is None) == (transformed_init_state is None): raise ValueError('Must specify exactly one of `init_state` ' 'or `transformed_init_state`.') with tf.compat.v1.name_scope( name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'), values=[init_state, transformed_init_state]): if transformed_init_state is None: init_state_parts = (init_state if is_list_like(init_state) else [init_state]) transformed_init_state_parts = self._inverse_transform( init_state_parts) transformed_init_state = (transformed_init_state_parts if is_list_like(init_state) else transformed_init_state_parts[0]) else: if is_list_like(transformed_init_state): transformed_init_state = [ tf.convert_to_tensor(value=s, name='transformed_init_state') for s in transformed_init_state ] else: transformed_init_state = tf.convert_to_tensor( value=transformed_init_state, name='transformed_init_state') kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_init_state, inner_results=self._inner_kernel.bootstrap_results( transformed_init_state)) return kernel_results
def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
def sample_chain(num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, parallel_iterations=10, name=None): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from an Markov chain at `current_state` and whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such "thinning" is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized (in calls to `sess.run`), and thus do not increase memory requirements. Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s `parallel_iterations=1`, otherwise results will not be reproducible. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "mcmc_sample_chain"). Returns: next_states: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. #### Examples ##### Sample from a diagonal-variance Gaussian. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions def make_likelihood(true_variances): return tfd.MultivariateNormalDiag( scale_diag=tf.sqrt(true_variances)) dims = 10 dtype = np.float32 true_variances = tf.linspace(dtype(1), dtype(3), dims) likelihood = make_likelihood(true_variances) states, kernel_results = tfp.mcmc.sample_chain( num_results=1000, current_state=tf.zeros(dims), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), num_burnin_steps=500) # Compute sample stats. sample_mean = tf.reduce_mean(states, axis=0) sample_var = tf.reduce_mean( tf.squared_difference(states, sample_mean), axis=0) ``` ##### Sampling from factor-analysis posteriors with known factors. I.e., ```none for i=1..n: w[i] ~ Normal(0, eye(d)) # prior x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood ``` where `F` denotes factors. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, factors): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) # Setup data. num_weights = 10 num_factors = 4 num_chains = 100 dtype = np.float32 prior = make_prior(num_weights, dtype) weights = prior.sample(num_chains) factors = np.random.randn(num_factors, num_weights).astype(dtype) x = make_likelihood(weights, factors).sample(num_chains) def target_log_prob(w): # Target joint is: `f(w) = p(w, x | factors)`. return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) # Get `num_results` samples from `num_chains` independent chains. chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, current_state=tf.zeros([num_chains, dims], dtype), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob, step_size=0.1, num_leapfrog_steps=2), num_burnin_steps=500) # Compute sample stats. sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) sample_var = tf.reduce_mean( tf.squared_difference(chains_states, sample_mean), axis=[0, 1]) ``` #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ if not kernel.is_calibrated: warnings.warn( "Supplied `TransitionKernel` is not calibrated. Markov " "chain may not converge to intended target distribution.") with tf.name_scope( name, "mcmc_sample_chain", [num_results, num_burnin_steps, num_steps_between_results]): num_results = tf.convert_to_tensor(num_results, dtype=tf.int32, name="num_results") num_burnin_steps = tf.convert_to_tensor(num_burnin_steps, dtype=tf.int64, name="num_burnin_steps") num_steps_between_results = tf.convert_to_tensor( num_steps_between_results, dtype=tf.int64, name="num_steps_between_results") if mcmc_util.is_list_like(current_state): current_state = [ tf.convert_to_tensor(s, name="current_state") for s in current_state ] else: current_state = tf.convert_to_tensor(current_state, name="current_state") def _scan_body(args_list, num_steps): """Closure which implements `tf.scan` body.""" next_state, current_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=kernel.one_step, initial_loop_vars=args_list, parallel_iterations=parallel_iterations) return [next_state, current_kernel_results] if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) return tf.scan( fn=_scan_body, elems=tf.one_hot(indices=0, depth=num_results, on_value=1 + num_burnin_steps, off_value=1 + num_steps_between_results, dtype=tf.int64), # num_steps initializer=[current_state, previous_kernel_results], parallel_iterations=parallel_iterations)
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # Key difficulty: The type of exchanges differs from one call to the # next...even the number of exchanges can differ. # As a result, exchanges must happen dynamically, in while loops. with tf.name_scope( name=mcmc_util.make_name(self.name, 'remc', 'one_step'), values=[current_state, previous_kernel_results]): # Each replica does `one_step` to get pre-exchange states/KernelResults. sampled_replica_states, sampled_replica_results = zip(*[ rk.one_step(previous_kernel_results.replica_states[i], previous_kernel_results.replica_results[i]) for i, rk in enumerate(self.replica_kernels) ]) sampled_replica_states = list(sampled_replica_states) sampled_replica_results = list(sampled_replica_results) states_are_lists = mcmc_util.is_list_like(sampled_replica_states[0]) if not states_are_lists: sampled_replica_states = [[s] for s in sampled_replica_states] num_state_parts = len(sampled_replica_states[0]) dtype = sampled_replica_states[0][0].dtype # Must put states into TensorArrays. Why? We will read/write states # dynamically with Tensor index `i`, and you cannot do this with lists. # old_states[k][i] is Tensor of (old) state part k, for replica i. # The `k` will be known statically, and `i` is a Tensor. old_states = [ tf.TensorArray( dtype, size=self.num_replica, dynamic_size=False, clear_after_read=False, tensor_array_name='old_states', # State part k has same shape, regardless of replica. So use 0. element_shape=sampled_replica_states[0][k].shape) for k in range(num_state_parts) ] for k in range(num_state_parts): for i in range(self.num_replica): old_states[k] = old_states[k].write(i, sampled_replica_states[i][k]) exchange_proposed = self.exchange_proposed_fn( self.num_replica, seed=self._seed_stream()) exchange_proposed_n = tf.shape(exchange_proposed)[0] exchanged_states = self._get_exchanged_states( old_states, exchange_proposed, exchange_proposed_n, sampled_replica_states, sampled_replica_results) no_exchange_proposed, _ = tf.setdiff1d( tf.range(self.num_replica), tf.reshape(exchange_proposed, [-1])) exchanged_states = self._insert_old_states_where_no_exchange_was_proposed( no_exchange_proposed, old_states, exchanged_states) next_replica_states = [] for i in range(self.num_replica): next_replica_states_i = [] for k in range(num_state_parts): next_replica_states_i.append(exchanged_states[k].read(i)) next_replica_states.append(next_replica_states_i) if not states_are_lists: next_replica_states = [s[0] for s in next_replica_states] sampled_replica_states = [s[0] for s in sampled_replica_states] # Now that states are/aren't exchanged, bootstrap next kernel_results. # The viewpoint is that after each exchange, we are starting anew. next_replica_results = [ rk.bootstrap_results(state) for rk, state in zip(self.replica_kernels, next_replica_states) ] next_state = next_replica_states[0] # Replica 0 is the returned state(s). kernel_results = ReplicaExchangeMCKernelResults( replica_states=next_replica_states, replica_results=next_replica_results, sampled_replica_states=sampled_replica_states, sampled_replica_results=sampled_replica_results, ) return next_state, kernel_results
def sample_annealed_importance_chain(num_steps, proposal_log_prob_fn, target_log_prob_fn, current_state, make_kernel_fn, parallel_iterations=10, name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial "proposal" distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution: `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. Note: When running in graph mode, `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three times (although this may be reduced to two times in the future). Args: num_steps: Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator. proposal_log_prob_fn: Python callable that returns the log density of the initial distribution. target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_annealed_importance_chain` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "sample_annealed_importance_chain"). Returns: next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching `target_log_prob_fn(current_state)`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### Examples ##### Estimate the normalizing constant of a log-gamma distribution. ```python tfd = tfp.distributions # Run 100 AIS chains in parallel num_chains = 100 dims = 20 dtype = np.float32 proposal = tfd.MultivatiateNormalDiag( loc=tf.zeros([dims], dtype=dtype)) target = tfd.TransformedDistribution( distribution=tfd.Gamma(concentration=dtype(2), rate=dtype(3)), bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()), event_shape=[dims]) chains_state, ais_weights, kernels_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target.log_prob, current_state=proposal.sample(num_chains), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.2, num_leapfrog_steps=2))) log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` ##### Estimate marginal likelihood of a Bayesian regression model. ```python tfd = tfp.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, x): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, x, axes=[[0], [-1]])) # Run 100 AIS chains in parallel num_chains = 100 dims = 10 dtype = np.float32 # Make training data. x = np.random.randn(num_chains, dims).astype(dtype) true_weights = np.random.randn(dims).astype(dtype) y = np.dot(x, true_weights) + np.random.randn(num_chains) # Setup model. prior = make_prior(dims, dtype) def target_log_prob_fn(weights): return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) proposal = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) weight_samples, ais_weights, kernel_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target_log_prob_fn current_state=tf.zeros([num_chains, dims], dtype), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.1, num_leapfrog_steps=2))) log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) ``` """ with tf.compat.v1.name_scope(name, "sample_annealed_importance_chain", [num_steps, current_state]): num_steps = tf.convert_to_tensor(value=num_steps, dtype=tf.int32, name="num_steps") if mcmc_util.is_list_like(current_state): current_state = [ tf.convert_to_tensor(value=s, name="current_state") for s in current_state ] else: current_state = tf.convert_to_tensor(value=current_state, name="current_state") def _make_convex_combined_log_prob_fn(iter_): def _fn(*args): p = tf.identity(proposal_log_prob_fn(*args), name="proposal_log_prob") t = tf.identity(target_log_prob_fn(*args), name="target_log_prob") dtype = p.dtype.base_dtype beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype) return tf.identity(beta * t + (1. - beta) * p, name="convex_combined_log_prob") return _fn def _loop_body(iter_, ais_weights, current_state, kernel_results): """Closure which implements `tf.while_loop` body.""" x = (current_state if mcmc_util.is_list_like(current_state) else [current_state]) proposal_log_prob = proposal_log_prob_fn(*x) target_log_prob = target_log_prob_fn(*x) ais_weights += ((target_log_prob - proposal_log_prob) / tf.cast(num_steps, ais_weights.dtype)) kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_)) next_state, inner_results = kernel.one_step( current_state, kernel_results.inner_results) kernel_results = AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) return [iter_ + 1, ais_weights, next_state, kernel_results] def _bootstrap_results(init_state): """Creates first version of `previous_kernel_results`.""" kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0)) inner_results = kernel.bootstrap_results(init_state) convex_combined_log_prob = inner_results.accepted_results.target_log_prob dtype = convex_combined_log_prob.dtype.as_numpy_dtype shape = tf.shape(input=convex_combined_log_prob) proposal_log_prob = tf.fill(shape, dtype(np.nan), name="bootstrap_proposal_log_prob") target_log_prob = tf.fill(shape, dtype(np.nan), name="target_target_log_prob") return AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) previous_kernel_results = _bootstrap_results(current_state) inner_results = previous_kernel_results.inner_results ais_weights = tf.zeros(shape=tf.broadcast_dynamic_shape( tf.shape(input=inner_results.proposed_results.target_log_prob), tf.shape(input=inner_results.accepted_results.target_log_prob)), dtype=inner_results.proposed_results. target_log_prob.dtype.base_dtype) [_, ais_weights, current_state, kernel_results] = tf.while_loop( cond=lambda iter_, *args: iter_ < num_steps, body=_loop_body, loop_vars=[ np.int32(0), # iter_ ais_weights, current_state, previous_kernel_results, ], parallel_iterations=parallel_iterations) return [current_state, ais_weights, kernel_results]
def build_assign_op(): if mcmc_util.is_list_like(step_size_var): return [ss.assign_add(ss * adjustment) for ss in step_size_var] return step_size_var.assign_add(step_size_var * adjustment)
def sample_chain( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, parallel_iterations=10, name=None): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from an Markov chain at `current_state` and whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). This function can sample from multiple chains, in parallel. (Whether or not there are multiple chains is dictated by the `kernel`.) The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such "thinning" is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized (in calls to `sess.run`), and thus do not increase memory requirements. Warning: when setting a `seed` in the `kernel`, ensure that `sample_chain`'s `parallel_iterations=1`, otherwise results will not be reproducible. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "mcmc_sample_chain"). Returns: next_states: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. #### Examples ##### Sample from a diagonal-variance Gaussian. ```python import tensorflow tf import tensorflow_probability as tfp tfd = tf.contrib.distributions def make_likelihood(true_variances): return tfd.MultivariateNormalDiag( scale_diag=tf.sqrt(true_variances)) dims = 10 dtype = np.float32 true_variances = tf.linspace(dtype(1), dtype(3), dims) likelihood = make_likelihood(true_variances) states, kernel_results = tfp.mcmc.sample_chain( num_results=1000, current_state=tf.zeros(dims), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), num_burnin_steps=500) # Compute sample stats. sample_mean = tf.reduce_mean(states, axis=0) sample_var = tf.reduce_mean( tf.squared_difference(states, sample_mean), axis=0) ``` ##### Sampling from factor-analysis posteriors with known factors. I.e., ```none for i=1..n: w[i] ~ Normal(0, eye(d)) # prior x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood ``` where `F` denotes factors. ```python import tensorflow tf import tensorflow_probability as tfp tfd = tf.contrib.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, factors): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) # Setup data. num_weights = 10 num_factors = 4 num_chains = 100 dtype = np.float32 prior = make_prior(num_weights, dtype) weights = prior.sample(num_chains) factors = np.random.randn(num_factors, num_weights).astype(dtype) x = make_likelihood(weights, factors).sample(num_chains) def target_log_prob(w): # Target joint is: `f(w) = p(w, x | factors)`. return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) # Get `num_results` samples from `num_chains` independent chains. chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, current_state=tf.zeros([num_chains, dims], dtype), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob, step_size=0.1, num_leapfrog_steps=2), num_burnin_steps=500) # Compute sample stats. sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) sample_var = tf.reduce_mean( tf.squared_difference(chains_states, sample_mean), axis=[0, 1]) ``` #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ if not kernel.is_calibrated: warnings.warn("Supplied `TransitionKernel` is not calibrated. Markov " "chain may not converge to intended target distribution.") with tf.name_scope( name, "mcmc_sample_chain", [num_results, num_burnin_steps, num_steps_between_results]): num_results = tf.convert_to_tensor( num_results, dtype=tf.int32, name="num_results") num_burnin_steps = tf.convert_to_tensor( num_burnin_steps, dtype=tf.int32, name="num_burnin_steps") num_steps_between_results = tf.convert_to_tensor( num_steps_between_results, dtype=tf.int32, name="num_steps_between_results") if mcmc_util.is_list_like(current_state): current_state = [tf.convert_to_tensor(s, name="current_state") for s in current_state] else: current_state = tf.convert_to_tensor(current_state, name="current_state") def _scan_body(args_list, num_steps): """Closure which implements `tf.scan` body.""" current_state, previous_kernel_results = args_list return tf.while_loop( cond=lambda it_, *args: it_ < num_steps, body=lambda it_, cs, pkr: [it_ + 1] + list(kernel.one_step(cs, pkr)), loop_vars=[ np.int32(0), # it_ current_state, previous_kernel_results, ], parallel_iterations=parallel_iterations)[1:] # Lop off `it_`. if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) return tf.scan( fn=_scan_body, elems=tf.one_hot(indices=0, depth=num_results, on_value=1 + num_burnin_steps, off_value=1 + num_steps_between_results, dtype=tf.int32), # num_steps initializer=[current_state, previous_kernel_results], parallel_iterations=parallel_iterations)
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.log( tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results
def bootstrap_results(self, init_state=None, transformed_init_state=None): """Returns an object with the same type as returned by `one_step`. Unlike other `TransitionKernel`s, `TransformedTransitionKernel.bootstrap_results` has the option of initializing the `TransformedTransitionKernelResults` from either an initial state, eg, requiring computing `bijector.inverse(init_state)`, or directly from `transformed_init_state`, i.e., a `Tensor` or list of `Tensor`s which is interpretted as the `bijector.inverse` transformed state. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. transformed_init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". #### Examples To use `transformed_init_state` in context of `tfp.mcmc.sample_chain`, you need to explicitly pass the `previous_kernel_results`, e.g., ```python transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...) init_state = ... # Doesnt matter. transformed_init_state = ... # Does matter. results, _ = tfp.mcmc.sample_chain( num_results=..., current_state=init_state, previous_kernel_results=transformed_kernel.bootstrap_results( transformed_init_state=transformed_init_state), kernel=transformed_kernel) ``` """ if (init_state is None) == (transformed_init_state is None): raise ValueError('Must specify exactly one of `init_state` ' 'or `transformed_init_state`.') with tf.name_scope( name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'), values=[init_state, transformed_init_state]): if transformed_init_state is None: init_state_parts = (init_state if is_list_like(init_state) else [init_state]) transformed_init_state_parts = self._inverse_transform(init_state_parts) transformed_init_state = ( transformed_init_state_parts if is_list_like(init_state) else transformed_init_state_parts[0]) else: if is_list_like(transformed_init_state): transformed_init_state = [ tf.convert_to_tensor(s, name='transformed_init_state') for s in transformed_init_state ] else: transformed_init_state = tf.convert_to_tensor( transformed_init_state, name='transformed_init_state') kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_init_state, inner_results=self._inner_kernel.bootstrap_results( transformed_init_state)) return kernel_results
def _prepare_args(target_log_prob_fn, volatility_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, volatility=None, grads_volatility_fn=None, diffusion_drift=None, parallel_iterations=10): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] [ target_log_prob, grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) [ volatility_parts, grads_volatility, ] = _maybe_call_volatility_fn_and_grads( volatility_fn, state_parts, volatility, grads_volatility_fn, distribution_util.prefer_static_shape(target_log_prob), parallel_iterations) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor(value=s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes ] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') if diffusion_drift is None: diffusion_drift_parts = _get_drift(step_sizes, volatility_parts, grads_volatility, grads_target_log_prob) else: diffusion_drift_parts = (list(diffusion_drift) if mcmc_util.is_list_like(diffusion_drift) else [diffusion_drift]) if len(state_parts) != len(diffusion_drift): raise ValueError( 'There should be exactly one `diffusion_drift` or it ' 'should have same length as list-like `current_state`.') return [ state_parts, step_sizes, target_log_prob, grads_target_log_prob, volatility_parts, grads_volatility, diffusion_drift_parts, ]
def sample_annealed_importance_chain( num_steps, proposal_log_prob_fn, target_log_prob_fn, current_state, make_kernel_fn, parallel_iterations=10, name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. This function uses Hamiltonian Monte Carlo to sample from a series of distributions that slowly interpolates between an initial "proposal" distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution: `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three times (although this may be reduced to two times, in the future). Args: num_steps: Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator. proposal_log_prob_fn: Python callable that returns the log density of the initial distribution. target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_annealed_importance_chain` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "sample_annealed_importance_chain"). Returns: next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching `target_log_prob_fn(current_state)`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### Examples ##### Estimate the normalizing constant of a log-gamma distribution. ```python tfd = tfp.distributions # Run 100 AIS chains in parallel num_chains = 100 dims = 20 dtype = np.float32 proposal = tfd.MultivatiateNormalDiag( loc=tf.zeros([dims], dtype=dtype)) target = tfd.TransformedDistribution( distribution=tfd.Gamma(concentration=dtype(2), rate=dtype(3)), bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()), event_shape=[dims]) chains_state, ais_weights, kernels_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target.log_prob, current_state=proposal.sample(num_chains), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.2, num_leapfrog_steps=2))) log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` ##### Estimate marginal likelihood of a Bayesian regression model. ```python tfd = tfp.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, x): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, x, axes=[[0], [-1]])) # Run 100 AIS chains in parallel num_chains = 100 dims = 10 dtype = np.float32 # Make training data. x = np.random.randn(num_chains, dims).astype(dtype) true_weights = np.random.randn(dims).astype(dtype) y = np.dot(x, true_weights) + np.random.randn(num_chains) # Setup model. prior = make_prior(dims, dtype) def target_log_prob_fn(weights): return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) proposal = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) weight_samples, ais_weights, kernel_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target_log_prob_fn current_state=tf.zeros([num_chains, dims], dtype), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.1, num_leapfrog_steps=2))) log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) ``` """ with tf.name_scope( name, "sample_annealed_importance_chain", [num_steps, current_state]): num_steps = tf.convert_to_tensor( num_steps, dtype=tf.int32, name="num_steps") if mcmc_util.is_list_like(current_state): current_state = [tf.convert_to_tensor(s, name="current_state") for s in current_state] else: current_state = tf.convert_to_tensor( current_state, name="current_state") def _make_convex_combined_log_prob_fn(iter_): def _fn(*args): p = tf.identity(proposal_log_prob_fn(*args), name="proposal_log_prob") t = tf.identity(target_log_prob_fn(*args), name="target_log_prob") dtype = p.dtype.base_dtype beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype) return tf.identity(beta * t + (1. - beta) * p, name="convex_combined_log_prob") return _fn def _loop_body(iter_, ais_weights, current_state, kernel_results): """Closure which implements `tf.while_loop` body.""" x = (current_state if mcmc_util.is_list_like(current_state) else [current_state]) proposal_log_prob = proposal_log_prob_fn(*x) target_log_prob = target_log_prob_fn(*x) ais_weights += ((target_log_prob - proposal_log_prob) / tf.cast(num_steps, ais_weights.dtype)) kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_)) next_state, inner_results = kernel.one_step( current_state, kernel_results.inner_results) kernel_results = AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) return [iter_ + 1, ais_weights, next_state, kernel_results] def _bootstrap_results(init_state): """Creates first version of `previous_kernel_results`.""" kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0)) inner_results = kernel.bootstrap_results(init_state) convex_combined_log_prob = inner_results.accepted_results.target_log_prob dtype = convex_combined_log_prob.dtype.as_numpy_dtype shape = tf.shape(convex_combined_log_prob) proposal_log_prob = tf.fill(shape, dtype(np.nan), name="bootstrap_proposal_log_prob") target_log_prob = tf.fill(shape, dtype(np.nan), name="target_target_log_prob") return AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) previous_kernel_results = _bootstrap_results(current_state) inner_results = previous_kernel_results.inner_results ais_weights = tf.zeros( shape=tf.broadcast_dynamic_shape( tf.shape(inner_results.proposed_results.target_log_prob), tf.shape(inner_results.accepted_results.target_log_prob)), dtype=inner_results.proposed_results.target_log_prob.dtype.base_dtype) [_, ais_weights, current_state, kernel_results] = tf.while_loop( cond=lambda iter_, *args: iter_ < num_steps, body=_loop_body, loop_vars=[ np.int32(0), # iter_ ais_weights, current_state, previous_kernel_results, ], parallel_iterations=parallel_iterations) return [current_state, ais_weights, kernel_results]
def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0]