Ejemplo n.º 1
0
 def _trace_scan_fn(state_and_results, num_steps):
     next_state, current_kernel_results = mcmc_util.smart_for_loop(
         loop_num_iter=num_steps,
         body_fn=kernel.one_step,
         initial_loop_vars=list(state_and_results),
         parallel_iterations=parallel_iterations)
     return next_state, current_kernel_results
Ejemplo n.º 2
0
 def _scan_body(args_list, num_steps):
     """Closure which implements `tf.scan` body."""
     next_state, current_kernel_results = mcmc_util.smart_for_loop(
         loop_num_iter=num_steps,
         body_fn=kernel.one_step,
         initial_loop_vars=args_list,
         parallel_iterations=parallel_iterations)
     return [next_state, current_kernel_results]
Ejemplo n.º 3
0
 def _scan_body(args_list, num_steps):
   """Closure which implements `tf.scan` body."""
   next_state, current_kernel_results = mcmc_util.smart_for_loop(
       loop_num_iter=num_steps,
       body_fn=kernel.one_step,
       initial_loop_vars=args_list,
       parallel_iterations=parallel_iterations
   )
   return [next_state, current_kernel_results]
Ejemplo n.º 4
0
    def test_python_for_loop(self):
        n = tf.constant(10, dtype=tf.int64)
        counter = collections.Counter()

        def body(x):
            counter['body_calls'] += 1
            return [x + 1]

        result = smart_for_loop(loop_num_iter=n,
                                body_fn=body,
                                initial_loop_vars=[tf.constant(1)])
        self.assertEqual(10, counter['body_calls'])
        self.assertAllClose([11], self.evaluate(result))
Ejemplo n.º 5
0
    def test_tf_while_loop(self):
        n = tf.placeholder_with_default(input=np.int64(10), shape=())
        counter = collections.Counter()

        def body(x):
            counter['body_calls'] += 1
            return [x + 1]

        result = smart_for_loop(loop_num_iter=n,
                                body_fn=body,
                                initial_loop_vars=[tf.constant(1)])
        self.assertEqual(1, counter['body_calls'])
        self.assertAllClose([11], self.evaluate(result))