def test_static_num_iters(self, iters): counter = None # following loop variables not @parameterized because the tf.constants # would be executed outside the Eager mode that # @test_util.test_all_tf_execution_regimes creates for n in [ iters, tf.constant(iters, dtype=tf.int64), tf.constant(iters, dtype=tf.int32) ]: counter = collections.Counter() def body(x): counter['body_calls'] += 1 return [x + 1] result = util.smart_for_loop(loop_num_iter=n, body_fn=body, initial_loop_vars=[tf.constant(1)]) if JAX_MODE: # JAX always traces loop bodies exactly once self.assertEqual(1, counter['body_calls']) elif tf.executing_eagerly(): self.assertEqual(iters, counter['body_calls']) else: expected_num_calls = 1 if iters > 0 else 0 self.assertEqual(expected_num_calls, counter['body_calls']) self.assertAllClose([iters + 1], self.evaluate(result))
def _trace_scan_fn(state_and_results, num_steps): next_state, current_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=kernel.one_step, initial_loop_vars=list(state_and_results), parallel_iterations=parallel_iterations) return next_state, current_kernel_results
def test_tf_while_loop(self): iters = 10 n = tf1.placeholder_with_default(np.int64(iters), shape=()) counter = collections.Counter() def body(x): counter['body_calls'] += 1 return [x + 1] result = util.smart_for_loop( loop_num_iter=n, body_fn=body, initial_loop_vars=[tf.constant(1)]) self.assertEqual(iters if tf.executing_eagerly() else 1, counter['body_calls']) self.assertAllClose([11], self.evaluate(result))
def test_python_for_loop(self): n = tf.constant(10, dtype=tf.int64) counter = collections.Counter() def body(x): counter['body_calls'] += 1 return [x + 1] result = util.smart_for_loop(loop_num_iter=n, body_fn=body, initial_loop_vars=[tf.constant(1)]) self.assertEqual(10, counter['body_calls']) self.assertAllClose([11], self.evaluate(result))
def test_python_for_loop(self): counter = None # Not @parameterized because the tf.constants would be executed outside the # Eager mode that @test_util.test_all_tf_execution_regimes creates, and # TF is unhappy about that. for n in [10, tf.constant(10, dtype=tf.int64), tf.constant(10, dtype=tf.int32)]: counter = collections.Counter() def body(x): counter['body_calls'] += 1 return [x + 1] result = util.smart_for_loop( loop_num_iter=n, body_fn=body, initial_loop_vars=[tf.constant(1)]) self.assertEqual(10, counter['body_calls']) self.assertAllClose([11], self.evaluate(result))
def test_unroll_threshold(self): iters = 50 counter = collections.Counter() def body(x): counter['body_calls'] += 1 return [x + 1] result = util.smart_for_loop(loop_num_iter=iters, body_fn=body, initial_loop_vars=[tf.constant(1)], unroll_threshold=iters) if JAX_MODE: # JAX always traces loop bodies exactly once self.assertEqual(1, counter['body_calls']) else: self.assertEqual(iters, counter['body_calls']) self.assertAllClose([iters + 1], self.evaluate(result))
def step_kernel( num_steps, current_state, previous_kernel_results=None, kernel=None, return_final_kernel_results=False, parallel_iterations=10, seed=None, name=None, ): """Takes `num_steps` repeated `TransitionKernel` steps from `current_state`. This is meant to be a minimal driver for executing `TransitionKernel`s; for something more featureful, see `sample_chain`. Args: num_steps: Integer number of Markov chain steps. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s. Warm-start for the auxiliary state needed by the given `kernel`. If not supplied, `step_kernel` will cold-start with `kernel.bootstrap_results`. kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state after `num_steps` steps are taken. This can be useful to inspect the final auxiliary state, or for a later warm restart. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_step_kernel'). Returns: next_state: Markov chain state after `num_step` steps are taken, of identical type as `current_state`. final_kernel_results: kernel results, as supplied by `kernel.one_step` after `num_step` steps are taken. This is only returned if `return_final_kernel_results` is `True`. """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='experimental.mcmc.step_kernel') if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') with tf.name_scope(name or 'mcmc_step_kernel'): num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32, name='num_steps') current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='current_state'), current_state) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) def _seeded_one_step(seed, *state_and_results): step_seed, passalong_seed = (samplers.split_seed(seed) if is_seeded else (None, seed)) one_step_kwargs = dict(seed=step_seed) if is_seeded else {} return [passalong_seed] + list( kernel.one_step(*state_and_results, **one_step_kwargs)) _, next_state, final_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=_seeded_one_step, initial_loop_vars=list( (seed, current_state, previous_kernel_results)), parallel_iterations=parallel_iterations) # return semantics are simple enough to not warrant the use of named tuples # as in `sample_chain` if return_final_kernel_results: return next_state, final_kernel_results else: return next_state