Ejemplo n.º 1
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  state_parts = [tf.convert_to_tensor(s, name='current_state')
                 for s in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn,
      state_parts,
      target_log_prob,
      grads_target_log_prob)
  step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size)
                else [step_size])
  step_sizes = [
      tf.convert_to_tensor(
          s, name='step_size', dtype=target_log_prob.dtype)
      for s in step_sizes]
  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      target_log_prob,
      grads_target_log_prob,
  ]
Ejemplo n.º 2
0
def _prepare_args(target_log_prob_fn,
                  volatility_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  volatility=None,
                  grads_volatility_fn=None):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
    state_parts = [
        tf.convert_to_tensor(s, name='current_state') for s in state_parts
    ]
    [
        target_log_prob,
        grads_target_log_prob,
    ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts,
                                          target_log_prob,
                                          grads_target_log_prob)
    [
        volatility_parts,
        grads_volatility,
    ] = _maybe_call_volatility_fn_and_grads(volatility_fn, state_parts,
                                            volatility, grads_volatility_fn)

    step_sizes = (list(step_size)
                  if mcmc_util.is_list_like(step_size) else [step_size])
    step_sizes = [
        tf.convert_to_tensor(s, name='step_size', dtype=target_log_prob.dtype)
        for s in step_sizes
    ]
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    # The shape of 'volatility_parts' needs to have the number of chains as a
    # leading dimension. For determinism we broadcast 'volatility_parts' to the
    # shape of `state_parts` since each dimension of `state_parts` could have a
    # different volatility value.
    volatility_parts = _maybe_broadcast_volatility(volatility_parts,
                                                   state_parts)

    diffusion_drift_parts = _get_drift(step_sizes, volatility_parts,
                                       grads_volatility, grads_target_log_prob)

    return [
        state_parts,
        step_sizes,
        target_log_prob,
        grads_target_log_prob,
        volatility_parts,
        grads_volatility,
        diffusion_drift_parts,
    ]
Ejemplo n.º 3
0
  def testGradientComputesCorrectly(self):
    dtype = np.float32
    def fn(x, y):
      return x**2 + y**2

    fn_args = [dtype(3), dtype(3)]
    # Convert function input to a list of tensors
    fn_args = [tf.convert_to_tensor(arg) for arg in fn_args]
    fn_result, grads = maybe_call_fn_and_grads(fn, fn_args)
    fn_result_, grads_ = self.evaluate([fn_result, grads])
    self.assertNear(18., fn_result_, err=1e-5)
    for grad in grads_:
      self.assertAllClose(grad, dtype(6), atol=0., rtol=1e-5)
Ejemplo n.º 4
0
    def testNoGradientsNiceError(self):
        dtype = np.float32

        def fn(x, y):
            return x**2 + tf.stop_gradient(y)**2

        fn_args = [dtype(3), dtype(3)]
        # Convert function input to a list of tensors
        fn_args = [
            tf.convert_to_tensor(arg, name='arg{}'.format(i))
            for i, arg in enumerate(fn_args)
        ]
        if tf.executing_eagerly():
            with self.assertRaisesRegexp(
                    ValueError,
                    'Encountered `None`.*\n.*fn_arg_list.*\n.*None'):
                maybe_call_fn_and_grads(fn, fn_args)
        else:
            with self.assertRaisesRegexp(
                    ValueError,
                    'Encountered `None`.*\n.*fn_arg_list.*arg1.*\n.*None'):
                maybe_call_fn_and_grads(fn, fn_args)
Ejemplo n.º 5
0
    def testGradientComputesCorrectly(self):
        dtype = np.float32

        def fn(x, y):
            return x**2 + y**2

        fn_args = [dtype(3), dtype(3)]
        # Convert function input to a list of tensors
        fn_args = [tf.convert_to_tensor(arg) for arg in fn_args]
        fn_result, grads = maybe_call_fn_and_grads(fn, fn_args)
        fn_result_, grads_ = self.evaluate([fn_result, grads])
        self.assertNear(18., fn_result_, err=1e-5)
        for grad in grads_:
            self.assertAllClose(grad, dtype(6), atol=0., rtol=1e-5)
