Example #1
0
 def test_flat_replace(self):
     results = FakeResults(0)
     self.assertEqual(
         unnest.replace_innermost(results, unique_core=1).unique_core, 1)
     self.assertEqual(
         unnest.replace_innermost(results,
                                  return_unused=True,
                                  unique_core=1,
                                  foo=2), (FakeResults(1), {
                                      'foo': 2
                                  }))
     self.assertEqual(
         unnest.replace_outermost(results, unique_core=2).unique_core, 2)
     self.assertEqual(
         unnest.replace_outermost(results,
                                  return_unused=True,
                                  unique_core=2,
                                  foo=3), (FakeResults(2), {
                                      'foo': 3
                                  }))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_innermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_outermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
Example #2
0
 def test_atypical_nesting_replace(self):
   def build(a, b):
     return FakeAtypicalNestingResults(
         unique_atypical_nesting=a, atypical_inner_results=FakeResults(b))
   results = build(0, 1)
   self.assertEqual(unnest.replace_innermost(
       results, unique_atypical_nesting=2),
                    build(2, 1))
   self.assertEqual(
       unnest.replace_innermost(results, unique_core=2), build(0, 2))
   self.assertEqual(unnest.replace_innermost(
       results, unique_atypical_nesting=2, unique_core=3),
                    build(2, 3))
   self.assertEqual(unnest.replace_innermost(
       results, return_unused=True,
       unique_atypical_nesting=2, unique_core=3, foo=4),
                    (build(2, 3), {'foo': 4}))
   self.assertEqual(unnest.replace_outermost(
       results, unique_atypical_nesting=2),
                    build(2, 1))
   self.assertEqual(
       unnest.replace_outermost(results, unique_core=2), build(0, 2))
   self.assertEqual(unnest.replace_outermost(
       results, unique_atypical_nesting=2, unique_core=3),
                    build(2, 3))
   self.assertEqual(unnest.replace_outermost(
       results, return_unused=True,
       unique_atypical_nesting=2, unique_core=3, foo=4),
                    (build(2, 3), {'foo': 4}))
   self.assertRaises(ValueError, lambda: unnest.replace_innermost(   # pylint: disable=g-long-lambda
       results, unique_core=1, foo=1))
   self.assertRaises(ValueError, lambda: unnest.replace_outermost(   # pylint: disable=g-long-lambda
       results, unique_core=1, foo=1))
Example #3
0
def hmc_like_momentum_distribution_setter_fn(kernel_results, new_distribution):
    """Setter for `momentum_distribution` so it can be adapted."""
    # Note that unnest.replace_innermost has a special path for going into
    # `accepted_results` preferentially, so this will set
    # `accepted_results.momentum_distribution`.
    return unnest.replace_innermost(kernel_results,
                                    momentum_distribution=new_distribution)
