def one_step(self, current_state, previous_kernel_results): with tf.name_scope( name=mcmc_util.make_name(self.name, 'rwm', 'one_step'), values=[self.seed, current_state, previous_kernel_results.target_log_prob]): with tf.name_scope('initialize'): current_state_parts = (list(current_state) if mcmc_util.is_list_like(current_state) else [current_state]) current_state_parts = [tf.convert_to_tensor(s, name='current_state') for s in current_state_parts] self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='rwm_kernel_proposal') new_state_fn = self.new_state_fn next_state_parts = new_state_fn(current_state_parts, self._seed_stream) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros( shape=tf.shape(next_target_log_prob), dtype=next_target_log_prob.dtype.base_dtype), target_log_prob=next_target_log_prob, ), ]
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'mh', 'bootstrap_results'), values=[init_state]): pkr = self.inner_kernel.bootstrap_results(init_state) if not has_target_log_prob(pkr): raise ValueError( '"target_log_prob" must be a member of `inner_kernel` results.') x = pkr.target_log_prob return MetropolisHastingsKernelResults( accepted_results=pkr, is_accepted=tf.ones_like(x, dtype=tf.bool), log_accept_ratio=tf.zeros_like(x), proposed_state=init_state, proposed_results=pkr, extra=[], )
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'rwm', 'one_step'), values=[ self.seed, current_state, previous_kernel_results.target_log_prob ]): with tf.name_scope('initialize'): current_state_parts = (list(current_state) if mcmc_util.is_list_like(current_state) else [current_state]) current_state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in current_state_parts ] new_state_fn = self.new_state_fn next_state_parts = new_state_fn(current_state_parts, self._seed_stream()) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros( shape=tf.shape(next_target_log_prob), dtype=next_target_log_prob.dtype.base_dtype), target_log_prob=next_target_log_prob, ), ]
def bootstrap_results(self, init_state): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] init_target_log_prob = self.target_log_prob_fn(*init_state) init_grads_target_log_prob = tf.gradients(init_target_log_prob, init_state) return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, )
def bootstrap_results(self, init_state): with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] direction = [tf.zeros_like(x) for x in init_state] init_target_log_prob = self.target_log_prob_fn(*init_state) # pylint:disable=not-callable return SliceSamplerKernelResults( target_log_prob=init_target_log_prob, bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob), dtype=tf.bool), direction=direction, upper_bounds=tf.zeros_like(init_target_log_prob), lower_bounds=tf.zeros_like(init_target_log_prob))
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'), values=[init_state]): replica_results = [ self.replica_kernels[i].bootstrap_results(init_state) for i in range(self.num_replica) ] init_state_parts = (list(init_state) if mcmc_util.is_list_like(init_state) else [init_state]) replica_states = [[tf.identity(s) for s in init_state_parts] for i in range(self.num_replica)] def maybe_flatten(x): return x if mcmc_util.is_list_like(init_state) else x[0] replica_states = [maybe_flatten(s) for s in replica_states] next_replica_idx = tf.range(self.num_replica) [ exchange_proposed, exchange_proposed_n, ] = self.exchange_proposed_fn(self.num_replica, seed=self._seed_stream) exchange_proposed = tf.zeros_like(exchange_proposed) exchange_proposed_n = tf.zeros_like(exchange_proposed_n) return ReplicaExchangeMCKernelResults( replica_states=replica_states, replica_results=replica_results, next_replica_idx=next_replica_idx, exchange_proposed=exchange_proposed, exchange_proposed_n=exchange_proposed_n, sampled_replica_states=replica_states, sampled_replica_results=replica_results, )
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'slice', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] direction = [tf.zeros_like(x) for x in init_state] init_target_log_prob = self.target_log_prob_fn(*init_state) # pylint:disable=not-callable return SliceSamplerKernelResults( target_log_prob=init_target_log_prob, bounds_satisfied=tf.zeros(shape=tf.shape(init_target_log_prob), dtype=tf.bool), direction=direction, upper_bounds=tf.zeros_like(init_target_log_prob), lower_bounds=tf.zeros_like(init_target_log_prob) )
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of the Transformed Kernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s), _after_ application of `bijector.forward`. The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. The `inner_kernel.one_step` does not actually use `current_state`, rather it takes as input `previous_kernel_results.transformed_state` (because `TransformedTransitionKernel` creates a copy of the input inner_kernel with a modified `target_log_prob_fn` which internally applies the `bijector.forward`). previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ with tf.compat.v1.name_scope(name=make_name(self.name, 'transformed_kernel', 'one_step'), values=[previous_kernel_results]): transformed_next_state, kernel_results = self._inner_kernel.one_step( previous_kernel_results.transformed_state, previous_kernel_results.inner_results) transformed_next_state_parts = ( transformed_next_state if is_list_like(transformed_next_state) else [transformed_next_state]) next_state_parts = self._forward_transform( transformed_next_state_parts) next_state = (next_state_parts if is_list_like(transformed_next_state) else next_state_parts[0]) kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_next_state, inner_results=kernel_results) return next_state, kernel_results
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'), values=[init_state]): replica_results = [self.replica_kernels[i].bootstrap_results(init_state) for i in range(self.num_replica)] init_state_parts = (list(init_state) if mcmc_util.is_list_like(init_state) else [init_state]) replica_states = [[tf.identity(s) for s in init_state_parts] for i in range(self.num_replica)] def maybe_flatten(x): return x if mcmc_util.is_list_like(init_state) else x[0] replica_states = [maybe_flatten(s) for s in replica_states] next_replica_idx = tf.range(self.num_replica) [ exchange_proposed, exchange_proposed_n, ] = self.exchange_proposed_fn(self.num_replica, seed=self._seed_stream) exchange_proposed = tf.zeros_like(exchange_proposed) exchange_proposed_n = tf.zeros_like(exchange_proposed_n) return ReplicaExchangeMCKernelResults( replica_states=replica_states, replica_results=replica_results, next_replica_idx=next_replica_idx, exchange_proposed=exchange_proposed, exchange_proposed_n=exchange_proposed_n, sampled_replica_states=replica_states, sampled_replica_results=replica_results, )
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of the Transformed Kernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s), _after_ application of `bijector.forward`. The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. The `inner_kernel.one_step` does not actually use `current_state`, rather it takes as input `previous_kernel_results.transformed_state` (because `TransformedTransitionKernel` creates a copy of the input inner_kernel with a modified `target_log_prob_fn` which internally applies the `bijector.forward`). previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ with tf.name_scope( name=make_name(self.name, 'transformed_kernel', 'one_step'), values=[previous_kernel_results]): transformed_next_state, kernel_results = self._inner_kernel.one_step( previous_kernel_results.transformed_state, previous_kernel_results.inner_results) transformed_next_state_parts = ( transformed_next_state if is_list_like(transformed_next_state) else [transformed_next_state]) next_state_parts = self._forward_transform(transformed_next_state_parts) next_state = ( next_state_parts if is_list_like(transformed_next_state) else next_state_parts[0]) kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_next_state, inner_results=kernel_results) return next_state, kernel_results
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [tf.convert_to_tensor(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, )
def bootstrap_results(self, init_state): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [ tf.convert_to_tensor(value=x) for x in init_state ] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) if self._store_parameters_in_results: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda value=x, dtype=init_target_log_prob.dtype, name='step_size'), self.step_size), num_leapfrog_steps=tf.convert_to_tensor( value=self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps')) else: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=[], num_leapfrog_steps=[])
def bootstrap_results(self, init_state): with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala', 'bootstrap_results'), values=[init_state]): init_state_parts = (list(init_state) if mcmc_util.is_list_like(init_state) else [init_state]) init_state_parts = [ tf.convert_to_tensor(x) for x in init_state_parts ] init_volatility = self.volatility_fn(*init_state_parts) # pylint: disable=not-callable [ _, # state_parts _, # step_sizes init_target_log_prob, init_grads_target_log_prob, init_volatility, init_grads_volatility, init_diffusion_drift, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, state=init_state_parts, step_size=self.step_size, volatility=init_volatility, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(init_state) else x[0] return UncalibratedLangevinKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, volatility=maybe_flatten(init_volatility), grads_volatility=init_grads_volatility, diffusion_drift=init_diffusion_drift)
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc', 'bootstrap_results'), values=[init_state]): replica_results = [ self.replica_kernels[i].bootstrap_results(init_state) for i in range(self.num_replica) ] init_state_parts = (list(init_state) if mcmc_util.is_list_like(init_state) else [init_state]) # Convert all states parts to tensor... replica_states = [[ tf.convert_to_tensor(s) for s in init_state_parts ] for i in range(self.num_replica)] if not mcmc_util.is_list_like(init_state): replica_states = [s[0] for s in replica_states] return ReplicaExchangeMCKernelResults( replica_states=replica_states, replica_results=replica_results, sampled_replica_states=replica_states, sampled_replica_results=replica_results, )
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[ self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob ]): [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random_normal(shape=tf.shape(x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self. state_gradients_are_stopped) # Do leapfrog integration. [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < self.num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args) ), loop_vars=[ tf.zeros([], tf.int32, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts, ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction= _compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ), ]
def bootstrap_results(self, init_state=None, transformed_init_state=None): """Returns an object with the same type as returned by `one_step`. Unlike other `TransitionKernel`s, `TransformedTransitionKernel.bootstrap_results` has the option of initializing the `TransformedTransitionKernelResults` from either an initial state, eg, requiring computing `bijector.inverse(init_state)`, or directly from `transformed_init_state`, i.e., a `Tensor` or list of `Tensor`s which is interpretted as the `bijector.inverse` transformed state. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. transformed_init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". #### Examples To use `transformed_init_state` in context of `tfp.mcmc.sample_chain`, you need to explicitly pass the `previous_kernel_results`, e.g., ```python transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...) init_state = ... # Doesnt matter. transformed_init_state = ... # Does matter. results, _ = tfp.mcmc.sample_chain( num_results=..., current_state=init_state, previous_kernel_results=transformed_kernel.bootstrap_results( transformed_init_state=transformed_init_state), kernel=transformed_kernel) ``` """ if (init_state is None) == (transformed_init_state is None): raise ValueError('Must specify exactly one of `init_state` ' 'or `transformed_init_state`.') with tf.name_scope(name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'), values=[init_state, transformed_init_state]): if transformed_init_state is None: init_state_parts = (init_state if is_list_like(init_state) else [init_state]) transformed_init_state_parts = self._inverse_transform( init_state_parts) transformed_init_state = (transformed_init_state_parts if is_list_like(init_state) else transformed_init_state_parts[0]) else: if is_list_like(transformed_init_state): transformed_init_state = [ tf.convert_to_tensor(value=s, name='transformed_init_state') for s in transformed_init_state ] else: transformed_init_state = tf.convert_to_tensor( value=transformed_init_state, name='transformed_init_state') kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_init_state, inner_results=self._inner_kernel.bootstrap_results( transformed_init_state)) return kernel_results
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s of fully defined static shape representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. ValueError: if `current_state` does not have a fully defined static shape. TypeError: if `not target_log_prob.dtype.is_floating`. """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'slice', 'one_step'), values=[self.step_size, self.max_doublings, self._seed_stream, current_state, previous_kernel_results.target_log_prob]): with tf.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = tf.convert_to_tensor( self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next( self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=self._seed_stream() ) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults( target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds ), ]
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'mala', 'one_step'), values=[ self.step_size, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.diffusion_drift ]): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility, previous_kernel_results.diffusion_drift, self.parallel_iterations) random_draw_parts = [] for s in current_state_parts: random_draw_parts.append( tf.random_normal(shape=tf.shape(s), dtype=s.dtype.base_dtype, seed=self._seed_stream())) # Number of independent chains run by the algorithm. independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] # Decide whether to compute the acceptance ratio log_acceptance_correction_compute = _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims) log_acceptance_correction_skip = tf.zeros_like( next_target_log_prob) log_acceptance_correction = tf.cond( self.compute_acceptance, lambda: log_acceptance_correction_compute, lambda: log_acceptance_correction_skip) return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction=log_acceptance_correction, target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts), ]
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob]): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal( shape=tf.shape(input=x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self.state_gradients_are_stopped) num_leapfrog_steps = tf.convert_to_tensor( value=self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps') [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)), loop_vars=[ tf.zeros([], tf.int64, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc', 'one_step'), values=[current_state, previous_kernel_results]): sampled_replica_states, sampled_replica_results = zip(*[ rk.one_step(previous_kernel_results.replica_states[i], previous_kernel_results.replica_results[i]) for i, rk in enumerate(self.replica_kernels) ]) sampled_replica_states = list(sampled_replica_states) sampled_replica_results = list(sampled_replica_results) sampled_replica_results_modified = [ srr._replace(target_log_prob=srr.target_log_prob / self.inverse_temperatures[i]) if 'target_log_prob' in srr._fields else srr._replace( accepted_results=srr.accepted_results._replace( target_log_prob=srr.accepted_results.target_log_prob / self.inverse_temperatures[i])) for i, srr in enumerate(sampled_replica_results) ] sampled_replica_ratios = [ srr.target_log_prob if 'target_log_prob' in srr._fields else srr.accepted_results.target_log_prob for i, srr in enumerate(sampled_replica_results_modified) ] sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1) next_replica_idx = tf.range(self.num_replica) self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='replica_exchange_one_step') exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn( self.num_replica, seed=self._seed_stream) i = tf.constant(0) def cond(i, next_replica_idx): # pylint: disable=unused-argument return tf.less(i, exchange_proposed_n) def body(i, next_replica_idx): """`tf.while_loop` body.""" ratio = (sampled_replica_ratios[next_replica_idx[ exchange_proposed[i, 0]]] - sampled_replica_ratios[ next_replica_idx[exchange_proposed[i, 1]]]) ratio *= (self.inverse_temperatures[exchange_proposed[i, 1]] - self.inverse_temperatures[exchange_proposed[i, 0]]) self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='replica_exchange_one_step') log_uniform = tf.log( tf.random_uniform(shape=tf.shape(ratio), dtype=ratio.dtype.base_dtype, seed=self._seed_stream)) exchange = log_uniform < ratio exchange_op = tf.sparse_to_dense( [exchange_proposed[i, 0], exchange_proposed[i, 1]], [self.num_replica], [ next_replica_idx[exchange_proposed[i, 1]] - next_replica_idx[exchange_proposed[i, 0]], next_replica_idx[exchange_proposed[i, 0]] - next_replica_idx[exchange_proposed[i, 1]] ]) next_replica_idx = tf.cond( exchange, lambda: next_replica_idx + exchange_op, lambda: next_replica_idx) return [i + 1, next_replica_idx] next_replica_idx = tf.while_loop(cond, body, loop_vars=[i, next_replica_idx])[1] def _prep(list_): return list( tf.case( { tf.equal(next_replica_idx[i], j): _stateful_lambda( list_[j]) for j in range(self.num_replica) }, exclusive=True) for i in range(self.num_replica)) next_replica_states = _prep(sampled_replica_states) next_replica_results = _prep(sampled_replica_results_modified) next_replica_results = [ nrr._replace(target_log_prob=nrr.target_log_prob * self.inverse_temperatures[i]) if 'target_log_prob' in nrr._fields else nrr._replace( accepted_results=nrr.accepted_results._replace( target_log_prob=nrr.accepted_results.target_log_prob * self.inverse_temperatures[i])) for i, nrr in enumerate(next_replica_results) ] next_state = tf.identity(next_replica_states[0]) kernel_results = ReplicaExchangeMCKernelResults( replica_states=next_replica_states, replica_results=next_replica_results, next_replica_idx=next_replica_idx, exchange_proposed=exchange_proposed, exchange_proposed_n=exchange_proposed_n, sampled_replica_states=sampled_replica_states, sampled_replica_results=sampled_replica_results, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'remc', 'one_step'), values=[current_state, previous_kernel_results]): sampled_replica_states, sampled_replica_results = zip(*[ rk.one_step(previous_kernel_results.replica_states[i], previous_kernel_results.replica_results[i]) for i, rk in enumerate(self.replica_kernels)]) sampled_replica_states = list(sampled_replica_states) sampled_replica_results = list(sampled_replica_results) sampled_replica_results_modified = [ srr._replace(target_log_prob=srr.target_log_prob / self.inverse_temperatures[i]) if 'target_log_prob' in srr._fields else srr._replace(accepted_results=srr.accepted_results._replace( target_log_prob=srr.accepted_results.target_log_prob / self.inverse_temperatures[i])) for i, srr in enumerate(sampled_replica_results) ] sampled_replica_ratios = [ srr.target_log_prob if 'target_log_prob' in srr._fields else srr.accepted_results.target_log_prob for i, srr in enumerate(sampled_replica_results_modified)] sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1) next_replica_idx = tf.range(self.num_replica) self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='replica_exchange_one_step') exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn( self.num_replica, seed=self._seed_stream) i = tf.constant(0) def cond(i, next_replica_idx): # pylint: disable=unused-argument return tf.less(i, exchange_proposed_n) def body(i, next_replica_idx): """`tf.while_loop` body.""" ratio = ( sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 0]]] - sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 1]]]) ratio *= ( self.inverse_temperatures[exchange_proposed[i, 1]] - self.inverse_temperatures[exchange_proposed[i, 0]]) self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='replica_exchange_one_step') log_uniform = tf.log(tf.random_uniform( shape=tf.shape(ratio), dtype=ratio.dtype.base_dtype, seed=self._seed_stream)) exchange = log_uniform < ratio exchange_op = tf.sparse_to_dense( [exchange_proposed[i, 0], exchange_proposed[i, 1]], [self.num_replica], [next_replica_idx[exchange_proposed[i, 1]] - next_replica_idx[exchange_proposed[i, 0]], next_replica_idx[exchange_proposed[i, 0]] - next_replica_idx[exchange_proposed[i, 1]]]) next_replica_idx = tf.cond(exchange, lambda: next_replica_idx + exchange_op, lambda: next_replica_idx) return [i + 1, next_replica_idx] next_replica_idx = tf.while_loop( cond, body, loop_vars=[i, next_replica_idx])[1] def _prep(list_): return list( tf.case({tf.equal(next_replica_idx[i], j): _stateful_lambda(list_[j]) for j in range(self.num_replica)}, exclusive=True) for i in range(self.num_replica)) next_replica_states = _prep(sampled_replica_states) next_replica_results = _prep(sampled_replica_results_modified) next_replica_results = [ nrr._replace(target_log_prob=nrr.target_log_prob * self.inverse_temperatures[i]) if 'target_log_prob' in nrr._fields else nrr._replace(accepted_results=nrr.accepted_results._replace( target_log_prob=nrr.accepted_results.target_log_prob * self.inverse_temperatures[i])) for i, nrr in enumerate(next_replica_results) ] next_state = tf.identity(next_replica_states[0]) kernel_results = ReplicaExchangeMCKernelResults( replica_states=next_replica_states, replica_results=next_replica_results, next_replica_idx=next_replica_idx, exchange_proposed=exchange_proposed, exchange_proposed_n=exchange_proposed_n, sampled_replica_states=sampled_replica_states, sampled_replica_results=sampled_replica_results, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[ self.step_size, self.num_leapfrog_steps, self._seed_stream, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob ]): with tf.name_scope('initialize'): [ current_state_parts, step_sizes, current_target_log_prob, current_grads_target_log_prob, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True) current_momentums = [] for s in current_state_parts: # Note: # - We mutate seed state so subsequent calls are not correlated. # - We mutate seed BEFORE using it just in case users supplied the # same seed to an outer kernel, e.g., `MetropolisHastings`. self._seed_stream = distributions_util.gen_new_seed( self._seed_stream, salt='hmc_kernel_momentums') current_momentums.append( tf.random_normal(shape=tf.shape(s), dtype=s.dtype.base_dtype, seed=self._seed_stream)) num_leapfrog_steps = tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_momentums, next_state_parts, next_target_log_prob, next_grads_target_log_prob, ] = _leapfrog_integrator(current_momentums, self.target_log_prob_fn, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction= _compute_log_acceptance_correction( current_momentums, next_momentums, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, ), ]
def one_step(self, current_state, previous_kernel_results): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'one_step'), values=[self.step_size, self.num_leapfrog_steps, current_state, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob]): [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) independent_chain_ndims = distribution_util.prefer_static_rank( current_target_log_prob) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append(tf.random_normal( shape=tf.shape(x), dtype=x.dtype.base_dtype, seed=self._seed_stream())) def _leapfrog_one_step(*args): """Closure representing computation done during each leapfrog step.""" return _leapfrog_integrator_one_step( target_log_prob_fn=self.target_log_prob_fn, independent_chain_ndims=independent_chain_ndims, step_sizes=step_sizes, current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3], state_gradients_are_stopped=self.state_gradients_are_stopped) num_leapfrog_steps = tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps') [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda i, *args: i < num_leapfrog_steps, body=lambda i, *args: [i + 1] + list(_leapfrog_one_step(*args)), loop_vars=[ tf.zeros([], tf.int64, name='iter'), current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ])[1:] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ), ]
def bootstrap_results(self, init_state=None, transformed_init_state=None): """Returns an object with the same type as returned by `one_step`. Unlike other `TransitionKernel`s, `TransformedTransitionKernel.bootstrap_results` has the option of initializing the `TransformedTransitionKernelResults` from either an initial state, eg, requiring computing `bijector.inverse(init_state)`, or directly from `transformed_init_state`, i.e., a `Tensor` or list of `Tensor`s which is interpretted as the `bijector.inverse` transformed state. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. transformed_init_state: `Tensor` or Python `list` of `Tensor`s representing the a state(s) of the Markov chain(s). Must specify `init_state` or `transformed_init_state` but not both. Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". #### Examples To use `transformed_init_state` in context of `tfp.mcmc.sample_chain`, you need to explicitly pass the `previous_kernel_results`, e.g., ```python transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...) init_state = ... # Doesnt matter. transformed_init_state = ... # Does matter. results, _ = tfp.mcmc.sample_chain( num_results=..., current_state=init_state, previous_kernel_results=transformed_kernel.bootstrap_results( transformed_init_state=transformed_init_state), kernel=transformed_kernel) ``` """ if (init_state is None) == (transformed_init_state is None): raise ValueError('Must specify exactly one of `init_state` ' 'or `transformed_init_state`.') with tf.name_scope( name=make_name(self.name, 'transformed_kernel', 'bootstrap_results'), values=[init_state, transformed_init_state]): if transformed_init_state is None: init_state_parts = (init_state if is_list_like(init_state) else [init_state]) transformed_init_state_parts = self._inverse_transform(init_state_parts) transformed_init_state = ( transformed_init_state_parts if is_list_like(init_state) else transformed_init_state_parts[0]) else: if is_list_like(transformed_init_state): transformed_init_state = [ tf.convert_to_tensor(s, name='transformed_init_state') for s in transformed_init_state ] else: transformed_init_state = tf.convert_to_tensor( transformed_init_state, name='transformed_init_state') kernel_results = TransformedTransitionKernelResults( transformed_state=transformed_init_state, inner_results=self._inner_kernel.bootstrap_results( transformed_init_state)) return kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # Key difficulty: The type of exchanges differs from one call to the # next...even the number of exchanges can differ. # As a result, exchanges must happen dynamically, in while loops. with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc', 'one_step'), values=[current_state, previous_kernel_results]): # Each replica does `one_step` to get pre-exchange states/KernelResults. sampled_replica_states, sampled_replica_results = zip(*[ rk.one_step(previous_kernel_results.replica_states[i], previous_kernel_results.replica_results[i]) for i, rk in enumerate(self.replica_kernels) ]) sampled_replica_states = list(sampled_replica_states) sampled_replica_results = list(sampled_replica_results) states_are_lists = mcmc_util.is_list_like( sampled_replica_states[0]) if not states_are_lists: sampled_replica_states = [[s] for s in sampled_replica_states] num_state_parts = len(sampled_replica_states[0]) dtype = sampled_replica_states[0][0].dtype # Must put states into TensorArrays. Why? We will read/write states # dynamically with Tensor index `i`, and you cannot do this with lists. # old_states[k][i] is Tensor of (old) state part k, for replica i. # The `k` will be known statically, and `i` is a Tensor. old_states = [ tf.TensorArray( dtype, size=self.num_replica, dynamic_size=False, clear_after_read=False, tensor_array_name='old_states', # State part k has same shape, regardless of replica. So use 0. element_shape=sampled_replica_states[0][k].shape) for k in range(num_state_parts) ] for k in range(num_state_parts): for i in range(self.num_replica): old_states[k] = old_states[k].write( i, sampled_replica_states[i][k]) exchange_proposed = self.exchange_proposed_fn( self.num_replica, seed=self._seed_stream()) exchange_proposed_n = tf.shape(exchange_proposed)[0] exchanged_states = self._get_exchanged_states( old_states, exchange_proposed, exchange_proposed_n, sampled_replica_states, sampled_replica_results) no_exchange_proposed, _ = tf.setdiff1d( tf.range(self.num_replica), tf.reshape(exchange_proposed, [-1])) exchanged_states = self._insert_old_states_where_no_exchange_was_proposed( no_exchange_proposed, old_states, exchanged_states) next_replica_states = [] for i in range(self.num_replica): next_replica_states_i = [] for k in range(num_state_parts): next_replica_states_i.append(exchanged_states[k].read(i)) next_replica_states.append(next_replica_states_i) if not states_are_lists: next_replica_states = [s[0] for s in next_replica_states] sampled_replica_states = [s[0] for s in sampled_replica_states] # Now that states are/aren't exchanged, bootstrap next kernel_results. # The viewpoint is that after each exchange, we are starting anew. next_replica_results = [ rk.bootstrap_results(state) for rk, state in zip(self.replica_kernels, next_replica_states) ] next_state = next_replica_states[ 0] # Replica 0 is the returned state(s). kernel_results = ReplicaExchangeMCKernelResults( replica_states=next_replica_states, replica_results=next_replica_results, sampled_replica_states=sampled_replica_states, sampled_replica_results=sampled_replica_results, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ with tf.name_scope( name=mcmc_util.make_name(self.name, 'mh', 'one_step'), values=[current_state, previous_kernel_results]): # Take one inner step. [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results) if (not has_target_log_prob(proposed_results) or not has_target_log_prob(previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn('Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.log(tf.random_uniform( shape=tf.shape(proposed_results.target_log_prob), dtype=proposed_results.target_log_prob.dtype.base_dtype, seed=self._seed_stream())) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose( is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, proposed_results, previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `not target_log_prob.dtype.is_floating`. """ with tf.name_scope(name=mcmc_util.make_name(self.name, 'slice', 'one_step'), values=[ self.step_size, self.max_doublings, self._seed_stream, current_state, previous_kernel_results.target_log_prob ]): with tf.name_scope('initialize'): [current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args(self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = tf.convert_to_tensor(self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = distributions_util.prefer_static_rank( current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next(self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=self._seed_stream()) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults(target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds), ]