def _integrator_conserves_energy(self, x, independent_chain_ndims): event_dims = tf.range(independent_chain_ndims, tf.rank(x)) target_fn = lambda x: self._log_gamma_log_prob(x, event_dims) m = tf.random.normal(tf.shape(input=x)) log_prob_0 = target_fn(x) old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(input_tensor=m**2., axis=event_dims) event_size = np.prod(self.evaluate(x).shape[independent_chain_ndims:]) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( target_fn, step_sizes=[0.1 / event_size], num_steps=1000) [[new_m], [_], log_prob_1, [_]] = integrator([m], [x]) new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(input_tensor=new_m**2., axis=event_dims) old_energy_, new_energy_ = self.evaluate([old_energy, new_energy]) tf1.logging.vlog( 1, 'average energy relative change: {}'.format( (1. - new_energy_ / old_energy_).mean())) self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=prefer_static.rank_from_shape(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = samplers.split_seed(seed, n=len(current_state_parts)) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) current_momentum_parts = [] for part_seed, x in zip(seeds, current_state_parts): current_momentum_parts.append( samplers.normal(shape=ps.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=part_seed)) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = ps.rank(current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims, shard_axis_names=self.experimental_shard_axis_names), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, momentum_distribution, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, self.momentum_distribution, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) current_momentum_parts = momentum_distribution.sample(seed=seed) momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) # Let the integrator handle the case where no momentum distribution # is provided if self.momentum_distribution is None: leapfrog_kinetic_energy_fn = None else: leapfrog_kinetic_energy_fn = kinetic_energy_fn integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator( current_momentum_parts, current_state_parts, target=current_target_log_prob, target_grad_parts=current_target_log_prob_grad_parts, kinetic_energy_fn=leapfrog_kinetic_energy_fn) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( kinetic_energy_fn, current_momentum_parts, next_momentum_parts), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) current_momentum_parts = [] for x in current_state_parts: current_momentum_parts.append( tf.random.normal(shape=tf.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=self._seed_stream())) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = prefer_static.rank( current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, ) return maybe_flatten(next_state_parts), new_kernel_results
def loop_tree_doubling(self, step_size, log_slice_sample, init_energy, momentum_state_memory, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_size = prefer_static.size(init_energy) direction = tf.cast(tf.random.uniform(shape=[batch_size], minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) left_right_index = tf.concat([ tf.cast(direction, tf.int32)[..., tf.newaxis], tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis] ], axis=1) tree_start_states = tf.nest.map_structure( # Alternatively: `lambda v: tf.where(direction, v[1], v[0])` lambda v: tf.gather_nd(v, left_right_index), initial_step_state) directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(direction, ss, -ss) for direction, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, leapfrogs_taken, ] = self._build_sub_tree( directions_expanded, integrator, log_slice_sample, init_energy, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where(choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn # Alternative approach # left_right_mask = tf.transpose( # tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2), # [1, initial_step_metastate.candidate_state[0].shape[-1], 1]), # [2, 0, 1]) # trajactory_state_left_right = tf.where( # tf.equal(left_right_mask, 0.), # trajactory_state_left_right, # tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1])) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ # Alternative approach: # tf.where(tf.equal(left_right_mask, 0.), # v, # tf.tile(r[tf.newaxis], # tf.concat([[2], tf.ones_like(tf.shape(r))], 0))) tf.tensor_scatter_nd_update(v, left_right_index, r) for v, r in zip(tf.nest.flatten(initial_step_state), tf.nest.flatten(tree_final_states)) ]) no_u_turns_trajectory = has_not_u_turn( [s[0] for s in new_step_state.state], [m[0] for m in new_step_state.momentum], [s[1] for s in new_step_state.state], [m[1] for m in new_step_state.momentum]) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def _loop_build_sub_tree( self, direction, log_slice_sample, iter_, prev_tree_state, candidate_tree_state, continue_tree_previous, trace_arrays): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in prev_tree_state.state] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[tf.where(direction, ss, -ss) for direction, ss in zip( directions_expanded, self.step_size)], num_steps=self.unrolled_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: trace_arrays = TraceArrays( momentum_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.state_swap, next_state_parts)]) else: trace_arrays = TraceArrays( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.state_swap, next_state_parts)]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts) branch_excution = {x: functools.partial(f, x) for x in range(len(self.read_instruction))} no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) valid_candidate = log_slice_sample <= energy # Uniform sampling on the trajectory within the subtree sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE) weight_sum = candidate_tree_state.weight + sample_weight log_accept_thresh = tf.math.log( tf.cast(sample_weight, tf.float32) / tf.cast(weight_sum, tf.float32)) log_accept_thresh = tf.where( tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform( shape=[batch_size], dtype=tf.float32, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip(next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) not_divergent = log_slice_sample - energy < self.max_energy_diff continue_tree = not_divergent & no_u_turns_within_tree continue_tree_next = continue_tree_previous & continue_tree return ( iter_ + 1, next_tree_state, next_candidate_tree_state, continue_tree_next, trace_arrays, )
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) seed = samplers.sanitize_seed(self._seed_stream()) seeds = samplers.split_seed(seed, n=len(current_state_parts)) current_momentum_parts = [] for part_seed, x in zip(seeds, current_state_parts): current_momentum_parts.append( samplers.normal(shape=tf.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=part_seed)) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = prefer_static.rank( current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results