Example #4
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            if isinstance(self.initial_running_variance,
                          sample_stats.RunningVariance):
                variance_parts = [self.initial_running_variance]
            else:
                variance_parts = self.initial_running_variance

            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]

            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = _make_momentum_distribution(
                diags, init_state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            return DiagonalMassMatrixAdaptationResults(
                inner_results=inner_results, running_variance=variance_parts)
Example #5
0
def _update_field(kernel_results, field_name, value):
    try:
        return unnest.replace_innermost(kernel_results, **{field_name: value})
    except ValueError:
        msg = _kernel_result_not_implemented_message_template.format(
            kernel_results, field_name)
        raise REMCFieldNotFoundError(msg)
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Runs one iteration of NeuTra.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions index
        independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: Optional seed for reproducible sampling.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.
    """

        step_size = previous_kernel_results.new_step_size
        previous_kernel_results = unnest.replace_innermost(
            previous_kernel_results,
            num_leapfrog_steps=self._num_leapfrog_steps(step_size))

        new_state, kernel_results = self._kernel.one_step(
            self._flatten_state(current_state),
            previous_kernel_results,
            seed=seed)
        return self._unflatten_state(new_state), kernel_results
Example #7
0
  def test_sets_kinetic_energy(self):
    dist = tfd.MultivariateNormalDiag(scale_diag=tf.constant([0.1, 10.]))
    step_size = 0.1
    kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
        target_log_prob_fn=dist.log_prob,
        step_size=step_size,
        num_leapfrog_steps=1,
        store_parameters_in_results=True)
    init_state = tf.constant([0.1, 0.1])
    kr = kernel.bootstrap_results(init_state)

    # Manually set the momentum distribution.
    kr = unnest.replace_innermost(kr, momentum_distribution=dist)

    # Take one leapfrog step using the kernel.
    _, nkr = kernel.one_step(init_state, kr, seed=test_util.test_seed())
    # Need to evaluate here for consistency in graph mode.
    (momentum_parts,
     target_grad_parts,
     proposed_state,
     final_momentum,
     target_log_prob,
     grads_target_log_prob) = self.evaluate([
         nkr.proposed_results.initial_momentum,
         nkr.accepted_results.grads_target_log_prob,
         nkr.proposed_state,
         nkr.proposed_results.final_momentum,
         nkr.proposed_results.target_log_prob,
         nkr.proposed_results.grads_target_log_prob])

    # Take one leapfrog step manually.
    leapfrog = tfp.mcmc.internal.leapfrog_integrator.SimpleLeapfrogIntegrator(
        target_fn=dist.log_prob,
        step_sizes=[step_size],
        num_steps=1)
    # Again, need to evaluate here for graph mode consistency.
    (next_momentum,
     next_state,
     next_target_log_prob,
     grads_next_target_log_prob) = self.evaluate(leapfrog(
         momentum_parts=momentum_parts,
         state_parts=[init_state],
         target=dist.log_prob(init_state),
         target_grad_parts=target_grad_parts,
         kinetic_energy_fn=lambda x: -dist.log_prob(x)))

    # Verify resulting states are the same
    self.assertAllClose(proposed_state,
                        next_state[0])
    self.assertAllClose(final_momentum,
                        next_momentum)
    self.assertAllClose(target_log_prob,
                        next_target_log_prob)
    self.assertAllClose(grads_target_log_prob,
                        grads_next_target_log_prob)
Example #8
0
 def test_deeply_nested_replace(self):
     results = _build_deeply_nested(0, 1, 2, 3, 4)
     self.assertTrue(unnest.replace_innermost(results, unique_nesting=5),
                     _build_deeply_nested(0, 1, 2, 5, 4))
     self.assertTrue(unnest.replace_outermost(results, unique_nesting=5),
                     _build_deeply_nested(5, 1, 2, 3, 4))
     self.assertTrue(
         unnest.replace_innermost(results, unique_atypical_nesting=5),
         _build_deeply_nested(0, 1, 5, 3, 4))
     self.assertTrue(
         unnest.replace_outermost(results, unique_atypical_nesting=5),
         _build_deeply_nested(0, 5, 2, 3, 4))
     self.assertTrue(unnest.replace_innermost(results, unique_core=5),
                     _build_deeply_nested(0, 1, 2, 3, 5))
     self.assertTrue(unnest.replace_outermost(results, unique_core=5),
                     _build_deeply_nested(0, 1, 2, 3, 5))
     self.assertTrue(
         unnest.replace_innermost(results,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7),
         _build_deeply_nested(0, 1, 6, 5, 7))
     self.assertTrue(
         unnest.replace_innermost(results,
                                  return_unused=True,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7,
                                  foo=8),
         (_build_deeply_nested(0, 1, 6, 5, 7), {
             'foo': 8
         }))
     self.assertTrue(
         unnest.replace_outermost(results,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7),
         _build_deeply_nested(5, 6, 2, 3, 7))
     self.assertTrue(
         unnest.replace_outermost(results,
                                  return_unused=True,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7,
                                  foo=8),
         (_build_deeply_nested(5, 6, 2, 3, 7), {
             'foo': 8
         }))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_innermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_outermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
Example #9
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)

            # Bootstrap the results.
            results = self._bootstrap_from_inner_results(
                init_state, inner_results)
            if self.num_estimation_steps is not None:
                # We only update the momentum at the end of adaptation phase,
                # so we do not need to set the momentum here.
                return results

            # Set the momentum.
            diags = [
                variance_part.variance()
                for variance_part in results.running_variance
            ]
            inner_results = results.inner_results
            batch_shape = ps.shape(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = preconditioning_utils.make_momentum_distribution(
                init_state_parts,
                batch_shape,
                diags,
                shard_axis_names=self.experimental_shard_axis_names)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            results = results._replace(inner_results=inner_results)
            return results
Example #10
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'snaper_hamiltonian_monte_carlo',
                                    'one_step')):
            inner_results = previous_kernel_results.inner_results

            batch_shape = ps.shape(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            reduce_axes = ps.range(0, ps.size(batch_shape))
            step = inner_results.step
            state_ema_points = previous_kernel_results.state_ema_points

            kernel = self._make_kernel(
                batch_shape=batch_shape,
                step=step,
                state_ema_points=state_ema_points,
                state=current_state,
                mean=previous_kernel_results.ema_mean,
                variance=previous_kernel_results.ema_variance,
                principal_component=previous_kernel_results.
                ema_principal_component,
            )

            inner_results = unnest.replace_innermost(
                inner_results,
                momentum_distribution=(
                    kernel.inner_kernel.parameters['momentum_distribution']),  # pylint: disable=protected-access
            )

            seed = samplers.sanitize_seed(seed)
            state_parts, inner_results = kernel.one_step(
                tf.nest.flatten(current_state),
                inner_results,
                seed=seed,
            )

            state = tf.nest.pack_sequence_as(current_state, state_parts)

            state_ema_points, ema_mean, ema_variance = self._update_state_ema(
                reduce_axes=reduce_axes,
                state=state,
                step=step,
                state_ema_points=state_ema_points,
                ema_mean=previous_kernel_results.ema_mean,
                ema_variance=previous_kernel_results.ema_variance,
            )

            (principal_component_ema_points,
             ema_principal_component) = self._update_principal_component_ema(
                 reduce_axes=reduce_axes,
                 state=state,
                 step=step,
                 principal_component_ema_points=(
                     previous_kernel_results.principal_component_ema_points),
                 ema_principal_component=(
                     previous_kernel_results.ema_principal_component),
             )

            kernel_results = previous_kernel_results._replace(
                inner_results=inner_results,
                ema_mean=ema_mean,
                ema_variance=ema_variance,
                state_ema_points=state_ema_points,
                ema_principal_component=ema_principal_component,
                principal_component_ema_points=principal_component_ema_points,
                seed=seed,
            )

            return state, kernel_results
Example #11
0
def hmc_like_step_size_setter_fn(kernel_results, new_step_size):
    """Setter for `step_size` so it can be adapted."""
    return unnest.replace_innermost(kernel_results, step_size=new_step_size)
def hmc_like_num_leapfrog_steps_setter_fn(kernel_results,
                                          new_num_leapfrog_steps):
  """Setter for `num_leapfrog_steps` so it can be adapted."""
  return unnest.replace_innermost(
      kernel_results, num_leapfrog_steps=new_num_leapfrog_steps)