def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None, but if `store_parameters_in_results` # is true, then `momentum_distribution` defaults to an empty list if momentum_distribution is None or isinstance(momentum_distribution, list): batch_rank = ps.rank(target_log_prob) def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), reinterpreted_batch_ndims=event_ndims) momentum_distribution = jds.JointDistributionSequential([ _batched_isotropic_normal_like(state_part) for state_part in state_parts ]) # The momentum will get "maybe listified" to zip with the state parts, # and this step makes sure that the momentum distribution will have the # same "maybe listified" underlying shape. if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = jds.JointDistributionSequential( [momentum_distribution]) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob, grads_target_log_prob, ]
def _swap_then_retemper(x): x, is_multipart = mcmc_util.prepare_state_parts(x) it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0]) x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x] if not is_multipart: x = x[0] return x
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False, experimental_shard_axis_names=None): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(target_log_prob), shard_axis_names=experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(target_log_prob)) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def _prepare_step_size(step_size, dtype, n_state_parts): step_sizes, _ = mcmc_util.prepare_state_parts( step_size, dtype=dtype, name='step_size') if len(step_sizes) == 1: step_sizes *= n_state_parts if n_state_parts != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') return step_sizes
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')): init_state, _ = mcmc_util.prepare_state_parts(init_state) if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) if self._store_parameters_in_results: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, initial_momentum=tf.nest.map_structure( tf.zeros_like, init_state), final_momentum=tf.nest.map_structure( tf.zeros_like, init_state), # TODO(b/142590314): Try to use the following code once we commit to # a tensorization policy. # step_size=mcmc_util.prepare_state_parts( # self.step_size, # dtype=init_target_log_prob.dtype, # name='step_size')[0], step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=init_target_log_prob.dtype, name='step_size'), self.step_size), num_leapfrog_steps=tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps')) else: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, initial_momentum=tf.nest.map_structure( tf.zeros_like, init_state), final_momentum=tf.nest.map_structure( tf.zeros_like, init_state), step_size=[], num_leapfrog_steps=[])
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] state_parts, _ = mcmc_util.prepare_state_parts(init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) # Confirm that the step size is compatible with the state parts. _ = _prepare_step_size( self.step_size, current_target_log_prob.dtype, len(init_state)) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob), shard_axis_names=self.experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed()) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), self.step_size), log_accept_ratio=tf.zeros_like( current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like( current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'phmc', 'bootstrap_results')): result = super(UncalibratedPreconditionedHamiltonianMonteCarlo, self).bootstrap_results(init_state) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') target_log_prob = self.target_log_prob_fn(*state_parts) if (not self._store_parameters_in_results or self.momentum_distribution is None): momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(target_log_prob)) else: momentum_distribution = pu.maybe_make_list_and_batch_broadcast( self.momentum_distribution, ps.shape(target_log_prob)) result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults( **result._asdict(), # pylint: disable=protected-access momentum_distribution=momentum_distribution) return result
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = ps.convert_to_shape_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second ' '(`seed`) argument. `TransitionKernel` instances now receive seeds ' 'via `one_step`.') replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = bu.left_justified_broadcast_to(tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=_set_swapped_fields_to_nan( replica_results), is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), potential_energy=tf.zeros_like( pre_swap_replica_target_log_prob), )
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = tf.convert_to_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we fall back to the previous behavior and warn. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The second (`seed`) argument to `ReplicaExchangeMC`s ' '`make_kernel_fn` is deprecated. `TransitionKernel` instances now ' 'receive seeds via `bootstrap_results` and `one_step`. This ' 'fallback may become an error 2020-09-20.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = mcmc_util.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) post_swap_replica_results = _make_post_swap_replica_results( replica_results, inverse_temperatures, inverse_temperatures, is_swap_accepted[0], lambda x: x, ) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = prefer_static.size0(inverse_temperatures) replica_shape = tf.convert_to_tensor([num_replica]) replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, prefer_static.concat( [replica_shape, prefer_static.shape(x)], axis=0), name='replica_states') for x in init_state ] inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable _make_replica_target_log_prob_fn(self.target_log_prob_fn, inverse_temperatures), self._seed_stream()) replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = prefer_static.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = mcmc_util.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) post_swap_replica_results = _make_post_swap_replica_results( replica_results, inverse_temperatures, inverse_temperatures, is_swap_accepted[0], lambda x: x, ) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError( 'Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob)) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample() def _init(shape_and_dtype): """Allocate TensorArray for storing state and velocity.""" return [ # pylint: disable=g-complex-comprehension ps.zeros(ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [ (ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x ] velocity_state_memory = VelocityStateSwap( velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)), state_swap=_init(get_shapes_and_dtypes(init_state))) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, velocity_state_memory=velocity_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts( step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None, but if `store_parameters_in_results` # is true, then `momentum_distribution` defaults to DefaultStandardNormal(). if (momentum_distribution is None or isinstance(momentum_distribution, DefaultStandardNormal)): batch_rank = ps.rank(target_log_prob) def _batched_isotropic_normal_like(state_part): return sample.Sample( normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.), ps.shape(state_part)[batch_rank:]) momentum_distribution = jds.JointDistributionSequential( [_batched_isotropic_normal_like(state_part) for state_part in state_parts]) # The momentum will get "maybe listified" to zip with the state parts, # and this step makes sure that the momentum distribution will have the # same "maybe listified" underlying shape. if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = jds.JointDistributionSequential( [momentum_distribution]) # If all underlying distributions are independent, we can offer some help. # This code will also trigger for the output of the two blocks above. if (isinstance(momentum_distribution, jds.JointDistributionSequential) and not any(callable(dist_fn) for dist_fn in momentum_distribution.model)): batch_shape = ps.shape(target_log_prob) momentum_distribution = momentum_distribution.copy(model=[ batch_broadcast.BatchBroadcast(md, to_shape=batch_shape) for md in momentum_distribution.model ]) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]