Exemple #1
0
def _sample_next(target_log_prob_fn,
                 current_state_parts,
                 step_sizes,
                 max_doublings,
                 current_target_log_prob,
                 batch_rank,
                 seed=None,
                 name=None):
    """Applies a single iteration of slice sampling update.

  Applies hit and run style slice sampling. Chooses a uniform random direction
  on the unit sphere in the event space. Applies the one dimensional slice
  sampling update along that direction.

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    step_sizes: Python `list` of `Tensor`s. Provides a measure of the width
      of the density. Used to find the slice bounds. Must broadcast with the
      shape of `current_state_parts`.
    max_doublings: Integer number of doublings to allow while locating the slice
      boundaries.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    batch_rank: Integer. The number of axes in the state that correspond to
      independent batches.
    seed: Python integer to seed random number generators.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    bounds_satisfied: Boolean `Tensor` of the same shape as the log density.
      True indicates whether the an interval containing the slice for that
      batch was found successfully.
    direction: `Tensor` or Python list of `Tensors`s representing the direction
      along which the slice was sampled. Has the same shape and dtype(s) as
      `current_state_parts`.
    upper_bounds: `Tensor` of batch shape and the dtype of the input state. The
      upper bounds of the slices along the sampling direction.
    lower_bounds: `Tensor` of batch shape and the dtype of the input state. The
      lower bounds of the slices along the sampling direction.
  """
    with tf.name_scope(name, 'sample_next', [
            current_state_parts, step_sizes, max_doublings,
            current_target_log_prob, batch_rank
    ]):
        # First step: Choose a random direction.
        # Direction is a list of tensors. The i'th tensor should have the same shape
        # as the i'th state part.
        direction = _choose_random_direction(current_state_parts,
                                             batch_rank=batch_rank,
                                             seed=seed)

        # Interpolates the step sizes for the chosen direction.
        # Applies an ellipsoidal interpolation to compute the step direction for
        # the chosen direction. Suppose we are given step sizes for each direction.
        # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving
        # in a direction parallel to one of the axes. Consider an ellipsoid which
        # intercepts the i'th axis at s_i. The step size for a direction specified
        # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection
        # of the line through this vector with this ellipsoid.
        #
        # One can show that the length of the vector from the origin to the
        # intersection point is given by:
        # 1 / sqrt(n_1^2 / s_1^2  + n_2^2 / s_2^2  + ...).
        #
        # Proof:
        # The equation of the ellipsoid is:
        # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points
        # along the line given by n may be parameterized as alpha*n where alpha is
        # the distance along the vector. Plugging this into the equation for the
        # ellipsoid, we get:
        # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1
        # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } }
        reduce_axes = [
            tf.range(batch_rank, tf.rank(dirn_part)) for dirn_part in direction
        ]

        components = [
            tf.reduce_sum((dirn_part / step_size)**2, axis=reduce_axes[i])
            for i, (step_size,
                    dirn_part) in enumerate(zip(step_sizes, direction))
        ]
        step_size = tf.rsqrt(tf.add_n(components))

        # Computes the rank of a tensor. Uses the static rank if possible.
        def _get_rank(x):
            return (len(x.shape.as_list())
                    if x.shape.dims is not None else tf.rank(x))

        state_part_ranks = [_get_rank(part) for part in current_state_parts]

        def _step_along_direction(alpha):
            """Converts the scalar alpha into an n-dim vector with full state info.

      Computes x_0 + alpha * direction where x_0 is the current state and
      direction is the direction chosen above.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        state_parts: Tensor or Python list of `Tensor`s representing the
          state(s) of the Markov chain(s) for a given alpha and a given chosen
          direction. Has the same shape as `current_state_parts`.
      """
            padded_alphas = [
                _right_pad(alpha, final_rank=part_rank)
                for part_rank in state_part_ranks
            ]

            state_parts = [
                state_part + padded_alpha * direction_part
                for state_part, direction_part, padded_alpha in zip(
                    current_state_parts, direction, padded_alphas)
            ]
            return state_parts

        def projected_target_log_prob_fn(alpha):
            """The target log density projected along the chosen direction.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        Target log density evaluated at x_0 + alpha * direction where x_0 is the
        current state and direction is the direction chosen above. Has the same
        shape as `alpha`.
      """
            return target_log_prob_fn(*_step_along_direction(alpha))

        alpha_init = tf.zeros_like(
            current_target_log_prob,
            dtype=current_state_parts[0].dtype.base_dtype)
        [
            next_alpha, next_target_log_prob, bounds_satisfied, upper_bounds,
            lower_bounds
        ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn,
                                      x_initial=alpha_init,
                                      max_doublings=max_doublings,
                                      step_size=step_size,
                                      seed=seed)
        return [
            _step_along_direction(next_alpha), next_target_log_prob,
            bounds_satisfied, direction, upper_bounds, lower_bounds
        ]
