def _sample_next(target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, batch_rank, seed=None, name=None): """Applies a single iteration of slice sampling update. Applies hit and run style slice sampling. Chooses a uniform random direction on the unit sphere in the event space. Applies the one dimensional slice sampling update along that direction. Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. step_sizes: Python `list` of `Tensor`s. Provides a measure of the width of the density. Used to find the slice bounds. Must broadcast with the shape of `current_state_parts`. max_doublings: Integer number of doublings to allow while locating the slice boundaries. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. batch_rank: Integer. The number of axes in the state that correspond to independent batches. seed: Python integer to seed random number generators. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. bounds_satisfied: Boolean `Tensor` of the same shape as the log density. True indicates whether the an interval containing the slice for that batch was found successfully. direction: `Tensor` or Python list of `Tensors`s representing the direction along which the slice was sampled. Has the same shape and dtype(s) as `current_state_parts`. upper_bounds: `Tensor` of batch shape and the dtype of the input state. The upper bounds of the slices along the sampling direction. lower_bounds: `Tensor` of batch shape and the dtype of the input state. The lower bounds of the slices along the sampling direction. """ with tf.name_scope(name, 'sample_next', [ current_state_parts, step_sizes, max_doublings, current_target_log_prob, batch_rank ]): # First step: Choose a random direction. # Direction is a list of tensors. The i'th tensor should have the same shape # as the i'th state part. direction = _choose_random_direction(current_state_parts, batch_rank=batch_rank, seed=seed) # Interpolates the step sizes for the chosen direction. # Applies an ellipsoidal interpolation to compute the step direction for # the chosen direction. Suppose we are given step sizes for each direction. # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving # in a direction parallel to one of the axes. Consider an ellipsoid which # intercepts the i'th axis at s_i. The step size for a direction specified # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection # of the line through this vector with this ellipsoid. # # One can show that the length of the vector from the origin to the # intersection point is given by: # 1 / sqrt(n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...). # # Proof: # The equation of the ellipsoid is: # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points # along the line given by n may be parameterized as alpha*n where alpha is # the distance along the vector. Plugging this into the equation for the # ellipsoid, we get: # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1 # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } } reduce_axes = [ tf.range(batch_rank, tf.rank(dirn_part)) for dirn_part in direction ] components = [ tf.reduce_sum((dirn_part / step_size)**2, axis=reduce_axes[i]) for i, (step_size, dirn_part) in enumerate(zip(step_sizes, direction)) ] step_size = tf.rsqrt(tf.add_n(components)) # Computes the rank of a tensor. Uses the static rank if possible. def _get_rank(x): return (len(x.shape.as_list()) if x.shape.dims is not None else tf.rank(x)) state_part_ranks = [_get_rank(part) for part in current_state_parts] def _step_along_direction(alpha): """Converts the scalar alpha into an n-dim vector with full state info. Computes x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) for a given alpha and a given chosen direction. Has the same shape as `current_state_parts`. """ padded_alphas = [ _right_pad(alpha, final_rank=part_rank) for part_rank in state_part_ranks ] state_parts = [ state_part + padded_alpha * direction_part for state_part, direction_part, padded_alpha in zip( current_state_parts, direction, padded_alphas) ] return state_parts def projected_target_log_prob_fn(alpha): """The target log density projected along the chosen direction. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: Target log density evaluated at x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Has the same shape as `alpha`. """ return target_log_prob_fn(*_step_along_direction(alpha)) alpha_init = tf.zeros_like( current_target_log_prob, dtype=current_state_parts[0].dtype.base_dtype) [ next_alpha, next_target_log_prob, bounds_satisfied, upper_bounds, lower_bounds ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn, x_initial=alpha_init, max_doublings=max_doublings, step_size=step_size, seed=seed) return [ _step_along_direction(next_alpha), next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ]
def _sample_next(target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, batch_rank, seed=None, name=None): """Applies a single iteration of slice sampling update. Applies hit and run style slice sampling. Chooses a uniform random direction on the unit sphere in the event space. Applies the one dimensional slice sampling update along that direction. Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. Must have fully defined static shape. step_sizes: Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state_parts`. max_doublings: Integer number of doublings to allow while locating the slice boundaries. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. batch_rank: Integer. The number of axes in the state that correspond to independent batches. seed: Python integer to seed random number generators. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. bounds_satisfied: Boolean `Tensor` of the same shape as the log density. True indicates whether the an interval containing the slice for that batch was found successfully. direction: `Tensor` or Python list of `Tensors`s representing the direction along which the slice was sampled. Has the same shape and dtype(s) as `current_state_parts`. upper_bounds: `Tensor` of batch shape and the dtype of the input state. The upper bounds of the slices along the sampling direction. lower_bounds: `Tensor` of batch shape and the dtype of the input state. The lower bounds of the slices along the sampling direction. """ with tf.name_scope( name, 'sample_next', [current_state_parts, step_sizes, max_doublings, current_target_log_prob, batch_rank]): # First step: Choose a random direction. # Direction is a list of tensors. The i'th tensor should have the same shape # as the i'th state part. direction = _choose_random_direction(current_state_parts, batch_rank=batch_rank, seed=seed) # Interpolates the step sizes for the chosen direction. # Applies an ellipsoidal interpolation to compute the step direction for # the chosen direction. Suppose we are given step sizes for each direction. # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving # in a direction parallel to one of the axes. Consider an ellipsoid which # intercepts the i'th axis at s_i. The step size for a direction specified # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection # of the line through this vector with this ellipsoid. # # One can show that the length of the vector from the origin to the # intersection point is given by: # 1 / sqrt(n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...). # # Proof: # The equation of the ellipsoid is: # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points # along the line given by n may be parameterized as alpha*n where alpha is # the distance along the vector. Plugging this into the equation for the # ellipsoid, we get: # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1 # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } } reduce_axes = [tf.range(batch_rank, tf.rank(dirn_part)) for dirn_part in direction] components = [tf.reduce_sum((dirn_part / step_size) ** 2, axis=reduce_axes[i]) for i, (step_size, dirn_part) in enumerate(zip(step_sizes, direction))] step_size = tf.rsqrt(tf.add_n(components)) def _step_along_direction(alpha): """Converts the scalar alpha into an n-dim vector with full state info. Computes x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) for a given alpha and a given chosen direction. Has the same shape as `current_state_parts`. """ padded_alphas = [_right_pad_with_static_shape( alpha, final_rank=len(current_state_part.shape)) for current_state_part in current_state_parts] state_parts = [state_part + padded_alpha * direction_part for state_part, direction_part, padded_alpha in zip(current_state_parts, direction, padded_alphas)] return state_parts def projected_target_log_prob_fn(alpha): """The target log density projected along the chosen direction. Args: alpha: A tensor of shape equal to the batch dimensions of `current_state_parts`. Returns: Target log density evaluated at x_0 + alpha * direction where x_0 is the current state and direction is the direction chosen above. Has the same shape as `alpha`. """ return target_log_prob_fn(*_step_along_direction(alpha)) alpha_init = tf.zeros_like(current_target_log_prob, dtype=current_state_parts[0].dtype.base_dtype) [ next_alpha, next_target_log_prob, bounds_satisfied, upper_bounds, lower_bounds ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn, x_initial=alpha_init, max_doublings=max_doublings, step_size=step_size, seed=seed) return [ _step_along_direction(next_alpha), next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ]