def testGradientNumBodyCalls(self): counter = collections.Counter() dtype = np.float32 def fn(x, y): counter['body_calls'] += 1 return x**2 + y**2 fn_args = [dtype(3), dtype(3)] # Convert function input to a list of tensors. fn_args = [tf.convert_to_tensor(value=arg) for arg in fn_args] util.maybe_call_fn_and_grads(fn, fn_args) expected_num_calls = 1 if JAX_MODE or not tf.executing_eagerly() else 2 self.assertEqual(expected_num_calls, counter['body_calls'])
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None, but if `store_parameters_in_results` # is true, then `momentum_distribution` defaults to an empty list if momentum_distribution is None or isinstance(momentum_distribution, list): batch_rank = ps.rank(target_log_prob) def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), reinterpreted_batch_ndims=event_ndims) momentum_distribution = jds.JointDistributionSequential([ _batched_isotropic_normal_like(state_part) for state_part in state_parts ]) # The momentum will get "maybe listified" to zip with the state parts, # and this step makes sure that the momentum distribution will have the # same "maybe listified" underlying shape. if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = jds.JointDistributionSequential( [momentum_distribution]) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def _process_args(self, momentum_parts, state_parts, target, target_grad_parts): """Sanitize inputs to `__call__`.""" with tf.name_scope('process_args'): momentum_parts = [ tf.convert_to_tensor(v, dtype_hint=tf.float32, name='momentum_parts') for v in momentum_parts ] state_parts = [ tf.convert_to_tensor(v, dtype_hint=tf.float32, name='state_parts') for v in state_parts ] if target is None or target_grad_parts is None: [target, target_grad_parts ] = mcmc_util.maybe_call_fn_and_grads(self.target_fn, state_parts) else: target = tf.convert_to_tensor(target, dtype_hint=tf.float32, name='target') target_grad_parts = [ tf.convert_to_tensor(g, dtype_hint=tf.float32, name='target_grad_part') for g in target_grad_parts ] return momentum_parts, state_parts, target, target_grad_parts
def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob, grads_target_log_prob, ]
def testGradientWorksForMultivariateNormalTriL(self): # TODO(b/72831017): Remove this once bijector cacheing is fixed for # graph mode. if not tf.executing_eagerly(): self.skipTest('Gradients get None values in graph mode.') d = tfd.MultivariateNormalTriL(scale_tril=tf.eye(2)) x = d.sample(seed=test_util.test_seed()) fn_result, grads = util.maybe_call_fn_and_grads(d.log_prob, x) self.assertAllEqual(False, fn_result is None) self.assertAllEqual([False], [g is None for g in grads])
def testNoGradientsNiceError(self): dtype = np.float32 def fn(x, y): return x**2 + tf.stop_gradient(y)**2 fn_args = [dtype(3), dtype(3)] # Convert function input to a list of tensors fn_args = [ tf.convert_to_tensor(value=arg, name='arg{}'.format(i)) for i, arg in enumerate(fn_args) ] if tf.executing_eagerly(): with self.assertRaisesRegexp( ValueError, 'Encountered `None`.*\n.*fn_arg_list.*\n.*None'): util.maybe_call_fn_and_grads(fn, fn_args) else: with self.assertRaisesRegexp( ValueError, 'Encountered `None`.*\n.*fn_arg_list.*arg1.*\n.*None'): util.maybe_call_fn_and_grads(fn, fn_args)
def testGradientComputesCorrectly(self): dtype = np.float32 def fn(x, y): return x**2 + y**2 fn_args = [dtype(3), dtype(3)] # Convert function input to a list of tensors fn_args = [tf.convert_to_tensor(value=arg) for arg in fn_args] fn_result, grads = util.maybe_call_fn_and_grads(fn, fn_args) fn_result_, grads_ = self.evaluate([fn_result, grads]) self.assertNear(18., fn_result_, err=1e-5) for grad in grads_: self.assertAllClose(grad, dtype(6), atol=0., rtol=1e-5)
def bootstrap_results(self, init_state): with tf.name_scope( mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')): init_state, _ = mcmc_util.prepare_state_parts(init_state) if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) if self._store_parameters_in_results: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, initial_momentum=tf.nest.map_structure( tf.zeros_like, init_state), final_momentum=tf.nest.map_structure( tf.zeros_like, init_state), # TODO(b/142590314): Try to use the following code once we commit to # a tensorization policy. # step_size=mcmc_util.prepare_state_parts( # self.step_size, # dtype=init_target_log_prob.dtype, # name='step_size')[0], step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=init_target_log_prob.dtype, name='step_size'), self.step_size), num_leapfrog_steps=tf.convert_to_tensor( self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps')) else: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, initial_momentum=tf.nest.map_structure( tf.zeros_like, init_state), final_momentum=tf.nest.map_structure( tf.zeros_like, init_state), step_size=[], num_leapfrog_steps=[])
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] state_parts, _ = mcmc_util.prepare_state_parts(init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) # Confirm that the step size is compatible with the state parts. _ = _prepare_step_size( self.step_size, current_target_log_prob.dtype, len(init_state)) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob), shard_axis_names=self.experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed()) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), self.step_size), log_accept_ratio=tf.zeros_like( current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like( current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like( current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False, experimental_shard_axis_names=None): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts(step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(target_log_prob), shard_axis_names=experimental_shard_axis_names) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(target_log_prob)) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def hmc_like_proposed_velocity_getter_fn(kernel_results): """Getter for `proposed_velocity` so it can be inspected.""" final_momentum = unnest.get_innermost(kernel_results, 'final_momentum') proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') momentum_distribution = unnest.get_innermost(kernel_results, 'momentum_distribution', default=None) if momentum_distribution is None: proposed_velocity = final_momentum else: momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads( kinetic_energy_fn, final_momentum) # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, tf.nest.flatten(proposed_velocity))
def _integrator_conserves_energy(self, x, independent_chain_ndims): event_dims = tf.range(independent_chain_ndims, tf.rank(x)) m = tf.random.normal(tf.shape(input=x)) log_prob_0, grad_0 = maybe_call_fn_and_grads( lambda x: self._log_gamma_log_prob(x, event_dims), x) old_energy = -log_prob_0 + 0.5 * tf.reduce_sum( input_tensor=m**2., axis=event_dims) x_shape = self.evaluate(x).shape event_size = np.prod(x_shape[independent_chain_ndims:]) step_size = tf.constant(0.1 / event_size, x.dtype) hmc_lf_steps = tf.constant(1000, np.int32) def leapfrog_one_step(*args): return _leapfrog_integrator_one_step( lambda x: self._log_gamma_log_prob(x, event_dims), independent_chain_ndims, [step_size], *args) [[new_m], _, log_prob_1, _] = tf.while_loop( cond=lambda *args: True, body=leapfrog_one_step, loop_vars=[ [m], # current_momentum_parts [x], # current_state_parts, log_prob_0, # current_target_log_prob grad_0, # current_target_log_prob_grad_parts ], maximum_iterations=hmc_lf_steps) new_energy = -log_prob_1 + 0.5 * tf.reduce_sum( input_tensor=new_m**2., axis=event_dims) old_energy_, new_energy_ = self.evaluate([old_energy, new_energy]) tf.compat.v1.logging.vlog( 1, 'average energy relative change: {}'.format( (1. - new_energy_ / old_energy_).mean())) self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
def bootstrap_results(self, init_state): with tf.compat.v2.name_scope( mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [ tf.convert_to_tensor(value=x) for x in init_state ] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) if self._store_parameters_in_results: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda value=x, dtype=init_target_log_prob.dtype, name='step_size'), self.step_size), num_leapfrog_steps=tf.convert_to_tensor( value=self.num_leapfrog_steps, dtype=tf.int32, name='num_leapfrog_steps')) else: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=[], num_leapfrog_steps=[])
def _one_step(target_fn, step_sizes, get_velocity_parts, half_next_momentum_parts, state_parts, target, target_grad_parts): """Body of integrator while loop.""" with tf.name_scope('leapfrog_integrate_one_step'): velocity_parts = get_velocity_parts(half_next_momentum_parts) next_state_parts = [] for state_part, eps, velocity_part in zip(state_parts, step_sizes, velocity_parts): next_state_parts.append(state_part + tf.cast(eps, state_part.dtype) * tf.cast(velocity_part, state_part.dtype)) [next_target, next_target_grad_parts ] = mcmc_util.maybe_call_fn_and_grads(target_fn, next_state_parts) if any(g is None for g in next_target_grad_parts): raise ValueError('Encountered `None` gradient.\n' ' state_parts: {}\n' ' next_state_parts: {}\n' ' next_target_grad_parts: {}'.format( state_parts, next_state_parts, next_target_grad_parts)) tensorshape_util.set_shape(next_target, target.shape) for ng, g in zip(next_target_grad_parts, target_grad_parts): tensorshape_util.set_shape(ng, g.shape) next_half_next_momentum_parts = [ v + tf.cast(eps, v.dtype) * tf.cast(g, v.dtype) # pylint: disable=g-complex-comprehension for v, eps, g in zip(half_next_momentum_parts, step_sizes, next_target_grad_parts) ] return [ next_half_next_momentum_parts, next_state_parts, next_target, next_target_grad_parts, ]
def _one_step(self, half_next_momentum_parts, state_parts, target, target_grad_parts): """Body of integrator while loop.""" with tf.name_scope('leapfrog_integrate_one_step'): next_state_parts = [ x + tf.cast(eps, x.dtype) * tf.cast(v, x.dtype) # pylint: disable=g-complex-comprehension for x, eps, v in zip(state_parts, self.step_sizes, half_next_momentum_parts) ] [next_target, next_target_grad_parts ] = mcmc_util.maybe_call_fn_and_grads(self.target_fn, next_state_parts) if any(g is None for g in next_target_grad_parts): raise ValueError('Encountered `None` gradient.\n' ' state_parts: {}\n' ' next_state_parts: {}\n' ' next_target_grad_parts: {}'.format( state_parts, next_state_parts, next_target_grad_parts)) next_target.set_shape(target.shape) for ng, g in zip(next_target_grad_parts, target_grad_parts): ng.set_shape(g.shape) next_half_next_momentum_parts = [ v + tf.cast(eps, v.dtype) * tf.cast(g, v.dtype) # pylint: disable=g-complex-comprehension for v, eps, g in zip(half_next_momentum_parts, self.step_sizes, next_target_grad_parts) ] return [ next_half_next_momentum_parts, next_state_parts, next_target, next_target_grad_parts, ]
def _prepare_args(target_log_prob_fn, volatility_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, volatility=None, grads_volatility_fn=None, diffusion_drift=None, parallel_iterations=10): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] [ target_log_prob, grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) [ volatility_parts, grads_volatility, ] = _maybe_call_volatility_fn_and_grads(volatility_fn, state_parts, volatility, grads_volatility_fn, ps.shape(target_log_prob), parallel_iterations) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor(value=s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes ] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') if diffusion_drift is None: diffusion_drift_parts = _get_drift(step_sizes, volatility_parts, grads_volatility, grads_target_log_prob) else: diffusion_drift_parts = (list(diffusion_drift) if mcmc_util.is_list_like(diffusion_drift) else [diffusion_drift]) if len(state_parts) != len(diffusion_drift): raise ValueError( 'There should be exactly one `diffusion_drift` or it ' 'should have same length as list-like `current_state`.') return [ state_parts, step_sizes, target_log_prob, grads_target_log_prob, volatility_parts, grads_volatility, diffusion_drift_parts, ]
def _prepare_args(target_log_prob_fn, state, step_size, momentum_distribution, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state') if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes, _ = mcmc_util.prepare_state_parts( step_size, dtype=target_log_prob.dtype, name='step_size') # Default momentum distribution is None, but if `store_parameters_in_results` # is true, then `momentum_distribution` defaults to DefaultStandardNormal(). if (momentum_distribution is None or isinstance(momentum_distribution, DefaultStandardNormal)): batch_rank = ps.rank(target_log_prob) def _batched_isotropic_normal_like(state_part): return sample.Sample( normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.), ps.shape(state_part)[batch_rank:]) momentum_distribution = jds.JointDistributionSequential( [_batched_isotropic_normal_like(state_part) for state_part in state_parts]) # The momentum will get "maybe listified" to zip with the state parts, # and this step makes sure that the momentum distribution will have the # same "maybe listified" underlying shape. if not mcmc_util.is_list_like(momentum_distribution.dtype): momentum_distribution = jds.JointDistributionSequential( [momentum_distribution]) # If all underlying distributions are independent, we can offer some help. # This code will also trigger for the output of the two blocks above. if (isinstance(momentum_distribution, jds.JointDistributionSequential) and not any(callable(dist_fn) for dist_fn in momentum_distribution.model)): batch_shape = ps.shape(target_log_prob) momentum_distribution = momentum_distribution.copy(model=[ batch_broadcast.BatchBroadcast(md, to_shape=batch_shape) for md in momentum_distribution.model ]) if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), momentum_distribution, target_log_prob, grads_target_log_prob, ]
def _leapfrog_integrator_one_step( target_log_prob_fn, independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts, state_gradients_are_stopped=False, name=None): """Applies `num_leapfrog_steps` of the leapfrog integrator. Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. #### Examples: ##### Simple quadratic potential. ```python import matplotlib.pyplot as plt %matplotlib inline import numpy as np import tensorflow as tf from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator_one_step # pylint: disable=line-too-long tfd = tfp.distributions dims = 10 num_iter = int(1e3) dtype = np.float32 position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) target_log_prob_fn = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)).log_prob def _leapfrog_one_step(*args): # Closure representing computation done during each leapfrog step. return _leapfrog_integrator_one_step( target_log_prob_fn=target_log_prob_fn, independent_chain_ndims=0, step_sizes=[0.1], current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3]) # Do leapfrog integration. [ [next_momentum], [next_position], next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda *args: True, body=_leapfrog_one_step, loop_vars=[ [momentum], [position], target_log_prob_fn(position), tf.gradients(target_log_prob_fn(position), position), ], maximum_iterations=3) momentum_ = np.random.randn(dims).astype(dtype) position_ = np.random.randn(dims).astype(dtype) positions = np.zeros([num_iter, dims], dtype) with tf.Session() as sess: for i in xrange(num_iter): position_, momentum_ = sess.run( [next_momentum, next_position], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ plt.plot(positions[:, 0]); # Sinusoidal. ``` Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. step_sizes: Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state_parts`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. current_momentum_parts: Tensor containing the value(s) of the momentum variable(s) to update. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. current_target_log_prob_grad_parts: Python list of `Tensor`s representing gradient of `target_log_prob_fn(*current_state_parts`) wrt `current_state_parts`. Must have same shape as `current_state_parts`. The only reason to specify this argument is to reduce TF graph size. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_leapfrog_integrator'). Returns: proposed_momentum_parts: Updated value of the momentum. proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob` wrt `next_state`. Raises: ValueError: if `len(momentum_parts) != len(state_parts)`. ValueError: if `len(state_parts) != len(step_sizes)`. ValueError: if `len(state_parts) != len(grads_target_log_prob)`. TypeError: if `not target_log_prob.dtype.is_floating`. """ # Note on per-variable step sizes: # # Using per-variable step sizes is equivalent to using the same step # size for all variables and adding a diagonal mass matrix in the # kinetic energy term of the Hamiltonian being integrated. This is # hinted at by Neal (2011) but not derived in detail there. # # Let x and v be position and momentum variables respectively. # Let g(x) be the gradient of `target_log_prob_fn(x)`. # Let S be a diagonal matrix of per-variable step sizes. # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2. # # Using per-variable step sizes gives the updates # v' = v + 0.5 * matmul(S, g(x)) # x'' = x + matmul(S, v') # v'' = v' + 0.5 * matmul(S, g(x'')) # # Let u = matmul(inv(S), v). # Multiplying v by inv(S) in the updates above gives the transformed dynamics # u' = matmul(inv(S), v') = matmul(inv(S), v) + 0.5 * g(x) # = u + 0.5 * g(x) # x'' = x + matmul(S, v') = x + matmul(S**2, u') # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'') # = u' + 0.5 * g(x'') # # These are exactly the leapfrog updates for the Hamiltonian # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u # = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v). # # To summarize: # # * Using per-variable step sizes implicitly simulates the dynamics # of the Hamiltonian H' (which are energy-conserving in H'). We # keep track of v instead of u, but the underlying dynamics are # the same if we transform back. # * The value of the Hamiltonian H'(x, u) is the same as the value # of the original Hamiltonian H(x, v) after we transform back from # u to v. # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2). # # So using per-variable step sizes in HMC will give results that are # exactly identical to explicitly using a diagonal mass matrix. with tf.compat.v1.name_scope(name, 'hmc_leapfrog_integrator_one_step', [ independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ]): # Step 1: Update momentum. proposed_momentum_parts = [ v + 0.5 * tf.cast(eps, v.dtype) * g for v, eps, g in zip(current_momentum_parts, step_sizes, current_target_log_prob_grad_parts)] # Step 2: Update state. proposed_state_parts = [ x + tf.cast(eps, v.dtype) * v for x, eps, v in zip(current_state_parts, step_sizes, proposed_momentum_parts)] if state_gradients_are_stopped: proposed_state_parts = [tf.stop_gradient(x) for x in proposed_state_parts] # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state. [ proposed_target_log_prob, proposed_target_log_prob_grad_parts, ] = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, proposed_state_parts) if not proposed_target_log_prob.dtype.is_floating: raise TypeError('`target_log_prob_fn` must produce a `Tensor` ' 'with `float` `dtype`.') if any(g is None for g in proposed_target_log_prob_grad_parts): raise ValueError( 'Encountered `None` gradient. Does your target `target_log_prob_fn` ' 'access all `tf.Variable`s via `tf.get_variable`?\n' ' current_state_parts: {}\n' ' proposed_state_parts: {}\n' ' proposed_target_log_prob_grad_parts: {}'.format( current_state_parts, proposed_state_parts, proposed_target_log_prob_grad_parts)) # Step 3b: Update momentum (again). proposed_momentum_parts = [ v + 0.5 * tf.cast(eps, v.dtype) * g for v, eps, g in zip(proposed_momentum_parts, step_sizes, proposed_target_log_prob_grad_parts)] return [ proposed_momentum_parts, proposed_state_parts, proposed_target_log_prob, proposed_target_log_prob_grad_parts, ]
def testGradientWorksDespiteBijectorCaching(self): x = tf.constant(2.) fn_result, grads = util.maybe_call_fn_and_grads( lambda x_: tfd.LogNormal(loc=0., scale=1.).log_prob(x_), x) self.assertAllEqual(False, fn_result is None) self.assertAllEqual([False], [g is None for g in grads])
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, velocity_state_memory, momentum_distribution, seed): """Base case in tree doubling.""" acceptance_seed, next_seed = samplers.split_seed(seed) with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence kinetic_energy_fn = get_kinetic_energy_fn(momentum_distribution) [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts, kinetic_energy_fn=kinetic_energy_fn) _, next_velocity_parts = mcmc_util.maybe_call_fn_and_grads( kinetic_energy_fn, next_momentum_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, velocity=next_velocity_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts)] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where( continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy if GENERALIZED_UTURN: state_to_write = momentum_cumsum_previous state_to_check = momentum_cumsum else: state_to_write = next_state_parts state_to_check = next_state_parts batch_shape = ps.shape(next_target) has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_])[0] no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda read_index, directions, velocity_state_memory, next_velocity_parts, state_to_check, has_not_u_turn_init, log_prob_rank=ps.rank(next_target), shard_axis_names=self.experimental_shard_axis_names) # Get index to write state into memory swap write_index = write_instruction.gather([iter_]) velocity_state_memory = VelocityStateSwap( velocity_swap=[ _safe_tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(velocity_state_memory.velocity_swap, next_velocity_parts) ], state_swap=[ _safe_tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(velocity_state_memory.state_swap, state_to_write) ]) energy = compute_hamiltonian(next_target, next_momentum_parts, momentum_distribution) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-samplers.uniform( shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(is_sample_accepted, s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=bu.where_left_justified_mask( is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ bu.where_left_justified_mask(is_sample_accepted, grad0, grad1) for grad0, grad1 in zip(next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=bu.where_left_justified_mask( is_sample_accepted, current_energy, candidate_tree_state.energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, ps.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.)) energy_diff_sum = tf.where(continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, next_seed, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, velocity_state_memory, )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError( 'Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob)) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample() def _init(shape_and_dtype): """Allocate TensorArray for storing state and velocity.""" return [ # pylint: disable=g-complex-comprehension ps.zeros(ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [ (ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x ] velocity_state_memory = VelocityStateSwap( velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)), state_swap=_init(get_shapes_and_dtypes(init_state))) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, velocity_state_memory=velocity_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] momentum_distribution = previous_kernel_results.momentum_distribution current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample] = self._start_trajectory_batched( current_state, current_target_log_prob, momentum_distribution=momentum_distribution, seed=start_trajectory_seed) def _copy(v): return v * ps.ones( ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) _, init_velocity = mcmc_util.maybe_call_fn_and_grads( get_kinetic_energy_fn(momentum_distribution), [m + 0 for m in init_momentum]) # Breaks cache. initial_state = TreeDoublingState( momentum=init_momentum, velocity=init_velocity, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray( tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack(self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction ) velocity_state_memory = VelocityStateSwap( velocity_swap=self.init_velocity_state_memory(init_momentum), state_swap=self.init_velocity_state_memory(current_state)) step_size = _prepare_step_size( previous_kernel_results.step_size, current_target_log_prob.dtype, len(current_state)) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any(metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda step_size, velocity_state_memory, current_step_meta_info, iter_, state, metastate, momentum_distribution, seed), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = PreconditionedNUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=( new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps ), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, momentum_distribution=momentum_distribution, seed=seed, ) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results
def get_velocity_parts(half_next_momentum_parts): _, velocity_parts = mcmc_util.maybe_call_fn_and_grads( kinetic_energy_fn, half_next_momentum_parts) return velocity_parts