Ejemplo n.º 6
0
def _leapfrog_step(current_momentums,
                   target_log_prob_fn,
                   current_state_parts,
                   step_sizes,
                   current_grads_target_log_prob,
                   name=None):
    """Applies one step of the leapfrog integrator."""
    with tf.name_scope(name, '_leapfrog_step', [
            current_momentums, current_state_parts, step_sizes,
            current_grads_target_log_prob
    ]):
        proposed_momentums = [
            m + 0.5 * ss * g for m, ss, g in zip(
                current_momentums, step_sizes, current_grads_target_log_prob)
        ]
        proposed_state_parts = [
            x + ss * m for x, ss, m in zip(current_state_parts, step_sizes,
                                           proposed_momentums)
        ]

        [
            proposed_target_log_prob,
            proposed_grads_target_log_prob,
        ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn,
                                              proposed_state_parts)

        if not proposed_target_log_prob.dtype.is_floating:
            raise TypeError('`target_log_prob_fn` must produce a `Tensor` '
                            'with `float` `dtype`.')

        if any(g is None for g in proposed_grads_target_log_prob):
            raise ValueError(
                'Encountered `None` gradient. Does your target `target_log_prob_fn` '
                'access all `tf.Variable`s via `tf.get_variable`?\n'
                '  current_state_parts: {}\n'
                '  proposed_state_parts: {}\n'
                '  proposed_grads_target_log_prob: {}'.format(
                    current_state_parts, proposed_state_parts,
                    proposed_grads_target_log_prob))
        proposed_momentums = [
            m + 0.5 * ss * g for m, ss, g in zip(
                proposed_momentums, step_sizes, proposed_grads_target_log_prob)
        ]
        return [
            proposed_momentums,
            proposed_state_parts,
            proposed_target_log_prob,
            proposed_grads_target_log_prob,
        ]
Ejemplo n.º 7
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     init_state = [tf.convert_to_tensor(x) for x in init_state]
     [
         init_target_log_prob,
         init_grads_target_log_prob,
     ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state)
     return UncalibratedHamiltonianMonteCarloKernelResults(
         log_acceptance_correction=tf.zeros_like(init_target_log_prob),
         target_log_prob=init_target_log_prob,
         grads_target_log_prob=init_grads_target_log_prob,
     )
Ejemplo n.º 8
0
 def bootstrap_results(self, init_state):
   with tf.name_scope(
       name=mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results'),
       values=[init_state]):
     if not mcmc_util.is_list_like(init_state):
       init_state = [init_state]
     if self.state_gradients_are_stopped:
       init_state = [tf.stop_gradient(x) for x in init_state]
     else:
       init_state = [tf.convert_to_tensor(x) for x in init_state]
     [
         init_target_log_prob,
         init_grads_target_log_prob,
     ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn, init_state)
     return UncalibratedHamiltonianMonteCarloKernelResults(
         log_acceptance_correction=tf.zeros_like(init_target_log_prob),
         target_log_prob=init_target_log_prob,
         grads_target_log_prob=init_grads_target_log_prob,
     )
Ejemplo n.º 9
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(name=mcmc_util.make_name(self.name, 'hmc',
                                                 'bootstrap_results'),
                        values=[init_state]):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         if self.state_gradients_are_stopped:
             init_state = [tf.stop_gradient(x) for x in init_state]
         else:
             init_state = [
                 tf.convert_to_tensor(value=x) for x in init_state
             ]
         [
             init_target_log_prob,
             init_grads_target_log_prob,
         ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn,
                                               init_state)
         if self._store_parameters_in_results:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=tf.nest.map_structure(
                     lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                         value=x,
                         dtype=init_target_log_prob.dtype,
                         name='step_size'),
                     self.step_size),
                 num_leapfrog_steps=tf.convert_to_tensor(
                     value=self.num_leapfrog_steps,
                     dtype=tf.int64,
                     name='num_leapfrog_steps'))
         else:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=[],
                 num_leapfrog_steps=[])
Ejemplo n.º 10
0
  def _integrator_conserves_energy(self, x, independent_chain_ndims):
    event_dims = tf.range(independent_chain_ndims, tf.rank(x))

    m = tf.random.normal(tf.shape(input=x))
    log_prob_0, grad_0 = maybe_call_fn_and_grads(
        lambda x: self._log_gamma_log_prob(x, event_dims),
        x)
    old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(
        input_tensor=m**2., axis=event_dims)

    x_shape = self.evaluate(x).shape
    event_size = np.prod(x_shape[independent_chain_ndims:])
    step_size = tf.constant(0.1 / event_size, x.dtype)
    hmc_lf_steps = tf.constant(1000, np.int32)

    def leapfrog_one_step(*args):
      return _leapfrog_integrator_one_step(
          lambda x: self._log_gamma_log_prob(x, event_dims),
          independent_chain_ndims,
          [step_size],
          *args)

    [[new_m], _, log_prob_1, _] = tf.while_loop(
        cond=lambda *args: True,
        body=leapfrog_one_step,
        loop_vars=[
            [m],         # current_momentum_parts
            [x],         # current_state_parts,
            log_prob_0,  # current_target_log_prob
            grad_0,      # current_target_log_prob_grad_parts
        ],
        maximum_iterations=hmc_lf_steps)

    new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(
        input_tensor=new_m**2., axis=event_dims)

    old_energy_, new_energy_ = self.evaluate([old_energy, new_energy])
    tf.compat.v1.logging.vlog(
        1, 'average energy relative change: {}'.format(
            (1. - new_energy_ / old_energy_).mean()))
    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
