Exemplo n.º 1
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            if isinstance(self.initial_running_variance,
                          sample_stats.RunningVariance):
                variance_parts = [self.initial_running_variance]
            else:
                variance_parts = self.initial_running_variance

            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]

            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = _make_momentum_distribution(
                diags, init_state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            return DiagonalMassMatrixAdaptationResults(
                inner_results=inner_results, running_variance=variance_parts)
Exemplo n.º 2
0
 def test_atypical_nesting(self):
     results = FakeAtypicalNestingResults(
         unique_atypical_nesting=0, atypical_inner_results=FakeResults(1))
     self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting'))
     self.assertTrue(unnest.has_nested(results, 'unique_core'))
     self.assertFalse(hasattr(results, 'unique_core'))
     self.assertFalse(unnest.has_nested(results, 'foo'))
     self.assertEqual(
         results.unique_atypical_nesting,
         unnest.get_innermost(results, 'unique_atypical_nesting'))
     self.assertEqual(
         results.unique_atypical_nesting,
         unnest.get_outermost(results, 'unique_atypical_nesting'))
     self.assertEqual(results.atypical_inner_results.unique_core,
                      unnest.get_innermost(results, 'unique_core'))
     self.assertEqual(results.atypical_inner_results.unique_core,
                      unnest.get_outermost(results, 'unique_core'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_innermost(results, 'foo'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_outermost(results, 'foo'))
     self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON),
                   SINGLETON)
     self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON),
                   SINGLETON)
def hmc_like_proposed_velocity_getter_fn(kernel_results):
  """Getter for `proposed_velocity` so it can be inspected."""
  # TODO(b/169898033): This only works with the standard kinetic energy term.
  proposed_velocity = unnest.get_innermost(kernel_results, 'final_momentum')
  proposed_state = unnest.get_innermost(kernel_results, 'proposed_state')
  # proposed_velocity has the wrong structure when state is a scalar.
  return tf.nest.pack_sequence_as(proposed_state,
                                  tf.nest.flatten(proposed_velocity))
Exemplo n.º 4
0
 def test_change_core_params(self):
     builder = kernel_builder.KernelBuilder.make(lambda x: x)
     builder = builder.hmc(num_leapfrog_steps=3)
     kernel = builder.build()
     found_steps = unnest.get_innermost(kernel, 'num_leapfrog_steps')
     self.assertEqual(found_steps, 3)
     builder = builder.hmc(num_leapfrog_steps=5)
     kernel = builder.build()
     found_steps = unnest.get_innermost(kernel, 'num_leapfrog_steps')
     self.assertEqual(found_steps, 5)
Exemplo n.º 5
0
 def trace_fn(_, pkr):
   energy_diff = pkr.inner_results.inner_results.inner_results.log_accept_ratio
   has_divergence = jnp.abs(energy_diff) > max_energy_diff
   return (
       unnest.get_innermost(pkr, 'target_log_prob'),
       unnest.get_innermost(pkr, 'num_leapfrog_steps'),
       has_divergence,
       energy_diff,
       pkr.inner_results.inner_results.inner_results.log_accept_ratio,
       pkr.inner_results.inner_results.max_trajectory_length,
       unnest.get_innermost(pkr, 'step_size'),
   )
Exemplo n.º 6
0
 def trace_fn(_, pkr):
     return {
         'step_size':
         unnest.get_innermost(pkr, 'step_size'),
         'mean_trajectory_length':
         unnest.get_innermost(pkr, 'max_trajectory_length') / 2.,
         'principal_component':
         unnest.get_innermost(pkr, 'ema_principal_component'),
         'variance':
         unnest.get_innermost(pkr, 'ema_variance'),
         'num_leapfrog_steps':
         unnest.get_innermost(pkr, 'num_leapfrog_steps'),
     }
Exemplo n.º 7
0
 def test_flat(self):
   results = FakeResults(0)
   self.assertTrue(unnest.has_nested(results, 'unique_core'))
   self.assertFalse(unnest.has_nested(results, 'foo'))
   self.assertEqual(results.unique_core,
                    unnest.get_innermost(results, 'unique_core'))
   self.assertEqual(results.unique_core,
                    unnest.get_outermost(results, 'unique_core'))
   self.assertRaises(
       AttributeError, lambda: unnest.get_innermost(results, 'foo'))
   self.assertRaises(
       AttributeError, lambda: unnest.get_outermost(results, 'foo'))
   self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON)
   self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
