def f(): def target_log_prob_fn(event): return tfd.MultivariateNormalDiag( tf.zeros(event_size), scale_identity_multiplier=1.).log_prob(event) state = tf.zeros([batch_size, event_size]) chain_state, extra = tfp.mcmc.sample_chain( num_results=num_steps, num_burnin_steps=0, current_state=[state], kernel=tfp.experimental.mcmc.NoUTurnSampler( target_log_prob_fn, step_size=[0.3], use_auto_batching=True, seed=1, backend=tf_backend.TensorFlowBackend( safety_checks=False, while_parallel_iterations=1)), parallel_iterations=1) return chain_state, extra.leapfrogs_taken
import tensorflow.compat.v1 as tf from tensorflow_probability.python.experimental.auto_batching import numpy_backend from tensorflow_probability.python.experimental.auto_batching import test_programs from tensorflow_probability.python.experimental.auto_batching import tf_backend from tensorflow_probability.python.experimental.auto_batching import virtual_machine as vm from tensorflow_probability.python.internal import test_util flags.DEFINE_string( 'test_device', None, 'TensorFlow device on which to place operators under test') FLAGS = flags.FLAGS NP_BACKEND = numpy_backend.NumpyBackend() TF_BACKEND = tf_backend.TensorFlowBackend() TF_BACKEND_NO_ASSERTS = tf_backend.TensorFlowBackend(safety_checks=False) # This program always returns 2. def _constant_execute(inputs, backend): # Stack depth limit is 4 to accommodate the initial value and # two pushes to "answer". with tf.compat.v2.name_scope('constant_program'): return vm.execute(test_programs.constant_program(), [inputs], max_stack_depth=4, backend=backend) # This program returns n > 1 ? 2 : 0. def _single_if_execute(inputs, backend):
# Dependency imports import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python.experimental.auto_batching import allocation_strategy from tensorflow_probability.python.experimental.auto_batching import dsl from tensorflow_probability.python.experimental.auto_batching import instructions from tensorflow_probability.python.experimental.auto_batching import lowering from tensorflow_probability.python.experimental.auto_batching import numpy_backend from tensorflow_probability.python.experimental.auto_batching import tf_backend from tensorflow_probability.python.experimental.auto_batching import type_inference from tensorflow_probability.python.experimental.auto_batching import virtual_machine as vm from tensorflow_probability.python.internal import test_util TF_BACKEND = tf_backend.TensorFlowBackend() NP_BACKEND = numpy_backend.NumpyBackend() def _execute(prog, inputs, stack_depth, backend): return vm.execute( prog, [inputs], max_stack_depth=stack_depth, backend=backend) def fibonacci_program(): ab = dsl.ProgramBuilder() def fib_type(arg_types): return arg_types[0]