def testWrapTransitionKernel(self):

    class TestKernel(tfp.mcmc.TransitionKernel):

      def one_step(self, current_state, previous_kernel_results):
        return [x + 1 for x in current_state], previous_kernel_results + 1

      def bootstrap_results(self, current_state):
        return sum(current_state)

      def is_calibrated(self):
        return True

    def kernel(state, pkr):
      return util_tfp.transition_kernel_wrapper(
          state, pkr, TestKernel())

    state = {'x': self._constant(0.), 'y': self._constant(1.)}
    kr = 1.
    (final_state, final_kr), _ = fun_mc.trace(
        (state, kr),
        kernel,
        2,
        trace_fn=lambda *args: (),
    )
    self.assertAllEqual({
        'x': 2.,
        'y': 3.
    }, util.map_tree(np.array, final_state))
    self.assertAllEqual(1. + 2., final_kr)
Exemple #2
0
    def testInteractiveIterationAxis1(self):
        def kernel(x):
            return x + 1, x

        state, trace = prefab.interactive_trace(
            0.,
            lambda x: fun_mc.trace(x, kernel, 5),
            20,
            iteration_axis=1,
            progress_bar_fn=None)

        self.assertAllClose(100., state)
        self.assertEqual([100], list(trace.shape))
        self.assertAllClose(99., trace[-1])
    def testPHMC(self):
        step_size = self._constant(0.2)
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2], dtype=self._dtype)

        base_mean = self._constant([2., 3.])
        base_scale = self._constant([2., 0.5])

        def target_log_prob_fn(x):
            return -tf.reduce_sum(
                0.5 * tf.square((x - base_mean) / base_scale), -1), ()

        def kernel(phmc_state, seed):
            phmc_seed, seed = util.split_seed(seed, 2)
            phmc_state, _ = prefab.persistent_hamiltonian_monte_carlo_step(
                phmc_state,
                step_size=step_size,
                num_integrator_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn,
                noise_fraction=self._constant(0.5),
                mh_drift=self._constant(0.127),
                seed=phmc_seed)
            return (phmc_state, seed), phmc_state.state

        seed = self._make_seed(_test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, chain = tf.function(lambda state, seed: fun_mc.trace(  # pylint: disable=g-long-lambda
            state=(prefab.persistent_hamiltonian_monte_carlo_init(
                state, target_log_prob_fn), seed),
            fn=kernel,
            num_steps=num_steps))(state, seed)
        # Discard the warmup samples.
        chain = chain[1000:]

        sample_mean = tf.reduce_mean(chain, axis=[0, 1])
        sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

        true_samples = util.random_normal(shape=[4096, 2],
                                          dtype=self._dtype,
                                          seed=seed) * base_scale + base_mean

        true_mean = tf.reduce_mean(true_samples, axis=0)
        true_var = tf.math.reduce_variance(true_samples, axis=0)

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
  def testSGAHMC(self):

    @tfd.JointDistributionCoroutine
    def model():
      x = yield Root(tfd.Normal(self._constant(0.), 1.))
      yield tfd.Sample(tfd.Normal(x, 1.), 2)

    def target_log_prob_fn(x):
      return model.log_prob(x), ()

    @tf.function
    def kernel(sga_hmc_state, step, seed):
      adapt = step < num_adapt_steps
      seed, hmc_seed = util.split_seed(seed, 2)

      sga_hmc_state, sga_hmc_extra = sga_hmc.stochastic_gradient_ascent_hmc_step(
          sga_hmc_state,
          scalar_step_size=0.1,
          target_log_prob_fn=target_log_prob_fn,
          criterion_fn=sga_hmc.chees_criterion,
          adapt=adapt,
          seed=hmc_seed,
      )

      return (sga_hmc_state, step + 1, seed
             ), sga_hmc_extra.trajectory_length_params.mean_trajectory_length()

    init_trajectory_length = self._constant(0.1)
    num_adapt_steps = 10
    _, trajectory_length = fun_mc.trace(
        (sga_hmc.stochastic_gradient_ascent_hmc_init(
            util.map_tree_up_to(
                model.dtype, lambda dtype, shape: tf.zeros(  # pylint: disable=g-long-lambda
                    (16,) + tuple(shape), dtype), model.dtype,
                model.event_shape),
            target_log_prob_fn,
            init_trajectory_length=init_trajectory_length), 0,
         self._make_seed(_test_seed())), kernel, num_adapt_steps + 2)

    # We expect it to increase as part of adaptation.
    self.assertAllGreater(trajectory_length[-1], init_trajectory_length)
    # After adaptation is done, the trajectory length should remain constant.
    self.assertAllClose(trajectory_length[-1], trajectory_length[-2])
Exemple #5
0
    def testStepSizeAdaptation(self):
        def log_accept_ratio_fn(step_size):
            return -step_size**2

        def kernel(ssa_state, seed):
            normal_seed, seed = util.split_seed(seed, 2)
            log_accept_ratio = (
                log_accept_ratio_fn(ssa_state.step_size()) +
                0.01 * util.random_normal([4], self._dtype, normal_seed))
            ssa_state, ssa_extra = prefab.step_size_adaptation_step(
                ssa_state, log_accept_ratio, num_adaptation_steps=100)
            return (ssa_state,
                    seed), (ssa_extra.accept_prob, ssa_state.step_size(),
                            ssa_state.step_size(num_adaptation_steps=100))

        seed = self._make_seed(_test_seed())

        _, (p_accept, step_size, rms_step_size) = fun_mc.trace(
            (prefab.step_size_adaptation_init(tf.constant(
                0.1, self._dtype)), seed), kernel, 200)

        self.assertAllClose(0.8, p_accept[100], atol=0.1)
        self.assertAllClose(step_size[100], step_size[150])
        self.assertAllClose(rms_step_size[100], rms_step_size[150])
Exemple #6
0
def interactive_trace(
    state: 'fun_mc.State',
    fn: 'fun_mc.TransitionOperator',
    num_steps: 'fun_mc.IntTensor',
    trace_mask: 'fun_mc.BooleanNest' = True,
    iteration_axis: int = 0,
    block_until_ready: 'bool' = True,
    progress_bar_fn: 'Callable[[Iterable[Any]], Iterable[Any]]' = (
        _tqdm_progress_bar_fn),
) -> 'Tuple[fun_mc.State, fun_mc.TensorNest]':
  """Wrapped around fun_mc.trace, suited for interactive work.

  This is accomplished through unrolling fun_mc.trace, as well as optionally
  using a progress bar (TQDM by default).

  Args:
    state: A nest of `Tensor`s or None.
    fn: A `TransitionOperator`.
    num_steps: Number of steps to run the function for. Must be greater than 1.
    trace_mask: A potentially shallow nest with boolean leaves applied to the
      `extra` return value of `fn`. This controls whether or not to actually
      trace the quantities in `extra`. For subtrees of `extra` where the mask
      leaf is `True`, those subtrees are traced (i.e. the corresponding subtrees
      in `traces` will contain an extra leading dimension equalling
      `num_steps`). For subtrees of `extra` where the mask leaf is `False`,
      those subtrees are merely propagated, and their corresponding subtrees in
      `traces` correspond to their final value.
    iteration_axis: Integer. Indicates the axis of the trace outputs that should
      be flattened with the first axis. This is most useful when `fn` is
      `trace`. E.g. if the trace has shape `[num_steps, 2, 5]` and
      `iteration_axis=2`, the trace outputs will be reshaped/transposed to
      `[2, 5 * num_steps]`. A value of 0 disables this operation.
    block_until_ready: Whether to wait for the computation to finish between
      steps. This results in smoother progress bars under, e.g., JAX.
    progress_bar_fn: A callable that will be called with an iterable with length
      `num_steps` and which returns another iterable with the same length. This
      will be advanced for every step taken. If None, no progress bar is
      shown. Default: `lambda it: tqdm.tqdm(it, leave=True)`.

  Returns:
    state: The final state returned by `fn`.
    traces: A nest with the same structure as the extra return value of `fn`,
      but with leaves replaced with stacked and unstacked values according to
      the `trace_mask`.
  """
  num_steps = tf.get_static_value(num_steps)
  if num_steps is None:
    raise ValueError(
        'Interactive tracing requires `num_steps` to be statically known.')

  if progress_bar_fn is None:
    pbar = None
  else:
    pbar = iter(progress_bar_fn(range(num_steps)))

  def fn_with_progress(state):
    state, extra = fun_mc.call_transition_operator(fn, state)
    if block_until_ready:
      state, extra = util.block_until_ready((state, extra))
    if pbar is not None:
      try:
        next(pbar)
      except StopIteration:
        pass
    return [state], extra

  [state], trace = fun_mc.trace(
      # Wrap the state in a singleton list to simplify implementation of
      # `fn_with_progress`.
      state=[state],
      fn=fn_with_progress,
      num_steps=num_steps,
      trace_mask=trace_mask,
      unroll=True,
  )

  if iteration_axis != 0:
    def fix_part(x):
      x = util.move_axis(x, 0, iteration_axis - 1)
      x = tf.reshape(
          x,
          tuple(x.shape[:iteration_axis - 1]) + (-1,) +
          tuple(x.shape[iteration_axis + 1:]))
      return x
    trace = util.map_tree(fix_part, trace)
  return state, trace
Exemple #7
0
    def testAdaptiveHMC(self):
        num_chains = 16
        num_steps = 4000
        num_warmup_steps = num_steps // 2
        num_adapt_steps = int(0.8 * num_warmup_steps)

        # Setup the model and state constraints.
        model = tfp.distributions.JointDistributionSequential([
            tfp.distributions.Normal(loc=self._constant(0.), scale=1.),
            tfp.distributions.Independent(
                tfp.distributions.LogNormal(loc=self._constant([1., 1.]),
                                            scale=0.5), 1),
        ])
        bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()]
        transform_fn = util_tfp.bijector_to_transform_fn(bijector,
                                                         model.dtype,
                                                         batch_ndims=1)

        def target_log_prob_fn(*x):
            return model.log_prob(x), ()

        # Start out at zeros (in the unconstrained space).
        state, _ = transform_fn(*[
            tf.zeros([num_chains] + list(e), dtype=self._dtype)
            for e in model.event_shape
        ])

        reparam_log_prob_fn, reparam_state = fun_mc.reparameterize_potential_fn(
            target_log_prob_fn, transform_fn, state)

        # Define the kernel.
        def kernel(adaptive_hmc_state, seed):
            hmc_seed, seed = util.split_seed(seed, 2)

            adaptive_hmc_state, adaptive_hmc_extra = (
                prefab.adaptive_hamiltonian_monte_carlo_step(
                    adaptive_hmc_state,
                    target_log_prob_fn=reparam_log_prob_fn,
                    num_adaptation_steps=num_adapt_steps,
                    seed=hmc_seed))

            return (adaptive_hmc_state, seed), (adaptive_hmc_extra.state,
                                                adaptive_hmc_extra.is_accepted,
                                                adaptive_hmc_extra.step_size)

        seed = self._make_seed(_test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, (state_chain, is_accepted_chain,
            _) = tf.function(lambda reparam_state, seed: fun_mc.trace(  # pylint: disable=g-long-lambda
                state=(prefab.adaptive_hamiltonian_monte_carlo_init(
                    reparam_state, reparam_log_prob_fn), seed),
                fn=kernel,
                num_steps=num_steps))(reparam_state, seed)

        # Discard the warmup samples.
        state_chain = [s[num_warmup_steps:] for s in state_chain]
        is_accepted_chain = is_accepted_chain[num_warmup_steps:]

        accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32))
        rhat = tfp.mcmc.potential_scale_reduction(state_chain)
        sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain]
        sample_var = [
            tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain
        ]

        self.assertAllAssertsNested(lambda rhat: self.assertAllLess(rhat, 1.1),
                                    rhat)
        self.assertAllClose(0.8, accept_rate, atol=0.05)
        self.assertAllClose(model.mean(), sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(model.variance(), sample_var, rtol=0.1, atol=0.1)
Exemple #8
0
 def inner(x):
     state, trace = fun_mc.trace(x, kernel, 5)
     trace = tf.transpose(trace, [1, 0])
     return state, trace