Exemplo n.º 8
0
def _get_field(kernel_results, field_name):
    try:
        return unnest.get_innermost(kernel_results, field_name)
    except AttributeError:
        msg = _kernel_result_not_implemented_message_template.format(
            kernel_results, field_name)
        raise REMCFieldNotFoundError(msg)
    def test_supply_single_step_size(self):
        stream = test_util.test_seed_stream()

        jd_model = tfd.JointDistributionNamed({
            'a':
            tfd.Normal(0., 1.),
            'b':
            tfd.MultivariateNormalDiag(loc=tf.zeros(3),
                                       scale_diag=tf.constant([1., 2., 3.]))
        })

        init_step_size = 1.
        _, traced_step_size = self.evaluate(
            tfp.experimental.mcmc.windowed_adaptive_hmc(
                1,
                jd_model,
                num_adaptation_steps=25,
                n_chains=20,
                init_step_size=init_step_size,
                num_leapfrog_steps=5,
                discard_tuning=False,
                trace_fn=lambda *args: unnest.get_innermost(
                    args[-1], 'step_size'),
                seed=stream()))

        self.assertEqual((25 + 1, ), traced_step_size.shape)
        self.assertAllClose(1., traced_step_size[0])
    def test_supply_full_step_size(self):
        stream = test_util.test_seed_stream()

        jd_model = tfd.JointDistributionNamed({
            'a':
            tfd.Normal(0., 1.),
            'b':
            tfd.MultivariateNormalDiag(loc=tf.zeros(3),
                                       scale_diag=tf.constant([1., 2., 3.]))
        })

        init_step_size = {
            'a': tf.reshape(tf.linspace(1., 2., 20), (20, 1)),
            'b': tf.reshape(tf.linspace(1., 2., 60), (20, 3))
        }

        _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc(
            1,
            jd_model,
            num_adaptation_steps=100,
            n_chains=20,
            init_step_size=init_step_size,
            num_leapfrog_steps=5,
            discard_tuning=False,
            trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
            seed=stream(),
        )

        # Gets a newaxis because step size needs to have an event dimension.
        self.assertAllCloseNested([init_step_size['a'], init_step_size['b']],
                                  [j[0] for j in actual_step_size])
    def test_supply_single_step_size(self):
        stream = test_util.test_seed_stream()

        jd_model = tfd.JointDistributionNamed({
            'a':
            tfd.Normal(0., 1.),
            'b':
            tfd.MultivariateNormalDiag(loc=tf.zeros(3),
                                       scale_diag=tf.constant([1., 2., 3.]))
        })

        init_step_size = 1.
        _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc(
            1,
            jd_model,
            num_adaptation_steps=100,
            n_chains=20,
            init_step_size=init_step_size,
            num_leapfrog_steps=5,
            discard_tuning=False,
            trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
            seed=stream(),
        )

        actual_step = [j[0] for j in actual_step_size]
        expected_step = [1., 1.]
        self.assertAllCloseNested(expected_step, actual_step)
def hmc_like_log_accept_prob_getter_fn(kernel_results):
  log_accept_ratio = unnest.get_innermost(kernel_results, 'log_accept_ratio')
  safe_accept_ratio = tf.where(
      tf.math.is_finite(log_accept_ratio),
      log_accept_ratio,
      tf.constant(-np.inf, dtype=log_accept_ratio.dtype))
  return tf.minimum(safe_accept_ratio, 0.)
