Exemplo n.º 1
0
 def testReproducibility(self):
   strm1 = seed_stream.SeedStream(seed=4, salt="salt")
   strm2 = seed_stream.SeedStream(seed=4, salt="salt")
   strm3 = seed_stream.SeedStream(seed=4, salt="salt")
   outputs = [strm1() for _ in range(50)]
   self.assertEqual(outputs, [strm2() for _ in range(50)])
   self.assertEqual(outputs, [strm3() for _ in range(50)])
Exemplo n.º 2
0
 def testInitFromOtherSeedStream(self):
   strm1 = seed_stream.SeedStream(seed=4, salt="salt")
   strm2 = seed_stream.SeedStream(strm1, salt="salt")
   strm3 = seed_stream.SeedStream(strm1, salt="another salt")
   out1 = [strm1() for _ in range(50)]
   out2 = [strm2() for _ in range(50)]
   out3 = [strm3() for _ in range(50)]
   self.assertAllEqual(out1, out2)
   self.assertAllUnique(out1 + out3)
Exemplo n.º 3
0
 def testNestingRobustness(self):
   # SeedStreams started from generated seeds should not collide with
   # the master or with each other, even if the salts are the same.
   strm1 = seed_stream.SeedStream(seed=4, salt="salt")
   strm2 = seed_stream.SeedStream(strm1(), salt="salt")
   strm3 = seed_stream.SeedStream(strm1(), salt="salt")
   outputs = [strm1() for _ in range(50)]
   self.assertAllUnique(
       outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
Exemplo n.º 4
0
 def testNonRepetition(self):
   # The probability of repetitions in a short stream from a correct
   # PRNG is negligible; this test catches bugs that prevent state
   # updates.
   strm = seed_stream.SeedStream(seed=4, salt="salt")
   output = [strm() for _ in range(50)]
   self.assertEqual(sorted(output), sorted(list(set(output))))
Exemplo n.º 5
0
def _choose_random_direction(current_state_parts, batch_rank, seed=None):
    """Chooses a random direction in the event space."""
    seed_gen = seed_stream.SeedStream(seed, salt='_choose_random_direction')
    # Chooses the random directions across each of the input components.
    rnd_direction_parts = [
        tf.random_normal(current_state_part.shape.as_list(),
                         dtype=tf.float32,
                         seed=seed_gen())
        for current_state_part in current_state_parts
    ]

    # Sum squares over all of the input components. Note this takes all
    # components into account.
    sum_squares = sum(
        tf.reduce_sum(rnd_direction**2.,
                      axis=tf.range(batch_rank, tf.rank(rnd_direction)),
                      keepdims=True) for rnd_direction in rnd_direction_parts)

    # Normalizes the random direction fragments.
    rnd_direction_parts = [
        rnd_direction / tf.sqrt(sum_squares)
        for rnd_direction in rnd_direction_parts
    ]

    return rnd_direction_parts
Exemplo n.º 6
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 max_doublings,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: Scalar or `tf.Tensor` with same dtype as and shape compatible
        with `x_initial`. The size of the initial interval.
      max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of
      doublings to consider.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'slice_sampler_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """
        self._seed_stream = seed_stream.SeedStream(seed, salt='slice_sampler')
        self._parameters = dict(target_log_prob_fn=target_log_prob_fn,
                                step_size=step_size,
                                max_doublings=max_doublings,
                                seed=seed,
                                name=name)
Exemplo n.º 7
0
def _sample_with_shrinkage(x_initial, target_log_prob, log_slice_heights,
                           step_size, lower_bounds, upper_bounds, seed=None,
                           name=None):
  """Samples from the slice by applying shrinkage for rejected points.

  Implements the one dimensional slice sampling algorithm of Neal (2003), with a
  doubling algorithm (Neal 2003 P715 Fig. 4), which doubles the size of the
  interval at each iteration and shrinkage (Neal 2003 P716 Fig. 5), which
  reduces the width of the slice when a selected point is rejected, by setting
  the relevant bound that that value. Randomly sampled points are checked for
  two criteria: that they lie within the slice and that they pass the
  acceptability check (Neal 2003 P717 Fig. 6), which tests that the new state
  could have generated the previous one.

  Args:
    x_initial: A tensor of any shape. The initial positions of the chains. This
      function assumes that all the dimensions of `x_initial` are batch
      dimensions (i.e. the event shape is `[]`).
    target_log_prob: Callable accepting a tensor like `x_initial` and returning
      a tensor containing the log density at that point of the same shape.
    log_slice_heights: Tensor of the same shape and dtype as the return value
      of `target_log_prob` when applied to `x_initial`. The log of the height of
      the chosen slice.
    step_size: A tensor of shape and dtype compatible with `x_initial`. The min
      interval size in the doubling algorithm.
    lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower
      bounds for each chain.
    upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper
      bounds for each chain.
    seed: (Optional) positive int. The random seed. If None, no seed is set.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    x_proposed: A tensor of the same shape and dtype as `x_initial`. The next
      proposed state of the chain.
  """
  with tf.name_scope(name, 'sample_with_shrinkage',
                     [x_initial, log_slice_heights, step_size, lower_bounds,
                      upper_bounds]):
    seed_gen = seed_stream.SeedStream(seed, salt='_sample_with_shrinkage')
    # Keeps track of whether an acceptable sample has been found for the chain.
    found = tf.zeros_like(x_initial, dtype=tf.bool)
    cond = lambda found, *ignored_args: ~tf.reduce_all(found)
    x_next = tf.identity(x_initial)
    x_initial_shape = tf.shape(x_initial)
    x_initial_dtype = x_initial.dtype.base_dtype
    def _body(found, left, right, x_next):
      """Iterates until every chain has found a suitable next state."""
      proportions = tf.random_uniform(x_initial_shape, dtype=x_initial_dtype,
                                      seed=seed_gen())
      x_proposed = tf.where(~found, left + proportions * (right - left), x_next)
      accept_res = _test_acceptance(x_initial, target_log_prob=target_log_prob,
                                    decided=found,
                                    log_slice_heights=log_slice_heights,
                                    x_proposed=x_proposed, step_size=step_size,
                                    lower_bounds=left, upper_bounds=right)
      boundary_test = log_slice_heights < target_log_prob(x_proposed)
      can_accept = boundary_test & accept_res
      next_found = found | can_accept
      # Note that it might seem that we are moving the left and right end points
      # even if the point has been accepted (which is contrary to the stated
      # algorithm in Neal). However, this does not matter because the endpoints
      # for points that have been already accepted are not used again so it
      # doesn't matter what we do with them.
      next_left = tf.where(x_proposed < x_initial, x_proposed, left)
      next_right = tf.where(x_proposed >= x_initial, x_proposed, right)
      return next_found, next_left, next_right, x_proposed

    return tf.while_loop(cond, body=_body, loop_vars=(found, lower_bounds,
                                                      upper_bounds, x_next))[-1]
Exemplo n.º 8
0
def slice_bounds_by_doubling(x_initial,
                             target_log_prob,
                             log_slice_heights,
                             max_doublings,
                             step_size,
                             seed=None,
                             name=None):
  """Returns the bounds of the slice at each stage of doubling procedure.

  Precomputes the x coordinates of the left (L) and right (R) endpoints of the
  interval `I` produced in the "doubling" algorithm [Neal 2003][1] P713. Note
  that we simultaneously compute all possible doubling values for each chain,
  for the reason that at small-medium densities, the gains from parallel
  evaluation might cause a speed-up, but this will be benchmarked against the
  while loop implementation.

  Args:
    x_initial: `tf.Tensor` of any shape and any real dtype consumable by
      `target_log_prob`. The initial points.
    target_log_prob: A callable taking a `tf.Tensor` of shape and dtype as
      `x_initial` and returning a tensor of the same shape. The log density of
      the target distribution.
    log_slice_heights: `tf.Tensor` with the same shape as `x_initial` and the
      same dtype as returned by `target_log_prob`. The log of the height of the
      slice for each chain. The values must be bounded above by
      `target_log_prob(x_initial)`.
    max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of
      doublings to consider.
    step_size: `tf.Tensor` with same dtype as and shape compatible with
      `x_initial`. The size of the initial interval.
    seed: (Optional) positive int. The random seed. If None, no seed is set.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    upper_bounds: A tensor of same shape and dtype as `x_initial`. Slice upper
      bounds for each chain.
    lower_bounds: A tensor of same shape and dtype as `x_initial`. Slice lower
      bounds for each chain.
    both_ok: A tensor of shape `x_initial` and boolean dtype. Indicates if both
      the chosen upper and lower bound lie outside of the slice.

  #### References

  [1]: Radford M. Neal. Slice Sampling. The Annals of Statistics. 2003, Vol 31,
       No. 3 , 705-767.
       https://projecteuclid.org/download/pdf_1/euclid.aos/1056562461
  """
  with tf.name_scope(name, 'slice_bounds_by_doubling',
                     [x_initial, log_slice_heights, max_doublings, step_size]):
    seed_gen = seed_stream.SeedStream(seed, salt='slice_bounds_by_doubling')
    x_initial = tf.convert_to_tensor(x_initial)
    batch_shape = tf.shape(x_initial)
    dtype = step_size.dtype.base_dtype
    left_endpoints = x_initial + step_size * tf.random_uniform(batch_shape,
                                                               minval=-1.0,
                                                               maxval=0.0,
                                                               dtype=dtype,
                                                               seed=seed_gen())

    # Compute the increments by which we need to step the upper and lower bounds
    # part of the doubling procedure.
    left_increments, widths = _left_doubling_increments(
        batch_shape, max_doublings, step_size, seed=seed_gen())
    # The left and right end points. Shape (max_doublings+1,) + batch_shape.
    left_endpoints -= left_increments
    right_endpoints = left_endpoints + widths

    # Test if these end points lie outside of the slice.
    # Checks if the end points of the slice are outside the graph of the pdf.
    left_ep_values = tf.map_fn(target_log_prob, left_endpoints)
    right_ep_values = tf.map_fn(target_log_prob, right_endpoints)
    left_ok = left_ep_values < log_slice_heights
    right_ok = right_ep_values < log_slice_heights
    both_ok = left_ok & right_ok

    both_ok_f = tf.reshape(both_ok, [max_doublings + 1, -1])

    best_interval_idx = _find_best_interval_idx(tf.to_int32(both_ok_f))

    # Formats the above index as required to use with gather_nd.
    point_index_gather = tf.stack([best_interval_idx,
                                   tf.range(tf.size(best_interval_idx))],
                                  axis=1, name='point_index_gather')
    left_ep_f = tf.reshape(left_endpoints, [max_doublings + 1, -1])
    right_ep_f = tf.reshape(right_endpoints, [max_doublings + 1, -1])
    # The x values of the uppper and lower bounds of the slices for each chain.
    lower_bounds = tf.reshape(tf.gather_nd(left_ep_f, point_index_gather),
                              batch_shape)
    upper_bounds = tf.reshape(tf.gather_nd(right_ep_f, point_index_gather),
                              batch_shape)
    both_ok = tf.reduce_any(both_ok, axis=0)
    return upper_bounds, lower_bounds, both_ok
Exemplo n.º 9
0
 def testSaltedDistinctness(self):
   strm1 = seed_stream.SeedStream(seed=4, salt="salt")
   strm2 = seed_stream.SeedStream(seed=4, salt="another salt")
   self.assertAllUnique(
       [strm1() for _ in range(50)] + [strm2() for _ in range(50)])