Пример #1
0
    def testGradientNumBodyCalls(self):
        counter = collections.Counter()

        dtype = np.float32

        def fn(x, y):
            counter['body_calls'] += 1
            return x**2 + y**2

        fn_args = [dtype(3), dtype(3)]
        # Convert function input to a list of tensors.
        fn_args = [tf.convert_to_tensor(value=arg) for arg in fn_args]
        util.maybe_call_fn_and_grads(fn, fn_args)
        expected_num_calls = 1 if JAX_MODE or not tf.executing_eagerly() else 2
        self.assertEqual(expected_num_calls, counter['body_calls'])
Пример #2
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None, but if `store_parameters_in_results`
    # is true, then `momentum_distribution` defaults to an empty list
    if momentum_distribution is None or isinstance(momentum_distribution,
                                                   list):
        batch_rank = ps.rank(target_log_prob)

        def _batched_isotropic_normal_like(state_part):
            event_ndims = ps.rank(state_part) - batch_rank
            return independent.Independent(
                normal.Normal(ps.zeros_like(state_part, tf.float32), 1.),
                reinterpreted_batch_ndims=event_ndims)

        momentum_distribution = jds.JointDistributionSequential([
            _batched_isotropic_normal_like(state_part)
            for state_part in state_parts
        ])

    # The momentum will get "maybe listified" to zip with the state parts,
    # and this step makes sure that the momentum distribution will have the
    # same "maybe listified" underlying shape.
    if not mcmc_util.is_list_like(momentum_distribution.dtype):
        momentum_distribution = jds.JointDistributionSequential(
            [momentum_distribution])

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
Пример #3
0
 def _process_args(self, momentum_parts, state_parts, target,
                   target_grad_parts):
     """Sanitize inputs to `__call__`."""
     with tf.name_scope('process_args'):
         momentum_parts = [
             tf.convert_to_tensor(v,
                                  dtype_hint=tf.float32,
                                  name='momentum_parts')
             for v in momentum_parts
         ]
         state_parts = [
             tf.convert_to_tensor(v,
                                  dtype_hint=tf.float32,
                                  name='state_parts') for v in state_parts
         ]
         if target is None or target_grad_parts is None:
             [target, target_grad_parts
              ] = mcmc_util.maybe_call_fn_and_grads(self.target_fn,
                                                    state_parts)
         else:
             target = tf.convert_to_tensor(target,
                                           dtype_hint=tf.float32,
                                           name='target')
             target_grad_parts = [
                 tf.convert_to_tensor(g,
                                      dtype_hint=tf.float32,
                                      name='target_grad_part')
                 for g in target_grad_parts
             ]
         return momentum_parts, state_parts, target, target_grad_parts
Пример #4
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        target_log_prob,
        grads_target_log_prob,
    ]
Пример #5
0
 def testGradientWorksForMultivariateNormalTriL(self):
     # TODO(b/72831017): Remove this once bijector cacheing is fixed for
     # graph mode.
     if not tf.executing_eagerly():
         self.skipTest('Gradients get None values in graph mode.')
     d = tfd.MultivariateNormalTriL(scale_tril=tf.eye(2))
     x = d.sample(seed=test_util.test_seed())
     fn_result, grads = util.maybe_call_fn_and_grads(d.log_prob, x)
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])
Пример #6
0
  def testNoGradientsNiceError(self):
    dtype = np.float32

    def fn(x, y):
      return x**2 + tf.stop_gradient(y)**2

    fn_args = [dtype(3), dtype(3)]
    # Convert function input to a list of tensors
    fn_args = [
        tf.convert_to_tensor(value=arg, name='arg{}'.format(i))
        for i, arg in enumerate(fn_args)
    ]
    if tf.executing_eagerly():
      with self.assertRaisesRegexp(
          ValueError, 'Encountered `None`.*\n.*fn_arg_list.*\n.*None'):
        util.maybe_call_fn_and_grads(fn, fn_args)
    else:
      with self.assertRaisesRegexp(
          ValueError, 'Encountered `None`.*\n.*fn_arg_list.*arg1.*\n.*None'):
        util.maybe_call_fn_and_grads(fn, fn_args)
Пример #7
0
  def testGradientComputesCorrectly(self):
    dtype = np.float32
    def fn(x, y):
      return x**2 + y**2

    fn_args = [dtype(3), dtype(3)]
    # Convert function input to a list of tensors
    fn_args = [tf.convert_to_tensor(value=arg) for arg in fn_args]
    fn_result, grads = util.maybe_call_fn_and_grads(fn, fn_args)
    fn_result_, grads_ = self.evaluate([fn_result, grads])
    self.assertNear(18., fn_result_, err=1e-5)
    for grad in grads_:
      self.assertAllClose(grad, dtype(6), atol=0., rtol=1e-5)