Exemplo n.º 13
0
    def bootstrap_results(self, init_state):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'bootstrap_results')):
            # Step inner results.
            inner_results = self.inner_kernel.bootstrap_results(init_state)

            # Bootstrap the results.
            results = self._bootstrap_from_inner_results(
                init_state, inner_results)
            if self.num_estimation_steps is not None:
                # We only update the momentum at the end of adaptation phase,
                # so we do not need to set the momentum here.
                return results

            # Set the momentum.
            diags = [
                variance_part.variance()
                for variance_part in results.running_variance
            ]
            inner_results = results.inner_results
            batch_shape = ps.shape(
                unnest.get_innermost(inner_results, 'target_log_prob'))
            init_state_parts = tf.nest.flatten(init_state)
            momentum_distribution = preconditioning_utils.make_momentum_distribution(
                init_state_parts,
                batch_shape,
                diags,
                shard_axis_names=self.experimental_shard_axis_names)
            inner_results = self.momentum_distribution_setter_fn(
                inner_results, momentum_distribution)
            proposed = unnest.get_innermost(inner_results,
                                            'proposed_results',
                                            default=None)
            if proposed is not None:
                proposed = proposed._replace(
                    momentum_distribution=momentum_distribution)
                inner_results = unnest.replace_innermost(
                    inner_results, proposed_results=proposed)
            results = results._replace(inner_results=inner_results)
            return results
Exemplo n.º 14
0
def hmc_like_proposed_velocity_getter_fn(kernel_results):
    """Getter for `proposed_velocity` so it can be inspected."""
    final_momentum = unnest.get_innermost(kernel_results, 'final_momentum')
    proposed_state = unnest.get_innermost(kernel_results, 'proposed_state')

    momentum_distribution = unnest.get_innermost(kernel_results,
                                                 'momentum_distribution',
                                                 default=None)
    if momentum_distribution is None:
        proposed_velocity = final_momentum
    else:
        momentum_log_prob = getattr(momentum_distribution,
                                    '_log_prob_unnormalized',
                                    momentum_distribution.log_prob)
        kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)
        _, proposed_velocity = mcmc_util.maybe_call_fn_and_grads(
            kinetic_energy_fn, final_momentum)
    # proposed_velocity has the wrong structure when state is a scalar.
    return tf.nest.pack_sequence_as(proposed_state,
                                    tf.nest.flatten(proposed_velocity))
Exemplo n.º 15
0
 def test_deeply_nested(self):
   results = _build_deeply_nested(0, 1, 2, 3, 4)
   self.assertTrue(unnest.has_nested(results, 'unique_nesting'))
   self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting'))
   self.assertTrue(unnest.has_nested(results, 'unique_core'))
   self.assertFalse(hasattr(self, 'unique_core'))
   self.assertFalse(unnest.has_nested(results, 'foo'))
   self.assertEqual(unnest.get_innermost(results, 'unique_nesting'), 3)
   self.assertEqual(unnest.get_outermost(results, 'unique_nesting'), 0)
   self.assertEqual(
       unnest.get_innermost(results, 'unique_atypical_nesting'), 2)
   self.assertEqual(
       unnest.get_outermost(results, 'unique_atypical_nesting'), 1)
   self.assertEqual(unnest.get_innermost(results, 'unique_core'), 4)
   self.assertEqual(unnest.get_outermost(results, 'unique_core'), 4)
   self.assertRaises(
       AttributeError, lambda: unnest.get_innermost(results, 'foo'))
   self.assertRaises(
       AttributeError, lambda: unnest.get_outermost(results, 'foo'))
   self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON)
   self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
Exemplo n.º 16
0
def trace_results_fn(_, results):
    """Packs results into a dictionary"""
    results_dict = {}
    root_results = results.inner_results

    step_size = tf.convert_to_tensor(
        unnest.get_outermost(root_results[0], "step_size")
    )

    results_dict["hmc"] = {
        "is_accepted": unnest.get_innermost(root_results[0], "is_accepted"),
        "target_log_prob": unnest.get_innermost(
            root_results[0], "target_log_prob"
        ),
        "step_size": step_size,
    }

    def get_move_results(results):
        return {
            "is_accepted": results.is_accepted,
            "target_log_prob": results.accepted_results.target_log_prob,
            "proposed_delta": tf.stack(
                [
                    results.accepted_results.m,
                    results.accepted_results.t,
                    results.accepted_results.delta_t,
                    results.accepted_results.x_star,
                ]
            ),
        }

    res1 = root_results[1].inner_results
    results_dict["move/S->E"] = get_move_results(res1[0])
    results_dict["move/E->I"] = get_move_results(res1[1])
    results_dict["occult/S->E"] = get_move_results(res1[2])
    results_dict["occult/E->I"] = get_move_results(res1[3])

    return results_dict