Ejemplo n.º 11
0
  def _integrator_conserves_energy(self, x, independent_chain_ndims):
    event_dims = tf.range(independent_chain_ndims, tf.rank(x))

    m = tf.random_normal(tf.shape(x))
    log_prob_0, grad_0 = maybe_call_fn_and_grads(
        lambda x: self._log_gamma_log_prob(x, event_dims),
        x)
    old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(m**2., event_dims)

    x_shape = self.evaluate(x).shape
    event_size = np.prod(x_shape[independent_chain_ndims:])
    step_size = tf.constant(0.1 / event_size, x.dtype)
    hmc_lf_steps = tf.constant(1000, np.int32)

    def leapfrog_one_step(*args):
      return _leapfrog_integrator_one_step(
          lambda x: self._log_gamma_log_prob(x, event_dims),
          independent_chain_ndims,
          [step_size],
          *args)

    [[new_m], _, log_prob_1, _] = tf.while_loop(
        cond=lambda *args: True,
        body=leapfrog_one_step,
        loop_vars=[
            [m],         # current_momentum_parts
            [x],         # current_state_parts,
            log_prob_0,  # current_target_log_prob
            grad_0,      # current_target_log_prob_grad_parts
        ],
        maximum_iterations=hmc_lf_steps)

    new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(new_m**2., axis=event_dims)

    old_energy_, new_energy_ = self.evaluate([old_energy, new_energy])
    tf.logging.vlog(1, 'average energy relative change: {}'.format(
        (1. - new_energy_ / old_energy_).mean()))
    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
Ejemplo n.º 12
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
  state_parts = [tf.convert_to_tensor(s, name='current_state')
                 for s in state_parts]
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn,
      state_parts,
      target_log_prob,
      grads_target_log_prob)
  step_sizes = (list(step_size) if mcmc_util.is_list_like(step_size)
                else [step_size])
  step_sizes = [
      tf.convert_to_tensor(
          s, name='step_size', dtype=target_log_prob.dtype)
      for s in step_sizes]
  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      target_log_prob,
      grads_target_log_prob,
  ]
Ejemplo n.º 13
0
def _maybe_call_volatility_fn_and_grads(volatility_fn,
                                        state,
                                        volatility_fn_results=None,
                                        grads_volatility_fn=None):
    """Helper which computes `volatility_fn` results and grads, if needed."""
    state_parts = list(state) if mcmc_util.is_list_like(state) else [state]
    needs_volatility_fn_gradients = grads_volatility_fn is None
    [
        volatility_fn_results,
        grads_volatility_fn,
    ] = mcmc_util.maybe_call_fn_and_grads(volatility_fn,
                                          state_parts,
                                          volatility_fn_results,
                                          grads_volatility_fn,
                                          check_non_none_grads=False)

    # Convert `volatility_fn_results` to a list
    volatility_parts = (list(volatility_fn_results)
                        if mcmc_util.is_list_like(volatility_fn_results) else
                        [volatility_fn_results])

    if len(volatility_parts) == 1:
        volatility_parts *= len(state_parts)

    if len(volatility_parts) != len(state_parts):
        raise ValueError('There should be exactly one `volatility_parts` or it'
                         'should have same length as `current_state`.')

    # Compute gradient of `volatility_parts ** 2`
    if needs_volatility_fn_gradients:
        grads_volatility_fn = [
            2. * g * volatility if g is not None else tf.zeros_like(
                fn_arg, dtype=fn_arg.dtype.base_dtype) for g, volatility,
            fn_arg in zip(grads_volatility_fn, volatility_parts, state_parts)
        ]

    return volatility_parts, grads_volatility_fn