Пример #8
0
 def bootstrap_results(self, init_state):
     with tf.name_scope(
             mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')):
         init_state, _ = mcmc_util.prepare_state_parts(init_state)
         if self.state_gradients_are_stopped:
             init_state = [tf.stop_gradient(x) for x in init_state]
         [
             init_target_log_prob,
             init_grads_target_log_prob,
         ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn,
                                               init_state)
         if self._store_parameters_in_results:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 initial_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 final_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 # TODO(b/142590314): Try to use the following code once we commit to
                 # a tensorization policy.
                 # step_size=mcmc_util.prepare_state_parts(
                 #    self.step_size,
                 #    dtype=init_target_log_prob.dtype,
                 #    name='step_size')[0],
                 step_size=tf.nest.map_structure(
                     lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                         x,
                         dtype=init_target_log_prob.dtype,
                         name='step_size'),
                     self.step_size),
                 num_leapfrog_steps=tf.convert_to_tensor(
                     self.num_leapfrog_steps,
                     dtype=tf.int32,
                     name='num_leapfrog_steps'))
         else:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 initial_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 final_momentum=tf.nest.map_structure(
                     tf.zeros_like, init_state),
                 step_size=[],
                 num_leapfrog_steps=[])
Пример #9
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      state_parts, _ = mcmc_util.prepare_state_parts(init_state,
                                                     name='current_state')
      current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
          self.target_log_prob_fn, state_parts)
      # Confirm that the step size is compatible with the state parts.
      _ = _prepare_step_size(
          self.step_size, current_target_log_prob.dtype, len(init_state))
      momentum_distribution = self.momentum_distribution
      if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts, ps.shape(current_target_log_prob),
            shard_axis_names=self.experimental_shard_axis_names)
      momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
          momentum_distribution, ps.shape(current_target_log_prob))
      momentum_parts = momentum_distribution.sample(seed=samplers.zeros_seed())

      return PreconditionedNUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          step_size=tf.nest.map_structure(
              lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                  x,
                  dtype=current_target_log_prob.dtype,
                  name='step_size'),
              self.step_size),
          log_accept_ratio=tf.zeros_like(
              current_target_log_prob, name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(
              current_target_log_prob,
              dtype=TREE_COUNT_DTYPE,
              name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='is_accepted'),
          reach_max_depth=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='reach_max_depth'),
          has_divergence=tf.zeros_like(
              current_target_log_prob, dtype=tf.bool, name='has_divergence'),
          energy=compute_hamiltonian(current_target_log_prob, momentum_parts,
                                     momentum_distribution),
          momentum_distribution=momentum_distribution,
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Пример #10
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False,
                  experimental_shard_axis_names=None):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
    if state_gradients_are_stopped:
        state_parts = [tf.stop_gradient(x) for x in state_parts]
    target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn, state_parts, target_log_prob,
        grads_target_log_prob)
    step_sizes, _ = mcmc_util.prepare_state_parts(step_size,
                                                  dtype=target_log_prob.dtype,
                                                  name='step_size')

    # Default momentum distribution is None
    if momentum_distribution is None:
        momentum_distribution = pu.make_momentum_distribution(
            state_parts,
            ps.shape(target_log_prob),
            shard_axis_names=experimental_shard_axis_names)
    momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
        momentum_distribution, ps.shape(target_log_prob))

    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    def maybe_flatten(x):
        return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]

    return [
        maybe_flatten(state_parts),
        maybe_flatten(step_sizes),
        momentum_distribution,
        target_log_prob,
        grads_target_log_prob,
    ]
Пример #11
0
def hmc_like_proposed_velocity_getter_fn(kernel_results):
    """Getter for `proposed_velocity` so it can be inspected."""
    final_momentum = unnest.get_innermost(kernel_results, 'final_momentum')
    proposed_state = unnest.get_innermost(kernel_results, 'proposed_state')

    momentum_distribution = unnest.get_innermost(kernel_results,
                                                 'momentum_distribution',
                                                 default=None)
    if momentum_distribution is None:
        proposed_velocity = final_momentum
    else:
        momentum_log_prob = getattr(momentum_distribution,
                                    '_log_prob_unnormalized',
                                    momentum_distribution.log_prob)
        kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)
        _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads(
            kinetic_energy_fn, final_momentum)
    # proposed_velocity has the wrong structure when state is a scalar.
    return tf.nest.pack_sequence_as(proposed_state,
                                    tf.nest.flatten(proposed_velocity))
Пример #12
0
  def _integrator_conserves_energy(self, x, independent_chain_ndims):
    event_dims = tf.range(independent_chain_ndims, tf.rank(x))

    m = tf.random.normal(tf.shape(input=x))
    log_prob_0, grad_0 = maybe_call_fn_and_grads(
        lambda x: self._log_gamma_log_prob(x, event_dims),
        x)
    old_energy = -log_prob_0 + 0.5 * tf.reduce_sum(
        input_tensor=m**2., axis=event_dims)

    x_shape = self.evaluate(x).shape
    event_size = np.prod(x_shape[independent_chain_ndims:])
    step_size = tf.constant(0.1 / event_size, x.dtype)
    hmc_lf_steps = tf.constant(1000, np.int32)

    def leapfrog_one_step(*args):
      return _leapfrog_integrator_one_step(
          lambda x: self._log_gamma_log_prob(x, event_dims),
          independent_chain_ndims,
          [step_size],
          *args)

    [[new_m], _, log_prob_1, _] = tf.while_loop(
        cond=lambda *args: True,
        body=leapfrog_one_step,
        loop_vars=[
            [m],         # current_momentum_parts
            [x],         # current_state_parts,
            log_prob_0,  # current_target_log_prob
            grad_0,      # current_target_log_prob_grad_parts
        ],
        maximum_iterations=hmc_lf_steps)

    new_energy = -log_prob_1 + 0.5 * tf.reduce_sum(
        input_tensor=new_m**2., axis=event_dims)

    old_energy_, new_energy_ = self.evaluate([old_energy, new_energy])
    tf.compat.v1.logging.vlog(
        1, 'average energy relative change: {}'.format(
            (1. - new_energy_ / old_energy_).mean()))
    self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
