def testReproducibility(self): strm1 = seed_stream.SeedStream(seed=4, salt="salt") strm2 = seed_stream.SeedStream(seed=4, salt="salt") strm3 = seed_stream.SeedStream(seed=4, salt="salt") outputs = [strm1() for _ in range(50)] self.assertEqual(outputs, [strm2() for _ in range(50)]) self.assertEqual(outputs, [strm3() for _ in range(50)])
def testInitFromOtherSeedStream(self): strm1 = seed_stream.SeedStream(seed=4, salt="salt") strm2 = seed_stream.SeedStream(strm1, salt="salt") strm3 = seed_stream.SeedStream(strm1, salt="another salt") out1 = [strm1() for _ in range(50)] out2 = [strm2() for _ in range(50)] out3 = [strm3() for _ in range(50)] self.assertAllEqual(out1, out2) self.assertAllUnique(out1 + out3)
def testNestingRobustness(self): # SeedStreams started from generated seeds should not collide with # the master or with each other, even if the salts are the same. strm1 = seed_stream.SeedStream(seed=4, salt="salt") strm2 = seed_stream.SeedStream(strm1(), salt="salt") strm3 = seed_stream.SeedStream(strm1(), salt="salt") outputs = [strm1() for _ in range(50)] self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
def testNonRepetition(self): # The probability of repetitions in a short stream from a correct # PRNG is negligible; this test catches bugs that prevent state # updates. strm = seed_stream.SeedStream(seed=4, salt="salt") output = [strm() for _ in range(50)] self.assertEqual(sorted(output), sorted(list(set(output))))
def _choose_random_direction(current_state_parts, batch_rank, seed=None): """Chooses a random direction in the event space.""" seed_gen = seed_stream.SeedStream(seed, salt='_choose_random_direction') # Chooses the random directions across each of the input components. rnd_direction_parts = [ tf.random_normal(current_state_part.shape.as_list(), dtype=tf.float32, seed=seed_gen()) for current_state_part in current_state_parts ] # Sum squares over all of the input components. Note this takes all # components into account. sum_squares = sum( tf.reduce_sum(rnd_direction**2., axis=tf.range(batch_rank, tf.rank(rnd_direction)), keepdims=True) for rnd_direction in rnd_direction_parts) # Normalizes the random direction fragments. rnd_direction_parts = [ rnd_direction / tf.sqrt(sum_squares) for rnd_direction in rnd_direction_parts ] return rnd_direction_parts
def __init__(self, target_log_prob_fn, step_size, max_doublings, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: Scalar or `tf.Tensor` with same dtype as and shape compatible with `x_initial`. The size of the initial interval. max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of doublings to consider. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'slice_sampler_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. """ self._seed_stream = seed_stream.SeedStream(seed, salt='slice_sampler') self._parameters = dict(target_log_prob_fn=target_log_prob_fn, step_size=step_size, max_doublings=max_doublings, seed=seed, name=name)
def _sample_with_shrinkage(x_initial, target_log_prob, log_slice_heights, step_size, lower_bounds, upper_bounds, seed=None, name=None): """Samples from the slice by applying shrinkage for rejected points. Implements the one dimensional slice sampling algorithm of Neal (2003), with a doubling algorithm (Neal 2003 P715 Fig. 4), which doubles the size of the interval at each iteration and shrinkage (Neal 2003 P716 Fig. 5), which reduces the width of the slice when a selected point is rejected, by setting the relevant bound that that value. Randomly sampled points are checked for two criteria: that they lie within the slice and that they pass the acceptability check (Neal 2003 P717 Fig. 6), which tests that the new state could have generated the previous one. Args: x_initial: A tensor of any shape. The initial positions of the chains. This function assumes that all the dimensions of `x_initial` are batch dimensions (i.e. the event shape is `[]`). target_log_prob: Callable accepting a tensor like `x_initial` and returning a tensor containing the log density at that point of the same shape. log_slice_heights: Tensor of the same shape and dtype as the return value of `target_log_prob` when applied to `x_initial`. The log of the height of the chosen slice. step_size: A tensor of shape and dtype compatible with `x_initial`. The min interval size in the doubling algorithm. lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower bounds for each chain. upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper bounds for each chain. seed: (Optional) positive int. The random seed. If None, no seed is set. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: x_proposed: A tensor of the same shape and dtype as `x_initial`. The next proposed state of the chain. """ with tf.name_scope(name, 'sample_with_shrinkage', [x_initial, log_slice_heights, step_size, lower_bounds, upper_bounds]): seed_gen = seed_stream.SeedStream(seed, salt='_sample_with_shrinkage') # Keeps track of whether an acceptable sample has been found for the chain. found = tf.zeros_like(x_initial, dtype=tf.bool) cond = lambda found, *ignored_args: ~tf.reduce_all(found) x_next = tf.identity(x_initial) x_initial_shape = tf.shape(x_initial) x_initial_dtype = x_initial.dtype.base_dtype def _body(found, left, right, x_next): """Iterates until every chain has found a suitable next state.""" proportions = tf.random_uniform(x_initial_shape, dtype=x_initial_dtype, seed=seed_gen()) x_proposed = tf.where(~found, left + proportions * (right - left), x_next) accept_res = _test_acceptance(x_initial, target_log_prob=target_log_prob, decided=found, log_slice_heights=log_slice_heights, x_proposed=x_proposed, step_size=step_size, lower_bounds=left, upper_bounds=right) boundary_test = log_slice_heights < target_log_prob(x_proposed) can_accept = boundary_test & accept_res next_found = found | can_accept # Note that it might seem that we are moving the left and right end points # even if the point has been accepted (which is contrary to the stated # algorithm in Neal). However, this does not matter because the endpoints # for points that have been already accepted are not used again so it # doesn't matter what we do with them. next_left = tf.where(x_proposed < x_initial, x_proposed, left) next_right = tf.where(x_proposed >= x_initial, x_proposed, right) return next_found, next_left, next_right, x_proposed return tf.while_loop(cond, body=_body, loop_vars=(found, lower_bounds, upper_bounds, x_next))[-1]
def slice_bounds_by_doubling(x_initial, target_log_prob, log_slice_heights, max_doublings, step_size, seed=None, name=None): """Returns the bounds of the slice at each stage of doubling procedure. Precomputes the x coordinates of the left (L) and right (R) endpoints of the interval `I` produced in the "doubling" algorithm [Neal 2003][1] P713. Note that we simultaneously compute all possible doubling values for each chain, for the reason that at small-medium densities, the gains from parallel evaluation might cause a speed-up, but this will be benchmarked against the while loop implementation. Args: x_initial: `tf.Tensor` of any shape and any real dtype consumable by `target_log_prob`. The initial points. target_log_prob: A callable taking a `tf.Tensor` of shape and dtype as `x_initial` and returning a tensor of the same shape. The log density of the target distribution. log_slice_heights: `tf.Tensor` with the same shape as `x_initial` and the same dtype as returned by `target_log_prob`. The log of the height of the slice for each chain. The values must be bounded above by `target_log_prob(x_initial)`. max_doublings: Scalar positive int32 `tf.Tensor`. The maximum number of doublings to consider. step_size: `tf.Tensor` with same dtype as and shape compatible with `x_initial`. The size of the initial interval. seed: (Optional) positive int. The random seed. If None, no seed is set. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: upper_bounds: A tensor of same shape and dtype as `x_initial`. Slice upper bounds for each chain. lower_bounds: A tensor of same shape and dtype as `x_initial`. Slice lower bounds for each chain. both_ok: A tensor of shape `x_initial` and boolean dtype. Indicates if both the chosen upper and lower bound lie outside of the slice. #### References [1]: Radford M. Neal. Slice Sampling. The Annals of Statistics. 2003, Vol 31, No. 3 , 705-767. https://projecteuclid.org/download/pdf_1/euclid.aos/1056562461 """ with tf.name_scope(name, 'slice_bounds_by_doubling', [x_initial, log_slice_heights, max_doublings, step_size]): seed_gen = seed_stream.SeedStream(seed, salt='slice_bounds_by_doubling') x_initial = tf.convert_to_tensor(x_initial) batch_shape = tf.shape(x_initial) dtype = step_size.dtype.base_dtype left_endpoints = x_initial + step_size * tf.random_uniform(batch_shape, minval=-1.0, maxval=0.0, dtype=dtype, seed=seed_gen()) # Compute the increments by which we need to step the upper and lower bounds # part of the doubling procedure. left_increments, widths = _left_doubling_increments( batch_shape, max_doublings, step_size, seed=seed_gen()) # The left and right end points. Shape (max_doublings+1,) + batch_shape. left_endpoints -= left_increments right_endpoints = left_endpoints + widths # Test if these end points lie outside of the slice. # Checks if the end points of the slice are outside the graph of the pdf. left_ep_values = tf.map_fn(target_log_prob, left_endpoints) right_ep_values = tf.map_fn(target_log_prob, right_endpoints) left_ok = left_ep_values < log_slice_heights right_ok = right_ep_values < log_slice_heights both_ok = left_ok & right_ok both_ok_f = tf.reshape(both_ok, [max_doublings + 1, -1]) best_interval_idx = _find_best_interval_idx(tf.to_int32(both_ok_f)) # Formats the above index as required to use with gather_nd. point_index_gather = tf.stack([best_interval_idx, tf.range(tf.size(best_interval_idx))], axis=1, name='point_index_gather') left_ep_f = tf.reshape(left_endpoints, [max_doublings + 1, -1]) right_ep_f = tf.reshape(right_endpoints, [max_doublings + 1, -1]) # The x values of the uppper and lower bounds of the slices for each chain. lower_bounds = tf.reshape(tf.gather_nd(left_ep_f, point_index_gather), batch_shape) upper_bounds = tf.reshape(tf.gather_nd(right_ep_f, point_index_gather), batch_shape) both_ok = tf.reduce_any(both_ok, axis=0) return upper_bounds, lower_bounds, both_ok
def testSaltedDistinctness(self): strm1 = seed_stream.SeedStream(seed=4, salt="salt") strm2 = seed_stream.SeedStream(seed=4, salt="another salt") self.assertAllUnique( [strm1() for _ in range(50)] + [strm2() for _ in range(50)])