Ejemplo n.º 14
0
def _leapfrog_integrator_one_step(target_log_prob_fn,
                                  independent_chain_ndims,
                                  step_sizes,
                                  current_momentum_parts,
                                  current_state_parts,
                                  current_target_log_prob,
                                  current_target_log_prob_grad_parts,
                                  state_gradients_are_stopped=False,
                                  name=None):
    """Applies `num_leapfrog_steps` of the leapfrog integrator.

  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.

  #### Examples:

  ##### Simple quadratic potential.

  ```python
  import matplotlib.pyplot as plt
  %matplotlib inline
  import numpy as np
  import tensorflow as tf
  from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator
  tfd = tfp.distributions

  dims = 10
  num_iter = int(1e3)
  dtype = np.float32

  position = tf.placeholder(np.float32)
  momentum = tf.placeholder(np.float32)

  target_log_prob_fn = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype)).log_prob

  def _leapfrog_one_step(*args):
    # Closure representing computation done during each leapfrog step.
    return _leapfrog_integrator_one_step(
        target_log_prob_fn=target_log_prob_fn,
        independent_chain_ndims=0,
        step_sizes=[0.1],
        current_momentum_parts=args[0],
        current_state_parts=args[1],
        current_target_log_prob=args[2],
        current_target_log_prob_grad_parts=args[3])

  # Do leapfrog integration.
  [
      [next_momentum],
      [next_position],
      next_target_log_prob,
      next_target_log_prob_grad_parts,
  ] = tf.while_loop(
      cond=lambda *args: True,
      body=_leapfrog_one_step,
      loop_vars=[
        [momentum],
        [position],
        target_log_prob_fn(position),
        tf.gradients(target_log_prob_fn(position), position),
      ],
      maximum_iterations=3)

  momentum_ = np.random.randn(dims).astype(dtype)
  position_ = np.random.randn(dims).astype(dtype)
  positions = np.zeros([num_iter, dims], dtype)

  with tf.Session() as sess:
    for i in xrange(num_iter):
      position_, momentum_ = sess.run(
          [next_momentum, next_position],
          feed_dict={position: position_, momentum: momentum_})
      positions[i] = position_

  plt.plot(positions[:, 0]);  # Sinusoidal.
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    step_sizes: Python `list` of `Tensor`s representing the step size for the
      leapfrog integrator. Must broadcast with the shape of
      `current_state_parts`.  Larger step sizes lead to faster progress, but
      too-large step sizes make rejection exponentially more likely. When
      possible, it's often helpful to match per-variable step sizes to the
      standard deviations of the target distribution in each variable.
    current_momentum_parts: Tensor containing the value(s) of the momentum
      variable(s) to update.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    current_target_log_prob_grad_parts: Python list of `Tensor`s representing
      gradient of `target_log_prob_fn(*current_state_parts`) wrt
      `current_state_parts`. Must have same shape as `current_state_parts`. The
      only reason to specify this argument is to reduce TF graph size.
    state_gradients_are_stopped: Python `bool` indicating that the proposed new
      state be run through `tf.stop_gradient`. This is particularly useful when
      combining optimization over samples from the HMC chain.
      Default value: `False` (i.e., do not apply `stop_gradient`).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'hmc_leapfrog_integrator').

  Returns:
    proposed_momentum_parts: Updated value of the momentum.
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob`
      wrt `next_state`.

  Raises:
    ValueError: if `len(momentum_parts) != len(state_parts)`.
    ValueError: if `len(state_parts) != len(step_sizes)`.
    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
    TypeError: if `not target_log_prob.dtype.is_floating`.
  """
    # Note on per-variable step sizes:
    #
    # Using per-variable step sizes is equivalent to using the same step
    # size for all variables and adding a diagonal mass matrix in the
    # kinetic energy term of the Hamiltonian being integrated. This is
    # hinted at by Neal (2011) but not derived in detail there.
    #
    # Let x and v be position and momentum variables respectively.
    # Let g(x) be the gradient of `target_log_prob_fn(x)`.
    # Let S be a diagonal matrix of per-variable step sizes.
    # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2.
    #
    # Using per-variable step sizes gives the updates
    # v'  = v  + 0.5 * matmul(S, g(x))
    # x'' = x  + matmul(S, v')
    # v'' = v' + 0.5 * matmul(S, g(x''))
    #
    # Let u = matmul(inv(S), v).
    # Multiplying v by inv(S) in the updates above gives the transformed dynamics
    # u'  = matmul(inv(S), v')  = matmul(inv(S), v) + 0.5 * g(x)
    #                           = u + 0.5 * g(x)
    # x'' = x + matmul(S, v') = x + matmul(S**2, u')
    # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'')
    #                           = u' + 0.5 * g(x'')
    #
    # These are exactly the leapfrog updates for the Hamiltonian
    # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u
    #          = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v).
    #
    # To summarize:
    #
    # * Using per-variable step sizes implicitly simulates the dynamics
    #   of the Hamiltonian H' (which are energy-conserving in H'). We
    #   keep track of v instead of u, but the underlying dynamics are
    #   the same if we transform back.
    # * The value of the Hamiltonian H'(x, u) is the same as the value
    #   of the original Hamiltonian H(x, v) after we transform back from
    #   u to v.
    # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2).
    #
    # So using per-variable step sizes in HMC will give results that are
    # exactly identical to explicitly using a diagonal mass matrix.

    with tf.name_scope(name, 'hmc_leapfrog_integrator_one_step', [
            independent_chain_ndims, step_sizes, current_momentum_parts,
            current_state_parts, current_target_log_prob,
            current_target_log_prob_grad_parts
    ]):

        # Step 1: Update momentum.
        proposed_momentum_parts = [
            v + 0.5 * eps * g
            for v, eps, g in zip(current_momentum_parts, step_sizes,
                                 current_target_log_prob_grad_parts)
        ]

        # Step 2: Update state.
        proposed_state_parts = [
            x + eps * v for x, eps, v in zip(current_state_parts, step_sizes,
                                             proposed_momentum_parts)
        ]

        if state_gradients_are_stopped:
            proposed_state_parts = [
                tf.stop_gradient(x) for x in proposed_state_parts
            ]

        # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state.
        [
            proposed_target_log_prob,
            proposed_target_log_prob_grad_parts,
        ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn,
                                              proposed_state_parts)

        if not proposed_target_log_prob.dtype.is_floating:
            raise TypeError('`target_log_prob_fn` must produce a `Tensor` '
                            'with `float` `dtype`.')

        if any(g is None for g in proposed_target_log_prob_grad_parts):
            raise ValueError(
                'Encountered `None` gradient. Does your target `target_log_prob_fn` '
                'access all `tf.Variable`s via `tf.get_variable`?\n'
                '  current_state_parts: {}\n'
                '  proposed_state_parts: {}\n'
                '  proposed_target_log_prob_grad_parts: {}'.format(
                    current_state_parts, proposed_state_parts,
                    proposed_target_log_prob_grad_parts))

        # Step 3b: Update momentum (again).
        proposed_momentum_parts = [
            v + 0.5 * eps * g
            for v, eps, g in zip(proposed_momentum_parts, step_sizes,
                                 proposed_target_log_prob_grad_parts)
        ]

        return [
            proposed_momentum_parts,
            proposed_state_parts,
            proposed_target_log_prob,
            proposed_target_log_prob_grad_parts,
        ]