Пример #13
0
 def bootstrap_results(self, init_state):
     with tf.compat.v2.name_scope(
             mcmc_util.make_name(self.name, 'hmc', 'bootstrap_results')):
         if not mcmc_util.is_list_like(init_state):
             init_state = [init_state]
         if self.state_gradients_are_stopped:
             init_state = [tf.stop_gradient(x) for x in init_state]
         else:
             init_state = [
                 tf.convert_to_tensor(value=x) for x in init_state
             ]
         [
             init_target_log_prob,
             init_grads_target_log_prob,
         ] = mcmc_util.maybe_call_fn_and_grads(self.target_log_prob_fn,
                                               init_state)
         if self._store_parameters_in_results:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=tf.nest.map_structure(
                     lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
                         value=x,
                         dtype=init_target_log_prob.dtype,
                         name='step_size'),
                     self.step_size),
                 num_leapfrog_steps=tf.convert_to_tensor(
                     value=self.num_leapfrog_steps,
                     dtype=tf.int32,
                     name='num_leapfrog_steps'))
         else:
             return UncalibratedHamiltonianMonteCarloKernelResults(
                 log_acceptance_correction=tf.zeros_like(
                     init_target_log_prob),
                 target_log_prob=init_target_log_prob,
                 grads_target_log_prob=init_grads_target_log_prob,
                 step_size=[],
                 num_leapfrog_steps=[])
def _one_step(target_fn, step_sizes, get_velocity_parts,
              half_next_momentum_parts, state_parts, target,
              target_grad_parts):
    """Body of integrator while loop."""
    with tf.name_scope('leapfrog_integrate_one_step'):

        velocity_parts = get_velocity_parts(half_next_momentum_parts)
        next_state_parts = []
        for state_part, eps, velocity_part in zip(state_parts, step_sizes,
                                                  velocity_parts):
            next_state_parts.append(state_part +
                                    tf.cast(eps, state_part.dtype) *
                                    tf.cast(velocity_part, state_part.dtype))
        [next_target, next_target_grad_parts
         ] = mcmc_util.maybe_call_fn_and_grads(target_fn, next_state_parts)
        if any(g is None for g in next_target_grad_parts):
            raise ValueError('Encountered `None` gradient.\n'
                             '  state_parts: {}\n'
                             '  next_state_parts: {}\n'
                             '  next_target_grad_parts: {}'.format(
                                 state_parts, next_state_parts,
                                 next_target_grad_parts))

        tensorshape_util.set_shape(next_target, target.shape)
        for ng, g in zip(next_target_grad_parts, target_grad_parts):
            tensorshape_util.set_shape(ng, g.shape)

        next_half_next_momentum_parts = [
            v + tf.cast(eps, v.dtype) * tf.cast(g, v.dtype)  # pylint: disable=g-complex-comprehension
            for v, eps, g in zip(half_next_momentum_parts, step_sizes,
                                 next_target_grad_parts)
        ]

        return [
            next_half_next_momentum_parts,
            next_state_parts,
            next_target,
            next_target_grad_parts,
        ]
Пример #15
0
    def _one_step(self, half_next_momentum_parts, state_parts, target,
                  target_grad_parts):
        """Body of integrator while loop."""
        with tf.name_scope('leapfrog_integrate_one_step'):
            next_state_parts = [
                x + tf.cast(eps, x.dtype) * tf.cast(v, x.dtype)  # pylint: disable=g-complex-comprehension
                for x, eps, v in zip(state_parts, self.step_sizes,
                                     half_next_momentum_parts)
            ]

            [next_target, next_target_grad_parts
             ] = mcmc_util.maybe_call_fn_and_grads(self.target_fn,
                                                   next_state_parts)
            if any(g is None for g in next_target_grad_parts):
                raise ValueError('Encountered `None` gradient.\n'
                                 '  state_parts: {}\n'
                                 '  next_state_parts: {}\n'
                                 '  next_target_grad_parts: {}'.format(
                                     state_parts, next_state_parts,
                                     next_target_grad_parts))

            next_target.set_shape(target.shape)
            for ng, g in zip(next_target_grad_parts, target_grad_parts):
                ng.set_shape(g.shape)

            next_half_next_momentum_parts = [
                v + tf.cast(eps, v.dtype) * tf.cast(g, v.dtype)  # pylint: disable=g-complex-comprehension
                for v, eps, g in zip(half_next_momentum_parts, self.step_sizes,
                                     next_target_grad_parts)
            ]

            return [
                next_half_next_momentum_parts,
                next_state_parts,
                next_target,
                next_target_grad_parts,
            ]