def _sample_next(target_log_prob_fn,
                 current_state_parts,
                 step_sizes,
                 max_doublings,
                 current_target_log_prob,
                 batch_rank,
                 seed=None,
                 name=None):
  """Applies a single iteration of slice sampling update.

  Applies hit and run style slice sampling. Chooses a uniform random direction
  on the unit sphere in the event space. Applies the one dimensional slice
  sampling update along that direction.

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains. Must have fully defined static
      shape.
    step_sizes: Python `list` of `Tensor`s representing the step size for the
      leapfrog integrator. Must broadcast with the shape of
      `current_state_parts`.
    max_doublings: Integer number of doublings to allow while locating the slice
      boundaries.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    batch_rank: Integer. The number of axes in the state that correspond to
      independent batches.
    seed: Python integer to seed random number generators.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    bounds_satisfied: Boolean `Tensor` of the same shape as the log density.
      True indicates whether the an interval containing the slice for that
      batch was found successfully.
    direction: `Tensor` or Python list of `Tensors`s representing the direction
      along which the slice was sampled. Has the same shape and dtype(s) as
      `current_state_parts`.
    upper_bounds: `Tensor` of batch shape and the dtype of the input state. The
      upper bounds of the slices along the sampling direction.
    lower_bounds: `Tensor` of batch shape and the dtype of the input state. The
      lower bounds of the slices along the sampling direction.
  """
  with tf.name_scope(
      name, 'sample_next',
      [current_state_parts, step_sizes, max_doublings, current_target_log_prob,
       batch_rank]):
    # First step: Choose a random direction.
    # Direction is a list of tensors. The i'th tensor should have the same shape
    # as the i'th state part.
    direction = _choose_random_direction(current_state_parts,
                                         batch_rank=batch_rank,
                                         seed=seed)

    # Interpolates the step sizes for the chosen direction.
    # Applies an ellipsoidal interpolation to compute the step direction for
    # the chosen direction. Suppose we are given step sizes for each direction.
    # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving
    # in a direction parallel to one of the axes. Consider an ellipsoid which
    # intercepts the i'th axis at s_i. The step size for a direction specified
    # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection
    # of the line through this vector with this ellipsoid.
    #
    # One can show that the length of the vector from the origin to the
    # intersection point is given by:
    # 1 / sqrt(n_1^2 / s_1^2  + n_2^2 / s_2^2  + ...).
    #
    # Proof:
    # The equation of the ellipsoid is:
    # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points
    # along the line given by n may be parameterized as alpha*n where alpha is
    # the distance along the vector. Plugging this into the equation for the
    # ellipsoid, we get:
    # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1
    # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } }
    reduce_axes = [tf.range(batch_rank, tf.rank(dirn_part))
                   for dirn_part in direction]

    components = [tf.reduce_sum((dirn_part / step_size) ** 2,
                                axis=reduce_axes[i])
                  for i, (step_size, dirn_part)
                  in enumerate(zip(step_sizes, direction))]
    step_size = tf.rsqrt(tf.add_n(components))

    def _step_along_direction(alpha):
      """Converts the scalar alpha into an n-dim vector with full state info.

      Computes x_0 + alpha * direction where x_0 is the current state and
      direction is the direction chosen above.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        state_parts: Tensor or Python list of `Tensor`s representing the
          state(s) of the Markov chain(s) for a given alpha and a given chosen
          direction. Has the same shape as `current_state_parts`.
      """
      padded_alphas = [_right_pad_with_static_shape(
          alpha, final_rank=len(current_state_part.shape))
                       for current_state_part in current_state_parts]

      state_parts = [state_part + padded_alpha
                     * direction_part for state_part, direction_part,
                     padded_alpha in
                     zip(current_state_parts, direction, padded_alphas)]
      return state_parts

    def projected_target_log_prob_fn(alpha):
      """The target log density projected along the chosen direction.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        Target log density evaluated at x_0 + alpha * direction where x_0 is the
        current state and direction is the direction chosen above. Has the same
        shape as `alpha`.
      """
      return target_log_prob_fn(*_step_along_direction(alpha))

    alpha_init = tf.zeros_like(current_target_log_prob,
                               dtype=current_state_parts[0].dtype.base_dtype)
    [
        next_alpha,
        next_target_log_prob,
        bounds_satisfied,
        upper_bounds,
        lower_bounds
    ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn,
                                  x_initial=alpha_init,
                                  max_doublings=max_doublings,
                                  step_size=step_size, seed=seed)
    return [
        _step_along_direction(next_alpha),
        next_target_log_prob,
        bounds_satisfied,
        direction,
        upper_bounds,
        lower_bounds
    ]