Ejemplo n.º 15
0
def _prepare_args(target_log_prob_fn,
                  volatility_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  volatility=None,
                  grads_volatility_fn=None,
                  diffusion_drift=None,
                  parallel_iterations=10):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts = list(state) if mcmc_util.is_list_like(state) else [state]

    [
        target_log_prob,
        grads_target_log_prob,
    ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts,
                                          target_log_prob,
                                          grads_target_log_prob)
    [
        volatility_parts,
        grads_volatility,
    ] = _maybe_call_volatility_fn_and_grads(
        volatility_fn, state_parts, volatility, grads_volatility_fn,
        distribution_util.prefer_static_shape(target_log_prob),
        parallel_iterations)

    step_sizes = (list(step_size)
                  if mcmc_util.is_list_like(step_size) else [step_size])
    step_sizes = [
        tf.convert_to_tensor(s, name='step_size', dtype=target_log_prob.dtype)
        for s in step_sizes
    ]
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    if diffusion_drift is None:
        diffusion_drift_parts = _get_drift(step_sizes, volatility_parts,
                                           grads_volatility,
                                           grads_target_log_prob)
    else:
        diffusion_drift_parts = (list(diffusion_drift)
                                 if mcmc_util.is_list_like(diffusion_drift)
                                 else [diffusion_drift])
        if len(state_parts) != len(diffusion_drift):
            raise ValueError(
                'There should be exactly one `diffusion_drift` or it '
                'should have same length as list-like `current_state`.')

    return [
        state_parts,
        step_sizes,
        target_log_prob,
        grads_target_log_prob,
        volatility_parts,
        grads_volatility,
        diffusion_drift_parts,
    ]