Пример #16
0
def _prepare_args(target_log_prob_fn,
                  volatility_fn,
                  state,
                  step_size,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  volatility=None,
                  grads_volatility_fn=None,
                  diffusion_drift=None,
                  parallel_iterations=10):
    """Helper which processes input args to meet list-like assumptions."""
    state_parts = list(state) if mcmc_util.is_list_like(state) else [state]

    [
        target_log_prob,
        grads_target_log_prob,
    ] = mcmc_util.maybe_call_fn_and_grads(target_log_prob_fn, state_parts,
                                          target_log_prob,
                                          grads_target_log_prob)
    [
        volatility_parts,
        grads_volatility,
    ] = _maybe_call_volatility_fn_and_grads(volatility_fn, state_parts,
                                            volatility, grads_volatility_fn,
                                            ps.shape(target_log_prob),
                                            parallel_iterations)

    step_sizes = (list(step_size)
                  if mcmc_util.is_list_like(step_size) else [step_size])
    step_sizes = [
        tf.convert_to_tensor(value=s,
                             name='step_size',
                             dtype=target_log_prob.dtype) for s in step_sizes
    ]
    if len(step_sizes) == 1:
        step_sizes *= len(state_parts)
    if len(state_parts) != len(step_sizes):
        raise ValueError(
            'There should be exactly one `step_size` or it should '
            'have same length as `current_state`.')

    if diffusion_drift is None:
        diffusion_drift_parts = _get_drift(step_sizes, volatility_parts,
                                           grads_volatility,
                                           grads_target_log_prob)
    else:
        diffusion_drift_parts = (list(diffusion_drift)
                                 if mcmc_util.is_list_like(diffusion_drift)
                                 else [diffusion_drift])
        if len(state_parts) != len(diffusion_drift):
            raise ValueError(
                'There should be exactly one `diffusion_drift` or it '
                'should have same length as list-like `current_state`.')

    return [
        state_parts,
        step_sizes,
        target_log_prob,
        grads_target_log_prob,
        volatility_parts,
        grads_volatility,
        diffusion_drift_parts,
    ]
Пример #17
0
def _prepare_args(target_log_prob_fn,
                  state,
                  step_size,
                  momentum_distribution,
                  target_log_prob=None,
                  grads_target_log_prob=None,
                  maybe_expand=False,
                  state_gradients_are_stopped=False):
  """Helper which processes input args to meet list-like assumptions."""
  state_parts, _ = mcmc_util.prepare_state_parts(state, name='current_state')
  if state_gradients_are_stopped:
    state_parts = [tf.stop_gradient(x) for x in state_parts]
  target_log_prob, grads_target_log_prob = mcmc_util.maybe_call_fn_and_grads(
      target_log_prob_fn, state_parts, target_log_prob, grads_target_log_prob)
  step_sizes, _ = mcmc_util.prepare_state_parts(
      step_size, dtype=target_log_prob.dtype, name='step_size')

  # Default momentum distribution is None, but if `store_parameters_in_results`
  # is true, then `momentum_distribution` defaults to DefaultStandardNormal().
  if (momentum_distribution is None or
      isinstance(momentum_distribution, DefaultStandardNormal)):
    batch_rank = ps.rank(target_log_prob)
    def _batched_isotropic_normal_like(state_part):
      return sample.Sample(
          normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.),
          ps.shape(state_part)[batch_rank:])

    momentum_distribution = jds.JointDistributionSequential(
        [_batched_isotropic_normal_like(state_part)
         for state_part in state_parts])

  # The momentum will get "maybe listified" to zip with the state parts,
  # and this step makes sure that the momentum distribution will have the
  # same "maybe listified" underlying shape.
  if not mcmc_util.is_list_like(momentum_distribution.dtype):
    momentum_distribution = jds.JointDistributionSequential(
        [momentum_distribution])

  # If all underlying distributions are independent, we can offer some help.
  # This code will also trigger for the output of the two blocks above.
  if (isinstance(momentum_distribution, jds.JointDistributionSequential) and
      not any(callable(dist_fn) for dist_fn in momentum_distribution.model)):
    batch_shape = ps.shape(target_log_prob)
    momentum_distribution = momentum_distribution.copy(model=[
        batch_broadcast.BatchBroadcast(md, to_shape=batch_shape)
        for md in momentum_distribution.model
    ])

  if len(step_sizes) == 1:
    step_sizes *= len(state_parts)
  if len(state_parts) != len(step_sizes):
    raise ValueError('There should be exactly one `step_size` or it should '
                     'have same length as `current_state`.')
  def maybe_flatten(x):
    return x if maybe_expand or mcmc_util.is_list_like(state) else x[0]
  return [
      maybe_flatten(state_parts),
      maybe_flatten(step_sizes),
      momentum_distribution,
      target_log_prob,
      grads_target_log_prob,
  ]