Exemplo n.º 17
0
 def sample_trace_fn(_, pkr):
   return (
       unnest.get_innermost(pkr, 'target_log_prob'),
       unnest.get_innermost(pkr, 'leapfrogs_taken'),
       unnest.get_innermost(pkr, 'has_divergence'),
       unnest.get_innermost(pkr, 'energy'),
       unnest.get_innermost(pkr, 'log_accept_ratio'),
       unnest.get_innermost(pkr, 'reach_max_depth'),
   )
Exemplo n.º 18
0
def default_nuts_trace_fn(state, bijector, is_adapting, pkr):
    """Trace function for `windowed_adaptive_nuts` providing standard diagnostics.

  Specifically, these match up with a number of diagnostics used by ArviZ [1],
  to make diagnostics and analysis easier. The names used follow those used in
  TensorFlow Probability, and will need to be mapped to those used in the ArviZ
  schema.

  References:
    [1]: Kumar, R., Carroll, C., Hartikainen, A., & Martin, O. (2019). ArviZ a
    unified library for exploratory analysis of Bayesian models in Python.
    Journal of Open Source Software, 4(33), 1143.

  Args:
   state: tf.Tensor
     Current sampler state, flattened and unconstrained.
   bijector: tfb.Bijector
     This can be used to untransform the shape to something with the same shape
     as will be returned.
   is_adapting: bool
     Whether this is an adapting step, or may be treated as a valid MCMC draw.
   pkr: UncalibratedPreconditionedHamiltonianMonteCarloKernelResults
     Kernel results from this iteration.

  Returns:
    dict with sampler statistics.
  """
    del state, bijector  # Unused

    energy_diff = unnest.get_innermost(pkr, 'log_accept_ratio')
    return {
        'step_size':
        unnest.get_innermost(pkr, 'step_size'),
        'tune':
        is_adapting,
        'target_log_prob':
        unnest.get_innermost(pkr, 'target_log_prob'),
        'diverging':
        unnest.get_innermost(pkr, 'has_divergence'),
        'accept_ratio':
        tf.minimum(tf.ones_like(energy_diff), tf.exp(energy_diff)),
        'variance_scaling':
        unnest.get_innermost(pkr, 'momentum_distribution').variance(),
        'n_steps':
        unnest.get_innermost(pkr, 'leapfrogs_taken'),
        'is_accepted':
        unnest.get_innermost(pkr, 'is_accepted')
    }
    def test_sequential_step_size(self):
        stream = test_util.test_seed_stream()

        jd_model = tfd.JointDistributionSequential(
            [tfd.HalfNormal(scale=1., name=f'dist_{idx}') for idx in range(4)])
        init_step_size = [1., 2., 3.]
        _, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_nuts(
            1,
            jd_model,
            num_adaptation_steps=100,
            n_chains=20,
            init_step_size=init_step_size,
            discard_tuning=False,
            trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
            dist_3=tf.constant(1.),
            seed=stream(),
        )

        self.assertAllCloseNested(init_step_size,
                                  [j[0] for j in actual_step_size])
Exemplo n.º 20
0
    def testTooFewChains(self, use_static_shape):
        state = tf.constant([[0.1, 0.2]])

        def tlp_fn(x):
            return tf1.placeholder_with_default(
                tf.reduce_sum(x, -1),
                shape=[1] if use_static_shape else [None])

        kernel = tfp.experimental.mcmc.SNAPERHamiltonianMonteCarlo(
            tlp_fn,
            step_size=0.1,
            num_adaptation_steps=2,
            num_mala_steps=100,
            validate_args=True,
        )

        with self.assertRaisesRegex(Exception,
                                    'SNAPERHMC requires at least 2 chains'):
            self.evaluate(
                unnest.get_innermost(kernel.bootstrap_results(state),
                                     'target_log_prob'))
Exemplo n.º 21
0
def hmc_like_step_size_getter_fn(kernel_results):
    """Getter for `step_size` so it can be inspected."""
    return unnest.get_innermost(kernel_results, 'step_size')