Ejemplo n.º 16
0
def _leapfrog_integrator_one_step(
    target_log_prob_fn,
    independent_chain_ndims,
    step_sizes,
    current_momentum_parts,
    current_state_parts,
    current_target_log_prob,
    current_target_log_prob_grad_parts,
    state_gradients_are_stopped=False,
    name=None):
  """Applies `num_leapfrog_steps` of the leapfrog integrator.

  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.

  #### Examples:

  ##### Simple quadratic potential.

  ```python
  import matplotlib.pyplot as plt
  %matplotlib inline
  import numpy as np
  import tensorflow as tf
  from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator
  tfd = tfp.distributions

  dims = 10
  num_iter = int(1e3)
  dtype = np.float32

  position = tf.placeholder(np.float32)
  momentum = tf.placeholder(np.float32)

  target_log_prob_fn = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype)).log_prob

  def _leapfrog_one_step(*args):
    # Closure representing computation done during each leapfrog step.
    return _leapfrog_integrator_one_step(
        target_log_prob_fn=target_log_prob_fn,
        independent_chain_ndims=0,
        step_sizes=[0.1],
        current_momentum_parts=args[0],
        current_state_parts=args[1],
        current_target_log_prob=args[2],
        current_target_log_prob_grad_parts=args[3])

  # Do leapfrog integration.
  [
      [next_momentum],
      [next_position],
      next_target_log_prob,
      next_target_log_prob_grad_parts,
  ] = tf.while_loop(
      cond=lambda *args: True,
      body=_leapfrog_one_step,
      loop_vars=[
        [momentum],
        [position],
        target_log_prob_fn(position),
        tf.gradients(target_log_prob_fn(position), position),
      ],
      maximum_iterations=3)

  momentum_ = np.random.randn(dims).astype(dtype)
  position_ = np.random.randn(dims).astype(dtype)
  positions = np.zeros([num_iter, dims], dtype)

  with tf.Session() as sess:
    for i in xrange(num_iter):
      position_, momentum_ = sess.run(
          [next_momentum, next_position],
          feed_dict={position: position_, momentum: momentum_})
      positions[i] = position_

  plt.plot(positions[:, 0]);  # Sinusoidal.
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    step_sizes: Python `list` of `Tensor`s representing the step size for the
      leapfrog integrator. Must broadcast with the shape of
      `current_state_parts`.  Larger step sizes lead to faster progress, but
      too-large step sizes make rejection exponentially more likely. When
      possible, it's often helpful to match per-variable step sizes to the
      standard deviations of the target distribution in each variable.
    current_momentum_parts: Tensor containing the value(s) of the momentum
      variable(s) to update.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    current_target_log_prob_grad_parts: Python list of `Tensor`s representing
      gradient of `target_log_prob_fn(*current_state_parts`) wrt
      `current_state_parts`. Must have same shape as `current_state_parts`. The
      only reason to specify this argument is to reduce TF graph size.
    state_gradients_are_stopped: Python `bool` indicating that the proposed new
      state be run through `tf.stop_gradient`. This is particularly useful when
      combining optimization over samples from the HMC chain.
      Default value: `False` (i.e., do not apply `stop_gradient`).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'hmc_leapfrog_integrator').

  Returns:
    proposed_momentum_parts: Updated value of the momentum.
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob`
      wrt `next_state`.

  Raises:
    ValueError: if `len(momentum_parts) != len(state_parts)`.
    ValueError: if `len(state_parts) != len(step_sizes)`.
    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
    TypeError: if `not target_log_prob.dtype.is_floating`.
  """
  # Note on per-variable step sizes:
  #
  # Using per-variable step sizes is equivalent to using the same step
  # size for all variables and adding a diagonal mass matrix in the
  # kinetic energy term of the Hamiltonian being integrated. This is
  # hinted at by Neal (2011) but not derived in detail there.
  #
  # Let x and v be position and momentum variables respectively.
  # Let g(x) be the gradient of `target_log_prob_fn(x)`.
  # Let S be a diagonal matrix of per-variable step sizes.
  # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2.
  #
  # Using per-variable step sizes gives the updates
  # v'  = v  + 0.5 * matmul(S, g(x))
  # x'' = x  + matmul(S, v')
  # v'' = v' + 0.5 * matmul(S, g(x''))
  #
  # Let u = matmul(inv(S), v).
  # Multiplying v by inv(S) in the updates above gives the transformed dynamics
  # u'  = matmul(inv(S), v')  = matmul(inv(S), v) + 0.5 * g(x)
  #                           = u + 0.5 * g(x)
  # x'' = x + matmul(S, v') = x + matmul(S**2, u')
  # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'')
  #                           = u' + 0.5 * g(x'')
  #
  # These are exactly the leapfrog updates for the Hamiltonian
  # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u
  #          = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v).
  #
  # To summarize:
  #
  # * Using per-variable step sizes implicitly simulates the dynamics
  #   of the Hamiltonian H' (which are energy-conserving in H'). We
  #   keep track of v instead of u, but the underlying dynamics are
  #   the same if we transform back.
  # * The value of the Hamiltonian H'(x, u) is the same as the value
  #   of the original Hamiltonian H(x, v) after we transform back from
  #   u to v.
  # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2).
  #
  # So using per-variable step sizes in HMC will give results that are
  # exactly identical to explicitly using a diagonal mass matrix.

  with tf.name_scope(
      name, 'hmc_leapfrog_integrator_one_step',
      [independent_chain_ndims, step_sizes,
       current_momentum_parts, current_state_parts,
       current_target_log_prob, current_target_log_prob_grad_parts]):

    # Step 1: Update momentum.
    proposed_momentum_parts = [
        v + 0.5 * tf.cast(eps, v.dtype) * g
        for v, eps, g
        in zip(current_momentum_parts,
               step_sizes,
               current_target_log_prob_grad_parts)]

    # Step 2: Update state.
    proposed_state_parts = [
        x + tf.cast(eps, v.dtype) * v
        for x, eps, v
        in zip(current_state_parts,
               step_sizes,
               proposed_momentum_parts)]

    if state_gradients_are_stopped:
      proposed_state_parts = [tf.stop_gradient(x) for x in proposed_state_parts]

    # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state.
    [
        proposed_target_log_prob,
        proposed_target_log_prob_grad_parts,
    ] = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn,
        proposed_state_parts)

    if not proposed_target_log_prob.dtype.is_floating:
      raise TypeError('`target_log_prob_fn` must produce a `Tensor` '
                      'with `float` `dtype`.')

    if any(g is None for g in proposed_target_log_prob_grad_parts):
      raise ValueError(
          'Encountered `None` gradient. Does your target `target_log_prob_fn` '
          'access all `tf.Variable`s via `tf.get_variable`?\n'
          '  current_state_parts: {}\n'
          '  proposed_state_parts: {}\n'
          '  proposed_target_log_prob_grad_parts: {}'.format(
              current_state_parts,
              proposed_state_parts,
              proposed_target_log_prob_grad_parts))

    # Step 3b: Update momentum (again).
    proposed_momentum_parts = [
        v + 0.5 * tf.cast(eps, v.dtype) * g
        for v, eps, g
        in zip(proposed_momentum_parts,
               step_sizes,
               proposed_target_log_prob_grad_parts)]

    return [
        proposed_momentum_parts,
        proposed_state_parts,
        proposed_target_log_prob,
        proposed_target_log_prob_grad_parts,
    ]