Пример #18
0
def _leapfrog_integrator_one_step(
    target_log_prob_fn,
    independent_chain_ndims,
    step_sizes,
    current_momentum_parts,
    current_state_parts,
    current_target_log_prob,
    current_target_log_prob_grad_parts,
    state_gradients_are_stopped=False,
    name=None):
  """Applies `num_leapfrog_steps` of the leapfrog integrator.

  Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`.

  #### Examples:

  ##### Simple quadratic potential.

  ```python
  import matplotlib.pyplot as plt
  %matplotlib inline
  import numpy as np
  import tensorflow as tf
  from tensorflow_probability.python.mcmc.hmc import _leapfrog_integrator_one_step  # pylint: disable=line-too-long
  tfd = tfp.distributions

  dims = 10
  num_iter = int(1e3)
  dtype = np.float32

  position = tf.placeholder(np.float32)
  momentum = tf.placeholder(np.float32)

  target_log_prob_fn = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype)).log_prob

  def _leapfrog_one_step(*args):
    # Closure representing computation done during each leapfrog step.
    return _leapfrog_integrator_one_step(
        target_log_prob_fn=target_log_prob_fn,
        independent_chain_ndims=0,
        step_sizes=[0.1],
        current_momentum_parts=args[0],
        current_state_parts=args[1],
        current_target_log_prob=args[2],
        current_target_log_prob_grad_parts=args[3])

  # Do leapfrog integration.
  [
      [next_momentum],
      [next_position],
      next_target_log_prob,
      next_target_log_prob_grad_parts,
  ] = tf.while_loop(
      cond=lambda *args: True,
      body=_leapfrog_one_step,
      loop_vars=[
        [momentum],
        [position],
        target_log_prob_fn(position),
        tf.gradients(target_log_prob_fn(position), position),
      ],
      maximum_iterations=3)

  momentum_ = np.random.randn(dims).astype(dtype)
  position_ = np.random.randn(dims).astype(dtype)
  positions = np.zeros([num_iter, dims], dtype)

  with tf.Session() as sess:
    for i in xrange(num_iter):
      position_, momentum_ = sess.run(
          [next_momentum, next_position],
          feed_dict={position: position_, momentum: momentum_})
      positions[i] = position_

  plt.plot(positions[:, 0]);  # Sinusoidal.
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    step_sizes: Python `list` of `Tensor`s representing the step size for the
      leapfrog integrator. Must broadcast with the shape of
      `current_state_parts`.  Larger step sizes lead to faster progress, but
      too-large step sizes make rejection exponentially more likely. When
      possible, it's often helpful to match per-variable step sizes to the
      standard deviations of the target distribution in each variable.
    current_momentum_parts: Tensor containing the value(s) of the momentum
      variable(s) to update.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    current_target_log_prob_grad_parts: Python list of `Tensor`s representing
      gradient of `target_log_prob_fn(*current_state_parts`) wrt
      `current_state_parts`. Must have same shape as `current_state_parts`. The
      only reason to specify this argument is to reduce TF graph size.
    state_gradients_are_stopped: Python `bool` indicating that the proposed new
      state be run through `tf.stop_gradient`. This is particularly useful when
      combining optimization over samples from the HMC chain.
      Default value: `False` (i.e., do not apply `stop_gradient`).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'hmc_leapfrog_integrator').

  Returns:
    proposed_momentum_parts: Updated value of the momentum.
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    proposed_target_log_prob_grad_parts: Gradient of `proposed_target_log_prob`
      wrt `next_state`.

  Raises:
    ValueError: if `len(momentum_parts) != len(state_parts)`.
    ValueError: if `len(state_parts) != len(step_sizes)`.
    ValueError: if `len(state_parts) != len(grads_target_log_prob)`.
    TypeError: if `not target_log_prob.dtype.is_floating`.
  """
  # Note on per-variable step sizes:
  #
  # Using per-variable step sizes is equivalent to using the same step
  # size for all variables and adding a diagonal mass matrix in the
  # kinetic energy term of the Hamiltonian being integrated. This is
  # hinted at by Neal (2011) but not derived in detail there.
  #
  # Let x and v be position and momentum variables respectively.
  # Let g(x) be the gradient of `target_log_prob_fn(x)`.
  # Let S be a diagonal matrix of per-variable step sizes.
  # Let the Hamiltonian H(x, v) = -target_log_prob_fn(x) + 0.5 * ||v||**2.
  #
  # Using per-variable step sizes gives the updates
  # v'  = v  + 0.5 * matmul(S, g(x))
  # x'' = x  + matmul(S, v')
  # v'' = v' + 0.5 * matmul(S, g(x''))
  #
  # Let u = matmul(inv(S), v).
  # Multiplying v by inv(S) in the updates above gives the transformed dynamics
  # u'  = matmul(inv(S), v')  = matmul(inv(S), v) + 0.5 * g(x)
  #                           = u + 0.5 * g(x)
  # x'' = x + matmul(S, v') = x + matmul(S**2, u')
  # u'' = matmul(inv(S), v'') = matmul(inv(S), v') + 0.5 * g(x'')
  #                           = u' + 0.5 * g(x'')
  #
  # These are exactly the leapfrog updates for the Hamiltonian
  # H'(x, u) = -target_log_prob_fn(x) + 0.5 * u^T S**2 u
  #          = -target_log_prob_fn(x) + 0.5 * ||v||**2 = H(x, v).
  #
  # To summarize:
  #
  # * Using per-variable step sizes implicitly simulates the dynamics
  #   of the Hamiltonian H' (which are energy-conserving in H'). We
  #   keep track of v instead of u, but the underlying dynamics are
  #   the same if we transform back.
  # * The value of the Hamiltonian H'(x, u) is the same as the value
  #   of the original Hamiltonian H(x, v) after we transform back from
  #   u to v.
  # * Sampling v ~ N(0, I) is equivalent to sampling u ~ N(0, S**-2).
  #
  # So using per-variable step sizes in HMC will give results that are
  # exactly identical to explicitly using a diagonal mass matrix.

  with tf.compat.v1.name_scope(name, 'hmc_leapfrog_integrator_one_step', [
      independent_chain_ndims, step_sizes, current_momentum_parts,
      current_state_parts, current_target_log_prob,
      current_target_log_prob_grad_parts
  ]):

    # Step 1: Update momentum.
    proposed_momentum_parts = [
        v + 0.5 * tf.cast(eps, v.dtype) * g
        for v, eps, g
        in zip(current_momentum_parts,
               step_sizes,
               current_target_log_prob_grad_parts)]

    # Step 2: Update state.
    proposed_state_parts = [
        x + tf.cast(eps, v.dtype) * v
        for x, eps, v
        in zip(current_state_parts,
               step_sizes,
               proposed_momentum_parts)]

    if state_gradients_are_stopped:
      proposed_state_parts = [tf.stop_gradient(x) for x in proposed_state_parts]

    # Step 3a: Re-evaluate target-log-prob (and grad) at proposed state.
    [
        proposed_target_log_prob,
        proposed_target_log_prob_grad_parts,
    ] = mcmc_util.maybe_call_fn_and_grads(
        target_log_prob_fn,
        proposed_state_parts)

    if not proposed_target_log_prob.dtype.is_floating:
      raise TypeError('`target_log_prob_fn` must produce a `Tensor` '
                      'with `float` `dtype`.')

    if any(g is None for g in proposed_target_log_prob_grad_parts):
      raise ValueError(
          'Encountered `None` gradient. Does your target `target_log_prob_fn` '
          'access all `tf.Variable`s via `tf.get_variable`?\n'
          '  current_state_parts: {}\n'
          '  proposed_state_parts: {}\n'
          '  proposed_target_log_prob_grad_parts: {}'.format(
              current_state_parts,
              proposed_state_parts,
              proposed_target_log_prob_grad_parts))

    # Step 3b: Update momentum (again).
    proposed_momentum_parts = [
        v + 0.5 * tf.cast(eps, v.dtype) * g
        for v, eps, g
        in zip(proposed_momentum_parts,
               step_sizes,
               proposed_target_log_prob_grad_parts)]

    return [
        proposed_momentum_parts,
        proposed_state_parts,
        proposed_target_log_prob,
        proposed_target_log_prob_grad_parts,
    ]
