def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] state_parts = [tf.convert_to_tensor(s, name='current_state') for s in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor( s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob, grads_target_log_prob, ]
def _prepare_args(target_log_prob_fn, volatility_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, volatility=None, grads_volatility_fn=None): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in state_parts ] [ target_log_prob, grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) [ volatility_parts, grads_volatility, ] = _maybe_call_volatility_fn_and_grads(volatility_fn, state_parts, volatility, grads_volatility_fn) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor(s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes ] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') # The shape of 'volatility_parts' needs to have the number of chains as a # leading dimension. For determinism we broadcast 'volatility_parts' to the # shape of `state_parts` since each dimension of `state_parts` could have a # different volatility value. volatility_parts = _maybe_broadcast_volatility(volatility_parts, state_parts) diffusion_drift_parts = _get_drift(step_sizes, volatility_parts, grads_volatility, grads_target_log_prob) return [ state_parts, step_sizes, target_log_prob, grads_target_log_prob, volatility_parts, grads_volatility, diffusion_drift_parts, ]
def testGradientComputesCorrectly(self): dtype = np.float32 def fn(x, y): return x**2 + y**2 fn_args = [dtype(3), dtype(3)] # Convert function input to a list of tensors fn_args = [tf.convert_to_tensor(arg) for arg in fn_args] fn_result, grads = maybe_call_fn_and_grads(fn, fn_args) fn_result_, grads_ = self.evaluate([fn_result, grads]) self.assertNear(18., fn_result_, err=1e-5) for grad in grads_: self.assertAllClose(grad, dtype(6), atol=0., rtol=1e-5)
def testNoGradientsNiceError(self): dtype = np.float32 def fn(x, y): return x**2 + tf.stop_gradient(y)**2 fn_args = [dtype(3), dtype(3)] # Convert function input to a list of tensors fn_args = [ tf.convert_to_tensor(arg, name='arg{}'.format(i)) for i, arg in enumerate(fn_args) ] if tf.executing_eagerly(): with self.assertRaisesRegexp( ValueError, 'Encountered `None`.*\n.*fn_arg_list.*\n.*None'): maybe_call_fn_and_grads(fn, fn_args) else: with self.assertRaisesRegexp( ValueError, 'Encountered `None`.*\n.*fn_arg_list.*arg1.*\n.*None'): maybe_call_fn_and_grads(fn, fn_args)
def _leapfrog_step(current_momentums, target_log_prob_fn, current_state_parts, step_sizes, current_grads_target_log_prob, name=None): """Applies one step of the leapfrog integrator.""" with tf.name_scope(name, '_leapfrog_step', [ current_momentums, current_state_parts, step_sizes, current_grads_target_log_prob ]): proposed_momentums = [ m + 0.5 * ss * g for m, ss, g in zip( current_momentums, step_sizes, current_grads_target_log_prob) ] proposed_state_parts = [ x + ss * m for x, ss, m in zip(current_state_parts, step_sizes, proposed_momentums) ] [ proposed_target_log_prob, proposed_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, proposed_state_parts) if not proposed_target_log_prob.dtype.is_floating: raise TypeError('`target_log_prob_fn` must produce a `Tensor` ' 'with `float` `dtype`.') if any(g is None for g in proposed_grads_target_log_prob): raise ValueError( 'Encountered `None` gradient. Does your target `target_log_prob_fn` ' 'access all `tf.Variable`s via `tf.get_variable`?\n' ' current_state_parts: {}\n' ' proposed_state_parts: {}\n' ' proposed_grads_target_log_prob: {}'.format( current_state_parts, proposed_state_parts, proposed_grads_target_log_prob)) proposed_momentums = [ m + 0.5 * ss * g for m, ss, g in zip( proposed_momentums, step_sizes, proposed_grads_target_log_prob) ] return [ proposed_momentums, proposed_state_parts, proposed_target_log_prob, proposed_grads_target_log_prob, ]
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] init_state = [tf.convert_to_tensor(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, )
def bootstrap_results(self, init_state): with tf.name_scope( name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [tf.convert_to_tensor(x) for x in init_state] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like(init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, )
def bootstrap_results(self, init_state): with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'), values=[init_state]): if not mcmc_util.is_list_like(init_state): init_state = [init_state] if self.state_gradients_are_stopped: init_state = [tf.stop_gradient(x) for x in init_state] else: init_state = [ tf.convert_to_tensor(value=x) for x in init_state ] [ init_target_log_prob, init_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state) if self._store_parameters_in_results: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda value=x, dtype=init_target_log_prob.dtype, name='step_size'), self.step_size), num_leapfrog_steps=tf.convert_to_tensor( value=self.num_leapfrog_steps, dtype=tf.int64, name='num_leapfrog_steps')) else: return UncalibratedHamiltonianMonteCarloKernelResults( log_acceptance_correction=tf.zeros_like( init_target_log_prob), target_log_prob=init_target_log_prob, grads_target_log_prob=init_grads_target_log_prob, step_size=[], num_leapfrog_steps=[])
def _integrator_conserves_energy(self, x, independent_chain_ndims): event_dims = tf.range(independent_chain_ndims, tf.rank(x)) m = tf.random.normal(tf.shape(input=x)) log_prob_0, grad_0 = maybe_call_fn_and_grads( lambda x: self._log_gamma_log_prob(x, event_dims), x) old_energy = -log_prob_0 + 0.5 * tf.reduce_sum( input_tensor=m**2., axis=event_dims) x_shape = self.evaluate(x).shape event_size = np.prod(x_shape[independent_chain_ndims:]) step_size = tf.constant(0.1 / event_size, x.dtype) hmc_lf_steps = tf.constant(1000, np.int32) def leapfrog_one_step(*args): return _leapfrog_integrator_one_step( lambda x: self._log_gamma_log_prob(x, event_dims), independent_chain_ndims, [step_size], *args) [[new_m], _, log_prob_1, _] = tf.while_loop( cond=lambda *args: True, body=leapfrog_one_step, loop_vars=[ [m], # current_momentum_parts [x], # current_state_parts, log_prob_0, # current_target_log_prob grad_0, # current_target_log_prob_grad_parts ], maximum_iterations=hmc_lf_steps) new_energy = -log_prob_1 + 0.5 * tf.reduce_sum( input_tensor=new_m**2., axis=event_dims) old_energy_, new_energy_ = self.evaluate([old_energy, new_energy]) tf.compat.v1.logging.vlog( 1, 'average energy relative change: {}'.format( (1. - new_energy_ / old_energy_).mean())) self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
def _integrator_conserves_energy(self, x, independent_chain_ndims): event_dims = tf.range(independent_chain_ndims, tf.rank(x)) m = tf.random_normal(tf.shape(x)) log_prob_0, grad_0 = maybe_call_fn_and_grads( lambda x: self._log_gamma_log_prob(x, event_dims), x) old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(m**2., event_dims) x_shape = self.evaluate(x).shape event_size = np.prod(x_shape[independent_chain_ndims:]) step_size = tf.constant(0.1 / event_size, x.dtype) hmc_lf_steps = tf.constant(1000, np.int32) def leapfrog_one_step(*args): return _leapfrog_integrator_one_step( lambda x: self._log_gamma_log_prob(x, event_dims), independent_chain_ndims, [step_size], *args) [[new_m], _, log_prob_1, _] = tf.while_loop( cond=lambda *args: True, body=leapfrog_one_step, loop_vars=[ [m], # current_momentum_parts [x], # current_state_parts, log_prob_0, # current_target_log_prob grad_0, # current_target_log_prob_grad_parts ], maximum_iterations=hmc_lf_steps) new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(new_m**2., axis=event_dims) old_energy_, new_energy_ = self.evaluate([old_energy, new_energy]) tf.logging.vlog(1, 'average energy relative change: {}'.format( (1. - new_energy_ / old_energy_).mean())) self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
def _prepare_args(target_log_prob_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, maybe_expand=False, state_gradients_are_stopped=False): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] state_parts = [tf.convert_to_tensor(s, name='current_state') for s in state_parts] if state_gradients_are_stopped: state_parts = [tf.stop_gradient(x) for x in state_parts] target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor( s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError('There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') def maybe_flatten(x): return x if maybe_expand or mcmc_util.is_list_like(state) else x[0] return [ maybe_flatten(state_parts), maybe_flatten(step_sizes), target_log_prob, grads_target_log_prob, ]
def _maybe_call_volatility_fn_and_grads(volatility_fn, state, volatility_fn_results=None, grads_volatility_fn=None): """Helper which computes `volatility_fn` results and grads, if needed.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] needs_volatility_fn_gradients = grads_volatility_fn is None [ volatility_fn_results, grads_volatility_fn, ] = mcmc_util.maybe_call_fn_and_grads(volatility_fn, state_parts, volatility_fn_results, grads_volatility_fn, check_non_none_grads=False) # Convert `volatility_fn_results` to a list volatility_parts = (list(volatility_fn_results) if mcmc_util.is_list_like(volatility_fn_results) else [volatility_fn_results]) if len(volatility_parts) == 1: volatility_parts *= len(state_parts) if len(volatility_parts) != len(state_parts): raise ValueError('There should be exactly one `volatility_parts` or it' 'should have same length as `current_state`.') # Compute gradient of `volatility_parts ** 2` if needs_volatility_fn_gradients: grads_volatility_fn = [ 2. * g * volatility if g is not None else tf.zeros_like( fn_arg, dtype=fn_arg.dtype.base_dtype) for g, volatility, fn_arg in zip(grads_volatility_fn, volatility_parts, state_parts) ] return volatility_parts, grads_volatility_fn
def _leapfrog_integrator_one_step(target_log_prob_fn, independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts, state_gradients_are_stopped=False, name=None): """Applies `num_leapfrog_steps` of the leapfrog integrator. Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. #### Examples: ##### Simple quadratic potential. ```python import matplotlib.pyplot as plt %matplotlib inline import numpy as np import tensorflow as tf from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator tfd = tfp.distributions dims = 10 num_iter = int(1e3) dtype = np.float32 position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) target_log_prob_fn = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)).log_prob def _leapfrog_one_step(*args): # Closure representing computation done during each leapfrog step. return _leapfrog_integrator_one_step( target_log_prob_fn=target_log_prob_fn, independent_chain_ndims=0, step_sizes=[0.1], current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3]) # Do leapfrog integration. [ [next_momentum], [next_position], next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda *args: True, body=_leapfrog_one_step, loop_vars=[ [momentum], [position], target_log_prob_fn(position), tf.gradients(target_log_prob_fn(position), position), ], maximum_iterations=3) momentum_ = np.random.randn(dims).astype(dtype) position_ = np.random.randn(dims).astype(dtype) positions = np.zeros([num_iter, dims], dtype) with tf.Session() as sess: for i in xrange(num_iter): position_, momentum_ = sess.run( [next_momentum, next_position], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ plt.plot(positions[:, 0]); # Sinusoidal. ``` Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. step_sizes: Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state_parts`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. current_momentum_parts: Tensor containing the value(s) of the momentum variable(s) to update. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. current_target_log_prob_grad_parts: Python list of `Tensor`s representing gradient of `target_log_prob_fn(*current_state_parts`) wrt `current_state_parts`. Must have same shape as `current_state_parts`. The only reason to specify this argument is to reduce TF graph size. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_leapfrog_integrator'). Returns: proposed_momentum_parts: Updated value of the momentum. proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob` wrt `next_state`. Raises: ValueError: if `len(momentum_parts) != len(state_parts)`. ValueError: if `len(state_parts) != len(step_sizes)`. ValueError: if `len(state_parts) != len(grads_target_log_prob)`. TypeError: if `not target_log_prob.dtype.is_floating`. """ # Note on per-variable step sizes: # # Using per-variable step sizes is equivalent to using the same step # size for all variables and adding a diagonal mass matrix in the # kinetic energy term of the Hamiltonian being integrated. This is # hinted at by Neal (2011) but not derived in detail there. # # Let x and v be position and momentum variables respectively. # Let g(x) be the gradient of `target_log_prob_fn(x)`. # Let S be a diagonal matrix of per-variable step sizes. # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2. # # Using per-variable step sizes gives the updates # v' = v + 0.5 * matmul(S, g(x)) # x'' = x + matmul(S, v') # v'' = v' + 0.5 * matmul(S, g(x'')) # # Let u = matmul(inv(S), v). # Multiplying v by inv(S) in the updates above gives the transformed dynamics # u' = matmul(inv(S), v') = matmul(inv(S), v) + 0.5 * g(x) # = u + 0.5 * g(x) # x'' = x + matmul(S, v') = x + matmul(S**2, u') # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'') # = u' + 0.5 * g(x'') # # These are exactly the leapfrog updates for the Hamiltonian # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u # = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v). # # To summarize: # # * Using per-variable step sizes implicitly simulates the dynamics # of the Hamiltonian H' (which are energy-conserving in H'). We # keep track of v instead of u, but the underlying dynamics are # the same if we transform back. # * The value of the Hamiltonian H'(x, u) is the same as the value # of the original Hamiltonian H(x, v) after we transform back from # u to v. # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2). # # So using per-variable step sizes in HMC will give results that are # exactly identical to explicitly using a diagonal mass matrix. with tf.name_scope(name, 'hmc_leapfrog_integrator_one_step', [ independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts ]): # Step 1: Update momentum. proposed_momentum_parts = [ v + 0.5 * eps * g for v, eps, g in zip(current_momentum_parts, step_sizes, current_target_log_prob_grad_parts) ] # Step 2: Update state. proposed_state_parts = [ x + eps * v for x, eps, v in zip(current_state_parts, step_sizes, proposed_momentum_parts) ] if state_gradients_are_stopped: proposed_state_parts = [ tf.stop_gradient(x) for x in proposed_state_parts ] # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state. [ proposed_target_log_prob, proposed_target_log_prob_grad_parts, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, proposed_state_parts) if not proposed_target_log_prob.dtype.is_floating: raise TypeError('`target_log_prob_fn` must produce a `Tensor` ' 'with `float` `dtype`.') if any(g is None for g in proposed_target_log_prob_grad_parts): raise ValueError( 'Encountered `None` gradient. Does your target `target_log_prob_fn` ' 'access all `tf.Variable`s via `tf.get_variable`?\n' ' current_state_parts: {}\n' ' proposed_state_parts: {}\n' ' proposed_target_log_prob_grad_parts: {}'.format( current_state_parts, proposed_state_parts, proposed_target_log_prob_grad_parts)) # Step 3b: Update momentum (again). proposed_momentum_parts = [ v + 0.5 * eps * g for v, eps, g in zip(proposed_momentum_parts, step_sizes, proposed_target_log_prob_grad_parts) ] return [ proposed_momentum_parts, proposed_state_parts, proposed_target_log_prob, proposed_target_log_prob_grad_parts, ]
def _prepare_args(target_log_prob_fn, volatility_fn, state, step_size, target_log_prob=None, grads_target_log_prob=None, volatility=None, grads_volatility_fn=None, diffusion_drift=None, parallel_iterations=10): """Helper which processes input args to meet list-like assumptions.""" state_parts = list(state) if mcmc_util.is_list_like(state) else [state] [ target_log_prob, grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob) [ volatility_parts, grads_volatility, ] = _maybe_call_volatility_fn_and_grads( volatility_fn, state_parts, volatility, grads_volatility_fn, distribution_util.prefer_static_shape(target_log_prob), parallel_iterations) step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size) else [step_size]) step_sizes = [ tf.convert_to_tensor(s, name='step_size', dtype=target_log_prob.dtype) for s in step_sizes ] if len(step_sizes) == 1: step_sizes *= len(state_parts) if len(state_parts) != len(step_sizes): raise ValueError( 'There should be exactly one `step_size` or it should ' 'have same length as `current_state`.') if diffusion_drift is None: diffusion_drift_parts = _get_drift(step_sizes, volatility_parts, grads_volatility, grads_target_log_prob) else: diffusion_drift_parts = (list(diffusion_drift) if mcmc_util.is_list_like(diffusion_drift) else [diffusion_drift]) if len(state_parts) != len(diffusion_drift): raise ValueError( 'There should be exactly one `diffusion_drift` or it ' 'should have same length as list-like `current_state`.') return [ state_parts, step_sizes, target_log_prob, grads_target_log_prob, volatility_parts, grads_volatility, diffusion_drift_parts, ]
def _leapfrog_integrator_one_step( target_log_prob_fn, independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts, state_gradients_are_stopped=False, name=None): """Applies `num_leapfrog_steps` of the leapfrog integrator. Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. #### Examples: ##### Simple quadratic potential. ```python import matplotlib.pyplot as plt %matplotlib inline import numpy as np import tensorflow as tf from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator tfd = tfp.distributions dims = 10 num_iter = int(1e3) dtype = np.float32 position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) target_log_prob_fn = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)).log_prob def _leapfrog_one_step(*args): # Closure representing computation done during each leapfrog step. return _leapfrog_integrator_one_step( target_log_prob_fn=target_log_prob_fn, independent_chain_ndims=0, step_sizes=[0.1], current_momentum_parts=args[0], current_state_parts=args[1], current_target_log_prob=args[2], current_target_log_prob_grad_parts=args[3]) # Do leapfrog integration. [ [next_momentum], [next_position], next_target_log_prob, next_target_log_prob_grad_parts, ] = tf.while_loop( cond=lambda *args: True, body=_leapfrog_one_step, loop_vars=[ [momentum], [position], target_log_prob_fn(position), tf.gradients(target_log_prob_fn(position), position), ], maximum_iterations=3) momentum_ = np.random.randn(dims).astype(dtype) position_ = np.random.randn(dims).astype(dtype) positions = np.zeros([num_iter, dims], dtype) with tf.Session() as sess: for i in xrange(num_iter): position_, momentum_ = sess.run( [next_momentum, next_position], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ plt.plot(positions[:, 0]); # Sinusoidal. ``` Args: target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. independent_chain_ndims: Scalar `int` `Tensor` representing the number of leftmost `Tensor` dimensions which index independent chains. step_sizes: Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state_parts`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. current_momentum_parts: Tensor containing the value(s) of the momentum variable(s) to update. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. current_target_log_prob_grad_parts: Python list of `Tensor`s representing gradient of `target_log_prob_fn(*current_state_parts`) wrt `current_state_parts`. Must have same shape as `current_state_parts`. The only reason to specify this argument is to reduce TF graph size. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_leapfrog_integrator'). Returns: proposed_momentum_parts: Updated value of the momentum. proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob` wrt `next_state`. Raises: ValueError: if `len(momentum_parts) != len(state_parts)`. ValueError: if `len(state_parts) != len(step_sizes)`. ValueError: if `len(state_parts) != len(grads_target_log_prob)`. TypeError: if `not target_log_prob.dtype.is_floating`. """ # Note on per-variable step sizes: # # Using per-variable step sizes is equivalent to using the same step # size for all variables and adding a diagonal mass matrix in the # kinetic energy term of the Hamiltonian being integrated. This is # hinted at by Neal (2011) but not derived in detail there. # # Let x and v be position and momentum variables respectively. # Let g(x) be the gradient of `target_log_prob_fn(x)`. # Let S be a diagonal matrix of per-variable step sizes. # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2. # # Using per-variable step sizes gives the updates # v' = v + 0.5 * matmul(S, g(x)) # x'' = x + matmul(S, v') # v'' = v' + 0.5 * matmul(S, g(x'')) # # Let u = matmul(inv(S), v). # Multiplying v by inv(S) in the updates above gives the transformed dynamics # u' = matmul(inv(S), v') = matmul(inv(S), v) + 0.5 * g(x) # = u + 0.5 * g(x) # x'' = x + matmul(S, v') = x + matmul(S**2, u') # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'') # = u' + 0.5 * g(x'') # # These are exactly the leapfrog updates for the Hamiltonian # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u # = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v). # # To summarize: # # * Using per-variable step sizes implicitly simulates the dynamics # of the Hamiltonian H' (which are energy-conserving in H'). We # keep track of v instead of u, but the underlying dynamics are # the same if we transform back. # * The value of the Hamiltonian H'(x, u) is the same as the value # of the original Hamiltonian H(x, v) after we transform back from # u to v. # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2). # # So using per-variable step sizes in HMC will give results that are # exactly identical to explicitly using a diagonal mass matrix. with tf.name_scope( name, 'hmc_leapfrog_integrator_one_step', [independent_chain_ndims, step_sizes, current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts]): # Step 1: Update momentum. proposed_momentum_parts = [ v + 0.5 * tf.cast(eps, v.dtype) * g for v, eps, g in zip(current_momentum_parts, step_sizes, current_target_log_prob_grad_parts)] # Step 2: Update state. proposed_state_parts = [ x + tf.cast(eps, v.dtype) * v for x, eps, v in zip(current_state_parts, step_sizes, proposed_momentum_parts)] if state_gradients_are_stopped: proposed_state_parts = [tf.stop_gradient(x) for x in proposed_state_parts] # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state. [ proposed_target_log_prob, proposed_target_log_prob_grad_parts, ] = mcmc_util.maybe_call_fn_and_grads( target_log_prob_fn, proposed_state_parts) if not proposed_target_log_prob.dtype.is_floating: raise TypeError('`target_log_prob_fn` must produce a `Tensor` ' 'with `float` `dtype`.') if any(g is None for g in proposed_target_log_prob_grad_parts): raise ValueError( 'Encountered `None` gradient. Does your target `target_log_prob_fn` ' 'access all `tf.Variable`s via `tf.get_variable`?\n' ' current_state_parts: {}\n' ' proposed_state_parts: {}\n' ' proposed_target_log_prob_grad_parts: {}'.format( current_state_parts, proposed_state_parts, proposed_target_log_prob_grad_parts)) # Step 3b: Update momentum (again). proposed_momentum_parts = [ v + 0.5 * tf.cast(eps, v.dtype) * g for v, eps, g in zip(proposed_momentum_parts, step_sizes, proposed_target_log_prob_grad_parts)] return [ proposed_momentum_parts, proposed_state_parts, proposed_target_log_prob, proposed_target_log_prob_grad_parts, ]
def _leapfrog_integrator(current_momentums, target_log_prob_fn, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob, name=None): """Applies `num_leapfrog_steps` of the leapfrog integrator. Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. #### Examples: ##### Simple quadratic potential. ```python import matplotlib.pyplot as plt %matplotlib inline import numpy as np import tensorflow as tf from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator tfd = tf.contrib.distributions dims = 10 num_iter = int(1e3) dtype = np.float32 position = tf.placeholder(np.float32) momentum = tf.placeholder(np.float32) target_log_prob_fn = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)).log_prob current_target_log_prob = target_log_prob_fn(position) current_grads_target_log_prob = tf.gradients( current_target_log_prob, position) [ next_momentums, next_positions, ] = _leapfrog_integrator( current_momentums=[momentum], target_log_prob_fn=tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)).log_prob, current_state_parts=[position], step_sizes=[0.1], num_leapfrog_steps=3, current_target_log_prob=current_target_log_prob, current_grads_target_log_prob=current_grads_target_log_prob)[:2] momentum_ = np.random.randn(dims).astype(dtype) position_ = np.random.randn(dims).astype(dtype) positions = np.zeros([num_iter, dims], dtype) with tf.Session() as sess: for i in xrange(num_iter): position_, momentum_ = sess.run( [next_momentums[0], next_positions[0]], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ plt.plot(positions[:, 0]); # Sinusoidal. ``` Args: current_momentums: Tensor containing the value(s) of the momentum variable(s) to update. target_log_prob_fn: Python callable which takes an argument like `*current_state_parts` and returns its (possibly unnormalized) log-density under the target distribution. current_state_parts: Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `independent_chain_ndims` of the `Tensor`(s) index different chains. step_sizes: Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state_parts`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. current_target_log_prob: `Tensor` representing the value of `target_log_prob_fn(*current_state_parts)`. The only reason to specify this argument is to reduce TF graph size. current_grads_target_log_prob: Python list of `Tensor`s representing gradient of `target_log_prob_fn(*current_state_parts`) wrt `current_state_parts`. Must have same shape as `current_state_parts`. The only reason to specify this argument is to reduce TF graph size. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_leapfrog_integrator'). Returns: proposed_momentums: Updated value of the momentum. proposed_state_parts: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of `target_log_prob_fn` at `next_state`. proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt `next_state`. Raises: ValueError: if `len(momentums) != len(state_parts)`. ValueError: if `len(state_parts) != len(step_sizes)`. ValueError: if `len(state_parts) != len(grads_target_log_prob)`. TypeError: if `not target_log_prob.dtype.is_floating`. """ # Note on per-variable step sizes: # # Using per-variable step sizes is equivalent to using the same step # size for all variables and adding a diagonal mass matrix in the # kinetic energy term of the Hamiltonian being integrated. This is # hinted at by Neal (2011) but not derived in detail there. # # Let x and v be position and momentum variables respectively. # Let g(x) be the gradient of `target_log_prob_fn(x)`. # Let S be a diagonal matrix of per-variable step sizes. # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2. # # Using per-variable step sizes gives the updates # v' = v + 0.5 * matmul(S, g(x)) # x'' = x + matmul(S, v') # v'' = v' + 0.5 * matmul(S, g(x'')) # # Let u = matmul(inv(S), v). # Multiplying v by inv(S) in the updates above gives the transformed dynamics # u' = matmul(inv(S), v') = matmul(inv(S), v) + 0.5 * g(x) # = u + 0.5 * g(x) # x'' = x + matmul(S, v') = x + matmul(S**2, u') # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'') # = u' + 0.5 * g(x'') # # These are exactly the leapfrog updates for the Hamiltonian # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u # = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v). # # To summarize: # # * Using per-variable step sizes implicitly simulates the dynamics # of the Hamiltonian H' (which are energy-conserving in H'). We # keep track of v instead of u, but the underlying dynamics are # the same if we transform back. # * The value of the Hamiltonian H'(x, u) is the same as the value # of the original Hamiltonian H(x, v) after we transform back from # u to v. # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2). # # So using per-variable step sizes in HMC will give results that are # exactly identical to explicitly using a diagonal mass matrix. def _loop_body( step, current_momentums, current_state_parts, ignore_current_target_log_prob, # pylint: disable=unused-argument current_grads_target_log_prob): return [step + 1] + list( _leapfrog_step(current_momentums, target_log_prob_fn, current_state_parts, step_sizes, current_grads_target_log_prob)) with tf.name_scope(name, 'hmc_leapfrog_integrator', [ current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, current_target_log_prob, current_grads_target_log_prob ]): if len(current_momentums) != len(current_state_parts): raise ValueError( '`momentums` must be in one-to-one correspondence ' 'with `state_parts`') num_leapfrog_steps = tf.convert_to_tensor(num_leapfrog_steps, name='num_leapfrog_steps') [ current_target_log_prob, current_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, current_state_parts, current_target_log_prob, current_grads_target_log_prob) return tf.while_loop( cond=lambda iter_, *args: iter_ < num_leapfrog_steps, body=_loop_body, loop_vars=[ np.int32(0), # iter_ current_momentums, current_state_parts, current_target_log_prob, current_grads_target_log_prob, ], back_prop=False)[1:] # Lop-off "iter_".
def testGradientWorksDespiteBijectorCaching(self): x = tf.constant(2.) fn_result, grads = maybe_call_fn_and_grads( lambda x_: tfd.LogNormal(loc=0., scale=1.).log_prob(x_), x) self.assertAllEqual(False, fn_result is None) self.assertAllEqual([False], [g is None for g in grads])
def testGradientWorksDespiteBijectorCaching(self): d = tfd.LogNormal(loc=0., scale=1.) x = tf.constant(2.) fn_result, grads = maybe_call_fn_and_grads(lambda x: d.log_prob(x), x) # pylint: disable=unnecessary-lambda self.assertAllEqual(False, fn_result is None) self.assertAllEqual([False], [g is None for g in grads])