Ejemplo n.º 17
0
def _leapfrog_integrator(current_momentums,
                         target_log_prob_fn,
                         current_state_parts,
                         step_sizes,
                         num_leapfrog_steps,
                         current_target_log_prob,
                         current_grads_target_log_prob,
                         name=None):
    """Applies `num_leapfrog_steps` of the leapfrog integrator.

  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.

  #### Examples:

  ##### Simple quadratic potential.

  ```python
  import matplotlib.pyplot as plt
  %matplotlib inline
  import numpy as np
  import tensorflow as tf
  from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator
  tfd = tf.contrib.distributions

  dims = 10
  num_iter = int(1e3)
  dtype = np.float32

  position = tf.placeholder(np.float32)
  momentum = tf.placeholder(np.float32)

  target_log_prob_fn = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype)).log_prob
  current_target_log_prob = target_log_prob_fn(position)
  current_grads_target_log_prob = tf.gradients(
      current_target_log_prob, position)

  [
      next_momentums,
      next_positions,
  ] = _leapfrog_integrator(
      current_momentums=[momentum],
      target_log_prob_fn=tfd.MultivariateNormalDiag(
          loc=tf.zeros(dims, dtype)).log_prob,
      current_state_parts=[position],
      step_sizes=[0.1],
      num_leapfrog_steps=3,
      current_target_log_prob=current_target_log_prob,
      current_grads_target_log_prob=current_grads_target_log_prob)[:2]

  momentum_ = np.random.randn(dims).astype(dtype)
  position_ = np.random.randn(dims).astype(dtype)
  positions = np.zeros([num_iter, dims], dtype)

  with tf.Session() as sess:
    for i in xrange(num_iter):
      position_, momentum_ = sess.run(
          [next_momentums[0], next_positions[0]],
          feed_dict={position: position_, momentum: momentum_})
      positions[i] = position_

  plt.plot(positions[:, 0]);  # Sinusoidal.
  ```

  Args:
    current_momentums: Tensor containing the value(s) of the momentum
      variable(s) to update.
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    step_sizes: Python `list` of `Tensor`s representing the step size for the
      leapfrog integrator. Must broadcast with the shape of
      `current_state_parts`.  Larger step sizes lead to faster progress, but
      too-large step sizes make rejection exponentially more likely. When
      possible, it's often helpful to match per-variable step sizes to the
      standard deviations of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    current_grads_target_log_prob: Python list of `Tensor`s representing
      gradient of `target_log_prob_fn(*current_state_parts`) wrt
      `current_state_parts`. Must have same shape as `current_state_parts`. The
      only reason to specify this argument is to reduce TF graph size.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'hmc_leapfrog_integrator').

  Returns:
    proposed_momentums: Updated value of the momentum.
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt
      `next_state`.

  Raises:
    ValueError: if `len(momentums) != len(state_parts)`.
    ValueError: if `len(state_parts) != len(step_sizes)`.
    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
    TypeError: if `not target_log_prob.dtype.is_floating`.
  """

    # Note on per-variable step sizes:
    #
    # Using per-variable step sizes is equivalent to using the same step
    # size for all variables and adding a diagonal mass matrix in the
    # kinetic energy term of the Hamiltonian being integrated. This is
    # hinted at by Neal (2011) but not derived in detail there.
    #
    # Let x and v be position and momentum variables respectively.
    # Let g(x) be the gradient of `target_log_prob_fn(x)`.
    # Let S be a diagonal matrix of per-variable step sizes.
    # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2.
    #
    # Using per-variable step sizes gives the updates
    # v'  = v  + 0.5 * matmul(S, g(x))
    # x'' = x  + matmul(S, v')
    # v'' = v' + 0.5 * matmul(S, g(x''))
    #
    # Let u = matmul(inv(S), v).
    # Multiplying v by inv(S) in the updates above gives the transformed dynamics
    # u'  = matmul(inv(S), v')  = matmul(inv(S), v) + 0.5 * g(x)
    #                           = u + 0.5 * g(x)
    # x'' = x + matmul(S, v') = x + matmul(S**2, u')
    # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'')
    #                           = u' + 0.5 * g(x'')
    #
    # These are exactly the leapfrog updates for the Hamiltonian
    # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u
    #          = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v).
    #
    # To summarize:
    #
    # * Using per-variable step sizes implicitly simulates the dynamics
    #   of the Hamiltonian H' (which are energy-conserving in H'). We
    #   keep track of v instead of u, but the underlying dynamics are
    #   the same if we transform back.
    # * The value of the Hamiltonian H'(x, u) is the same as the value
    #   of the original Hamiltonian H(x, v) after we transform back from
    #   u to v.
    # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2).
    #
    # So using per-variable step sizes in HMC will give results that are
    # exactly identical to explicitly using a diagonal mass matrix.
    def _loop_body(
            step,
            current_momentums,
            current_state_parts,
            ignore_current_target_log_prob,  # pylint: disable=unused-argument
            current_grads_target_log_prob):
        return [step + 1] + list(
            _leapfrog_step(current_momentums, target_log_prob_fn,
                           current_state_parts, step_sizes,
                           current_grads_target_log_prob))

    with tf.name_scope(name, 'hmc_leapfrog_integrator', [
            current_momentums, current_state_parts, step_sizes,
            num_leapfrog_steps, current_target_log_prob,
            current_grads_target_log_prob
    ]):
        if len(current_momentums) != len(current_state_parts):
            raise ValueError(
                '`momentums` must be in one-to-one correspondence '
                'with `state_parts`')
        num_leapfrog_steps = tf.convert_to_tensor(num_leapfrog_steps,
                                                  name='num_leapfrog_steps')
        [
            current_target_log_prob,
            current_grads_target_log_prob,
        ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn,
                                              current_state_parts,
                                              current_target_log_prob,
                                              current_grads_target_log_prob)
        return tf.while_loop(
            cond=lambda iter_, *args: iter_ < num_leapfrog_steps,
            body=_loop_body,
            loop_vars=[
                np.int32(0),  # iter_
                current_momentums,
                current_state_parts,
                current_target_log_prob,
                current_grads_target_log_prob,
            ],
            back_prop=False)[1:]  # Lop-off "iter_".
Ejemplo n.º 18
0
 def testGradientWorksDespiteBijectorCaching(self):
     x = tf.constant(2.)
     fn_result, grads = maybe_call_fn_and_grads(
         lambda x_: tfd.LogNormal(loc=0., scale=1.).log_prob(x_), x)
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])
Ejemplo n.º 19
0
 def testGradientWorksDespiteBijectorCaching(self):
     d = tfd.LogNormal(loc=0., scale=1.)
     x = tf.constant(2.)
     fn_result, grads = maybe_call_fn_and_grads(lambda x: d.log_prob(x), x)  # pylint: disable=unnecessary-lambda
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])