Пример #19
0
 def testGradientWorksDespiteBijectorCaching(self):
     x = tf.constant(2.)
     fn_result, grads = util.maybe_call_fn_and_grads(
         lambda x_: tfd.LogNormal(loc=0., scale=1.).log_prob(x_), x)
     self.assertAllEqual(False, fn_result is None)
     self.assertAllEqual([False], [g is None for g in grads])
Пример #20
0
  def _loop_build_sub_tree(self,
                           directions,
                           integrator,
                           current_step_meta_info,
                           iter_,
                           energy_diff_sum_previous,
                           momentum_cumsum_previous,
                           leapfrogs_taken,
                           prev_tree_state,
                           candidate_tree_state,
                           continue_tree_previous,
                           not_divergent_previous,
                           velocity_state_memory,
                           momentum_distribution,
                           seed):
    """Base case in tree doubling."""
    acceptance_seed, next_seed = samplers.split_seed(seed)
    with tf.name_scope('loop_build_sub_tree'):
      # Take one leapfrog step in the direction v and check divergence
      kinetic_energy_fn = get_kinetic_energy_fn(momentum_distribution)
      [
          next_momentum_parts,
          next_state_parts,
          next_target,
          next_target_grad_parts
      ] = integrator(prev_tree_state.momentum,
                     prev_tree_state.state,
                     prev_tree_state.target,
                     prev_tree_state.target_grad_parts,
                     kinetic_energy_fn=kinetic_energy_fn)
      _, next_velocity_parts = mcmc_util.maybe_call_fn_and_grads(
          kinetic_energy_fn, next_momentum_parts)

      next_tree_state = TreeDoublingState(
          momentum=next_momentum_parts,
          velocity=next_velocity_parts,
          state=next_state_parts,
          target=next_target,
          target_grad_parts=next_target_grad_parts)
      momentum_cumsum = [p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                                   next_momentum_parts)]
      # If the tree have not yet terminated previously, we count this leapfrog.
      leapfrogs_taken = tf.where(
          continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken)

      write_instruction = current_step_meta_info.write_instruction
      read_instruction = current_step_meta_info.read_instruction
      init_energy = current_step_meta_info.init_energy

      if GENERALIZED_UTURN:
        state_to_write = momentum_cumsum_previous
        state_to_check = momentum_cumsum
      else:
        state_to_write = next_state_parts
        state_to_check = next_state_parts

      batch_shape = ps.shape(next_target)
      has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool)

      read_index = read_instruction.gather([iter_])[0]
      no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
          read_index,
          directions,
          velocity_state_memory,
          next_velocity_parts,
          state_to_check,
          has_not_u_turn_init,
          log_prob_rank=ps.rank(next_target),
          shard_axis_names=self.experimental_shard_axis_names)

      # Get index to write state into memory swap
      write_index = write_instruction.gather([iter_])
      velocity_state_memory = VelocityStateSwap(
          velocity_swap=[
              _safe_tensor_scatter_nd_update(old, [write_index], [new])
              for old, new in zip(velocity_state_memory.velocity_swap,
                                  next_velocity_parts)
          ],
          state_swap=[
              _safe_tensor_scatter_nd_update(old, [write_index], [new])
              for old, new in zip(velocity_state_memory.state_swap,
                                  state_to_write)
          ])

      energy = compute_hamiltonian(next_target, next_momentum_parts,
                                   momentum_distribution)
      current_energy = tf.where(tf.math.is_nan(energy),
                                tf.constant(-np.inf, dtype=energy.dtype),
                                energy)
      energy_diff = current_energy - init_energy

      if MULTINOMIAL_SAMPLE:
        not_divergent = -energy_diff < self.max_energy_diff
        weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff)
        log_accept_thresh = energy_diff - weight_sum
      else:
        log_slice_sample = current_step_meta_info.log_slice_sample
        not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
        # Uniform sampling on the trajectory within the subtree across valid
        # samples.
        is_valid = log_slice_sample <= energy_diff
        weight_sum = tf.where(is_valid,
                              candidate_tree_state.weight + 1,
                              candidate_tree_state.weight)
        log_accept_thresh = tf.where(
            is_valid,
            -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
            tf.constant(-np.inf, dtype=tf.float32))
      u = tf.math.log1p(-samplers.uniform(
          shape=batch_shape,
          dtype=log_accept_thresh.dtype,
          seed=acceptance_seed))
      is_sample_accepted = u <= log_accept_thresh

      next_candidate_tree_state = TreeDoublingStateCandidate(
          state=[
              bu.where_left_justified_mask(is_sample_accepted, s0, s1)
              for s0, s1 in zip(next_state_parts, candidate_tree_state.state)
          ],
          target=bu.where_left_justified_mask(
              is_sample_accepted, next_target, candidate_tree_state.target),
          target_grad_parts=[
              bu.where_left_justified_mask(is_sample_accepted, grad0, grad1)
              for grad0, grad1 in zip(next_target_grad_parts,
                                      candidate_tree_state.target_grad_parts)
          ],
          energy=bu.where_left_justified_mask(
              is_sample_accepted,
              current_energy,
              candidate_tree_state.energy),
          weight=weight_sum)

      continue_tree = not_divergent & continue_tree_previous
      continue_tree_next = no_u_turns_within_tree & continue_tree

      not_divergent_tokeep = tf.where(
          continue_tree_previous,
          not_divergent,
          ps.ones(batch_shape, dtype=tf.bool))

      # min(1., exp(energy_diff)).
      exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
      energy_diff_sum = tf.where(continue_tree,
                                 energy_diff_sum_previous + exp_energy_diff,
                                 energy_diff_sum_previous)

      return (
          iter_ + 1,
          next_seed,
          energy_diff_sum,
          momentum_cumsum,
          leapfrogs_taken,
          next_tree_state,
          next_candidate_tree_state,
          continue_tree_next,
          not_divergent_previous & not_divergent_tokeep,
          velocity_state_memory,
      )