Exemplo n.º 22
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            state_parts = tf.nest.flatten(current_state)
            new_momentum_distribution = _make_momentum_distribution(
                diags, state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                previous_kernel_results.inner_results,
                new_momentum_distribution)

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)
            new_state_parts = tf.nest.flatten(new_state)
            new_variance_parts = []
            for variance_part, diag, state_part in zip(variance_parts, diags,
                                                       new_state_parts):
                # Compute new variance for each variance part, accounting for partial
                # batching of the variance calculation across chains (ie, some, all, or
                # none of the chains may share the estimated mass matrix).
                #
                # For example, say
                #
                # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                # variance_part has shape          [4] + [5, 6]
                # log_prob has shape         [2, 3, 4]
                #
                # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                # matrices, each being shared across a [2, 3]-batch of chains. Note this
                # division is inferred from the shapes of the state part, the log_prob,
                # and the user-provided initial running variances.
                #
                # Until RunningVariance supports rank > 1 chunking, we need to flatten
                # the states that go into updating the variance estimates. In the above
                # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                # fed to `RunningVariance.update(state_part, axis=0)`, recording
                # 6 new observations in the running variance calculation.
                # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                # the resulting momentum distribution will have batch shape of
                # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                state_rank = ps.rank(state_part)
                variance_rank = ps.rank(diag)
                num_reduce_dims = state_rank - variance_rank

                state_part_shape = ps.shape(state_part)
                # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
                # dimensions to a single one otherwise.
                reshaped_state = ps.reshape(
                    state_part,
                    ps.concat(
                        [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
                         state_part_shape[num_reduce_dims:]],
                        axis=0))

                # The `axis=0` here removes the leading dimension we got from the
                # reshape above, so the new_variance_parts have the correct shape again.
                new_variance_parts.append(
                    variance_part.update(reshaped_state, axis=0))

            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts)

            return new_state, new_kernel_results
def hmc_like_proposed_state_getter_fn(kernel_results):
  """Getter for `proposed_state` so it can be inspected."""
  return unnest.get_innermost(kernel_results, 'proposed_state')
def hmc_like_num_leapfrog_steps_getter_fn(kernel_results):
  """Getter for `num_leapfrog_steps` so it can be inspected."""
  return unnest.get_innermost(kernel_results, 'num_leapfrog_steps')
Exemplo n.º 25
0
def hmc_like_momentum_distribution_getter_fn(kernel_results):
    """Getter for `momentum_distribution` so it can be updated."""
    return unnest.get_innermost(kernel_results, 'momentum_distribution')
Exemplo n.º 26
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'snaper_hamiltonian_monte_carlo',
                                    'one_step')):
            inner_results = previous_kernel_results.inner_results

            batch_shape = ps.shape(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            reduce_axes = ps.range(0, ps.size(batch_shape))
            step = inner_results.step
            state_ema_points = previous_kernel_results.state_ema_points

            kernel = self._make_kernel(
                batch_shape=batch_shape,
                step=step,
                state_ema_points=state_ema_points,
                state=current_state,
                mean=previous_kernel_results.ema_mean,
                variance=previous_kernel_results.ema_variance,
                principal_component=previous_kernel_results.
                ema_principal_component,
            )

            inner_results = unnest.replace_innermost(
                inner_results,
                momentum_distribution=(
                    kernel.inner_kernel.parameters['momentum_distribution']),  # pylint: disable=protected-access
            )

            seed = samplers.sanitize_seed(seed)
            state_parts, inner_results = kernel.one_step(
                tf.nest.flatten(current_state),
                inner_results,
                seed=seed,
            )

            state = tf.nest.pack_sequence_as(current_state, state_parts)

            state_ema_points, ema_mean, ema_variance = self._update_state_ema(
                reduce_axes=reduce_axes,
                state=state,
                step=step,
                state_ema_points=state_ema_points,
                ema_mean=previous_kernel_results.ema_mean,
                ema_variance=previous_kernel_results.ema_variance,
            )

            (principal_component_ema_points,
             ema_principal_component) = self._update_principal_component_ema(
                 reduce_axes=reduce_axes,
                 state=state,
                 step=step,
                 principal_component_ema_points=(
                     previous_kernel_results.principal_component_ema_points),
                 ema_principal_component=(
                     previous_kernel_results.ema_principal_component),
             )

            kernel_results = previous_kernel_results._replace(
                inner_results=inner_results,
                ema_mean=ema_mean,
                ema_variance=ema_variance,
                state_ema_points=state_ema_points,
                ema_principal_component=ema_principal_component,
                principal_component_ema_points=principal_component_ema_points,
                seed=seed,
            )

            return state, kernel_results