Example #1
0
 def _trace_scan_fn(seed_state_and_results, num_steps):
     seed, next_state, current_kernel_results = loop_util.smart_for_loop(
         loop_num_iter=num_steps,
         body_fn=_seeded_one_step,
         initial_loop_vars=list(seed_state_and_results),
         parallel_iterations=parallel_iterations)
     return seed, next_state, current_kernel_results
Example #2
0
    def test_static_num_iters(self, iters):
        counter = None
        # following loop variables not @parameterized because the tf.constants
        # would be executed outside the Eager mode that
        # @test_util.test_all_tf_execution_regimes creates
        for n in [
                iters,
                tf.constant(iters, dtype=tf.int64),
                tf.constant(iters, dtype=tf.int32)
        ]:
            counter = collections.Counter()

            def body(x):
                counter['body_calls'] += 1
                return [x + 1]

            result = loop_util.smart_for_loop(
                loop_num_iter=n,
                body_fn=body,
                initial_loop_vars=[tf.constant(1)])
            if JAX_MODE:  # JAX always traces loop bodies exactly once
                self.assertEqual(1, counter['body_calls'])
            elif tf.executing_eagerly():
                self.assertEqual(iters, counter['body_calls'])
            else:
                expected_num_calls = 1 if iters > 0 else 0
                self.assertEqual(expected_num_calls, counter['body_calls'])
            self.assertAllClose([iters + 1], self.evaluate(result))
Example #3
0
  def test_placeholder_num_iters(self):
    iters = 10
    n = tf1.placeholder_with_default(np.int64(iters), shape=())
    counter = collections.Counter()
    def body(x):
      counter['body_calls'] += 1
      return [x + 1]

    result = loop_util.smart_for_loop(
        loop_num_iter=n, body_fn=body, initial_loop_vars=[tf.constant(1)])
    if tf.executing_eagerly() and not JAX_MODE:  # JAX always traces loops
      self.assertEqual(iters, counter['body_calls'])
    else:
      self.assertEqual(1, counter['body_calls'])
    self.assertAllClose([11], self.evaluate(result))
Example #4
0
  def test_unroll_threshold(self):
    iters = 50
    counter = collections.Counter()
    def body(x):
      counter['body_calls'] += 1
      return [x + 1]

    result = loop_util.smart_for_loop(
        loop_num_iter=iters,
        body_fn=body,
        initial_loop_vars=[tf.constant(1)],
        unroll_threshold=iters)
    if JAX_MODE:  # JAX always traces loop bodies exactly once
      self.assertEqual(1, counter['body_calls'])
    else:
      self.assertEqual(iters, counter['body_calls'])
    self.assertAllClose([iters + 1], self.evaluate(result))
Example #5
0
def step_kernel(
    num_steps,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    return_final_kernel_results=False,
    parallel_iterations=10,
    seed=None,
    name=None,
):
    """Takes `num_steps` repeated `TransitionKernel` steps from `current_state`.

  This is meant to be a minimal driver for executing `TransitionKernel`s; for
  something more featureful, see `sample_chain`.

  Args:
    num_steps: Integer number of Markov chain steps.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s.
      Warm-start for the auxiliary state needed by the given `kernel`.
      If not supplied, `step_kernel` will cold-start with
      `kernel.bootstrap_results`.
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state after `num_steps` steps are taken.
      This can be useful to inspect the final auxiliary state, or for a later
      warm restart.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'mcmc_step_kernel').

  Returns:
    next_state: Markov chain state after `num_step` steps are taken, of
      identical type as `current_state`.
    final_kernel_results: kernel results, as supplied by `kernel.one_step` after
      `num_step` steps are taken. This is only returned if
      `return_final_kernel_results` is `True`.
  """
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='experimental.mcmc.step_kernel')

    if not kernel.is_calibrated:
        warnings.warn(
            'supplied `TransitionKernel` is not calibrated. Markov '
            'chain may not converge to intended target distribution.')
    with tf.name_scope(name or 'mcmc_step_kernel'):
        num_steps = ps.convert_to_shape_tensor(num_steps,
                                               dtype_hint=tf.int32,
                                               name='num_steps')
        current_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(x, name='current_state'),
            current_state)
        if previous_kernel_results is None:
            previous_kernel_results = kernel.bootstrap_results(current_state)

        def _seeded_one_step(seed, *state_and_results):
            step_seed, passalong_seed = (samplers.split_seed(seed)
                                         if is_seeded else (None, seed))
            one_step_kwargs = dict(seed=step_seed) if is_seeded else {}
            return [passalong_seed] + list(
                kernel.one_step(*state_and_results, **one_step_kwargs))

        _, next_state, final_kernel_results = loop_util.smart_for_loop(
            loop_num_iter=num_steps,
            body_fn=_seeded_one_step,
            initial_loop_vars=list(
                (seed, current_state, previous_kernel_results)),
            parallel_iterations=parallel_iterations)

        # return semantics are simple enough to not warrant the use of named tuples
        # as in `sample_chain`
        if return_final_kernel_results:
            return next_state, final_kernel_results
        else:
            return next_state