Пример #21
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            # Padding the step_size so it is compatable with the states
            step_size = self.step_size
            if len(step_size) == 1:
                step_size = step_size * len(init_state)
            if len(step_size) != len(init_state):
                raise ValueError(
                    'Expected either one step size or {} (size of '
                    '`init_state`), but found {}'.format(
                        len(init_state), len(step_size)))
            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
                self.target_log_prob_fn, state_parts)
            momentum_distribution = self.momentum_distribution
            if momentum_distribution is None:
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(current_target_log_prob))
            momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                momentum_distribution, ps.shape(current_target_log_prob))
            momentum_parts = momentum_distribution.sample()

            def _init(shape_and_dtype):
                """Allocate TensorArray for storing state and velocity."""
                return [  # pylint: disable=g-complex-comprehension
                    ps.zeros(ps.concat([[max(self._write_instruction) + 1], s],
                                       axis=0),
                             dtype=d) for (s, d) in shape_and_dtype
                ]

            get_shapes_and_dtypes = lambda x: [
                (ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                for x_ in x
            ]
            velocity_state_memory = VelocityStateSwap(
                velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)),
                state_swap=_init(get_shapes_and_dtypes(init_state)))

            return PreconditionedNUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                velocity_state_memory=velocity_state_memory,
                step_size=step_size,
                log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                               name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                              dtype=TREE_COUNT_DTYPE,
                                              name='leapfrogs_taken'),
                is_accepted=tf.zeros_like(current_target_log_prob,
                                          dtype=tf.bool,
                                          name='is_accepted'),
                reach_max_depth=tf.zeros_like(current_target_log_prob,
                                              dtype=tf.bool,
                                              name='reach_max_depth'),
                has_divergence=tf.zeros_like(current_target_log_prob,
                                             dtype=tf.bool,
                                             name='has_divergence'),
                energy=compute_hamiltonian(current_target_log_prob,
                                           momentum_parts,
                                           momentum_distribution),
                momentum_distribution=momentum_distribution,
                # Allow room for one_step's seed.
                seed=samplers.zeros_seed(),
            )
Пример #22
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
    start_trajectory_seed, loop_seed = samplers.split_seed(seed)

    with tf.name_scope(self.name + '.one_step'):
      unwrap_state_list = not tf.nest.is_nested(current_state)
      if unwrap_state_list:
        current_state = [current_state]
      momentum_distribution = previous_kernel_results.momentum_distribution

      current_target_log_prob = previous_kernel_results.target_log_prob
      [init_momentum,
       init_energy,
       log_slice_sample] = self._start_trajectory_batched(
           current_state,
           current_target_log_prob,
           momentum_distribution=momentum_distribution,
           seed=start_trajectory_seed)

      def _copy(v):
        return v * ps.ones(
            ps.pad(
                [2], paddings=[[0, ps.rank(v)]], constant_values=1),
            dtype=v.dtype)
      _, init_velocity = mcmc_util.maybe_call_fn_and_grads(
          get_kinetic_energy_fn(momentum_distribution),
          [m + 0 for m in init_momentum])  # Breaks cache.

      initial_state = TreeDoublingState(
          momentum=init_momentum,
          velocity=init_velocity,
          state=current_state,
          target=current_target_log_prob,
          target_grad_parts=previous_kernel_results.grads_target_log_prob)
      initial_step_state = tf.nest.map_structure(_copy, initial_state)

      if MULTINOMIAL_SAMPLE:
        init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
      else:
        init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

      candidate_state = TreeDoublingStateCandidate(
          state=current_state,
          target=current_target_log_prob,
          target_grad_parts=previous_kernel_results.grads_target_log_prob,
          energy=init_energy,
          weight=init_weight)

      initial_step_metastate = TreeDoublingMetaState(
          candidate_state=candidate_state,
          is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
          momentum_sum=init_momentum,
          energy_diff_sum=tf.zeros_like(init_energy),
          leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE),
          continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
          not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

      # Convert the write/read instruction into TensorArray so that it is
      # compatible with XLA.
      write_instruction = tf.TensorArray(
          TREE_COUNT_DTYPE,
          size=len(self._write_instruction),
          clear_after_read=False).unstack(self._write_instruction)
      read_instruction = tf.TensorArray(
          tf.int32,
          size=len(self._read_instruction),
          clear_after_read=False).unstack(self._read_instruction)

      current_step_meta_info = OneStepMetaInfo(
          log_slice_sample=log_slice_sample,
          init_energy=init_energy,
          write_instruction=write_instruction,
          read_instruction=read_instruction
          )

      velocity_state_memory = VelocityStateSwap(
          velocity_swap=self.init_velocity_state_memory(init_momentum),
          state_swap=self.init_velocity_state_memory(current_state))

      step_size = _prepare_step_size(
          previous_kernel_results.step_size,
          current_target_log_prob.dtype,
          len(current_state))
      _, _, _, new_step_metastate = tf.while_loop(
          cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
              (iter_ < self.max_tree_depth) &
              tf.reduce_any(metastate.continue_tree)),
          body=lambda iter_, seed, state, metastate: self._loop_tree_doubling(  # pylint: disable=g-long-lambda
              step_size,
              velocity_state_memory,
              current_step_meta_info,
              iter_,
              state,
              metastate,
              momentum_distribution,
              seed),
          loop_vars=(
              tf.zeros([], dtype=tf.int32, name='iter'),
              loop_seed,
              initial_step_state,
              initial_step_metastate),
          parallel_iterations=self.parallel_iterations,
      )

      kernel_results = PreconditionedNUTSKernelResults(
          target_log_prob=new_step_metastate.candidate_state.target,
          grads_target_log_prob=(
              new_step_metastate.candidate_state.target_grad_parts),
          step_size=previous_kernel_results.step_size,
          log_accept_ratio=tf.math.log(
              new_step_metastate.energy_diff_sum /
              tf.cast(new_step_metastate.leapfrog_count,
                      dtype=new_step_metastate.energy_diff_sum.dtype)),
          leapfrogs_taken=(
              new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps
          ),
          is_accepted=new_step_metastate.is_accepted,
          reach_max_depth=new_step_metastate.continue_tree,
          has_divergence=~new_step_metastate.not_divergence,
          energy=new_step_metastate.candidate_state.energy,
          momentum_distribution=momentum_distribution,
          seed=seed,
      )

      result_state = new_step_metastate.candidate_state.state
      if unwrap_state_list:
        result_state = result_state[0]

      return result_state, kernel_results
 def get_velocity_parts(half_next_momentum_parts):
     _, velocity_parts = mcmc_util.maybe_call_fn_and_grads(
         kinetic_energy_fn, half_next_momentum_parts)
     return velocity_parts