Exemplo n.º 1
0
    def trace():
      kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
          state,
          step_size=self._constant(0.1),
          num_integrator_steps=3,
          target_log_prob_fn=target_log_prob_fn,
          seed=_test_seed())

      fun_mcmc.trace(
          state=fun_mcmc.hamiltonian_monte_carlo_init(
              tf.zeros([1], dtype=self._dtype), target_log_prob_fn),
          fn=kernel,
          num_steps=4,
          trace_fn=lambda *args: ())
Exemplo n.º 2
0
  def testRunningCovarianceMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = self._constant(
        np.concatenate(
            [
                rng.randn(window_size, 2),
                np.array([1., 2.]) +
                np.array([2., 3.]) * rng.randn(window_size * 10, 2)
            ],
            axis=0,
        ))

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_covariance_step(
          rvs, data[idx], window_size=window_size)
      return (rvs, idx + 1), (rvs.mean, rvs.covariance)

    _, (mean, cov) = fun_mcmc.trace(
        state=(fun_mcmc.running_covariance_init([2], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean/variance exactly.
    self.assertAllClose(
        np.mean(data[:window_size], axis=0), mean[window_size - 1])
    self.assertAllClose(
        _gen_cov(data[:window_size], axis=0), cov[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean/variance after the change in the distribution. Since the moving
    # average is computed only over ~window_size points, this test is rather
    # noisy.
    self.assertAllClose(np.array([1., 2.]), mean[-1], atol=0.2)
    self.assertAllClose(np.array([[4., 0.], [0., 9.]]), cov[-1], atol=1.)
Exemplo n.º 3
0
  def testRunningVarianceMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = self._constant(
        np.concatenate(
            [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)],
            axis=0))

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_variance_step(
          rvs, data[idx], window_size=window_size)
      return (rvs, idx + 1), (rvs.mean, rvs.variance)

    _, (mean, var) = fun_mcmc.trace(
        state=(fun_mcmc.running_variance_init([], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean/variance exactly.
    self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1])
    self.assertAllClose(np.var(data[:window_size]), var[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean/variance after the change in the distribution. Since the moving
    # average is computed only over ~window_size points, this test is rather
    # noisy.
    self.assertAllClose(1., mean[-1], atol=0.2)
    self.assertAllClose(4., var[-1], atol=0.8)
Exemplo n.º 4
0
    def testWrapTransitionKernel(self):
        class TestKernel(tfp.mcmc.TransitionKernel):
            def one_step(self, current_state, previous_kernel_results):
                return [x + 1
                        for x in current_state], previous_kernel_results + 1

            def bootstrap_results(self, current_state):
                return sum(current_state)

            def is_calibrated(self):
                return True

        def kernel(state, pkr):
            return util_tfp.transition_kernel_wrapper(state, pkr, TestKernel())

        state = {'x': self._constant(0.), 'y': self._constant(1.)}
        kr = 1.
        (final_state, final_kr), _ = fun_mcmc.trace(
            (state, kr),
            kernel,
            2,
            trace_fn=lambda *args: (),
        )
        self.assertAllEqual({
            'x': 2.,
            'y': 3.
        }, util.map_tree(np.array, final_state))
        self.assertAllEqual(1. + 2., final_kr)
Exemplo n.º 5
0
  def testTraceMask(self, unroll):

    def fun(x):
      return x + 1, (2 * x, 3 * x)

    x, (trace_1, trace_2) = fun_mcmc.trace(
        state=0, fn=fun, num_steps=3, trace_mask=(True, False), unroll=unroll)

    self.assertAllEqual(3, x)
    self.assertAllEqual([0, 2, 4], trace_1)
    self.assertAllEqual(6, trace_2)

    x, (trace_1, trace_2) = fun_mcmc.trace(
        state=0, fn=fun, num_steps=3, trace_mask=False, unroll=unroll)

    self.assertAllEqual(3, x)
    self.assertAllEqual(4, trace_1)
    self.assertAllEqual(6, trace_2)
Exemplo n.º 6
0
  def testTraceTrace(self, unroll):

    def fun(x):
      return fun_mcmc.trace(
          x, lambda x: (x + 1., x + 1.), 2, trace_mask=False, unroll=unroll)

    x, trace = fun_mcmc.trace(0., fun, 2)
    self.assertAllEqual(4., x)
    self.assertAllEqual([2., 4.], trace)
Exemplo n.º 7
0
  def testPreconditionedHMC(self):
    step_size = self._constant(0.2)
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2], dtype=self._dtype)

    base_mean = self._constant([1., 0])
    base_cov = self._constant([[1, 0.5], [0.5, 1]])

    bijector = tfp.bijectors.Softplus()
    base_dist = tfp.distributions.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    def kernel(hmc_state, seed):
      hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_state.state_extra[0]

    seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed()))

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 8
0
  def testTraceSingle(self, unroll):

    def fun(x):
      return x + 1., 2 * x

    x, e_trace = fun_mcmc.trace(
        state=0.,
        fn=fun,
        num_steps=5,
        trace_fn=lambda _, xp1: xp1,
        unroll=unroll)

    self.assertAllEqual(5., x)
    self.assertAllEqual([0., 2., 4., 6., 8.], e_trace)
Exemplo n.º 9
0
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 10
0
  def testTraceNested(self, unroll):

    def fun(x, y):
      return (x + 1., y + 2.), ()

    (x, y), (x_trace, y_trace) = fun_mcmc.trace(
        state=(0., 0.),
        fn=fun,
        num_steps=5,
        trace_fn=lambda xy, _: xy,
        unroll=unroll)

    self.assertAllEqual(5., x)
    self.assertAllEqual(10., y)
    self.assertAllEqual([1., 2., 3., 4., 5.], x_trace)
    self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
Exemplo n.º 11
0
  def testBasicHMC(self, unroll):
    step_size = self._constant(0.2)
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2], dtype=self._dtype)

    base_mean = self._constant([2., 3.])
    base_scale = self._constant([2., 0.5])

    def target_log_prob_fn(x):
      return -tf.reduce_sum(0.5 * tf.square(
          (x - base_mean) / base_scale), -1), ()

    def kernel(hmc_state, seed):
      hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          unroll_integrator=unroll,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_state.state

    seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

    true_samples = util.random_normal(
        shape=[4096, 2], dtype=self._dtype, seed=seed) * base_scale + base_mean

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_var = tf.math.reduce_variance(true_samples, axis=0)

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
Exemplo n.º 12
0
  def testGradientDescent(self):

    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    _, [(x, y), loss] = fun_mcmc.trace(
        fun_mcmc.GradientDescentState([self._constant(0.), self._constant(0.)]),
        lambda gd_state: fun_mcmc.gradient_descent_step(  # pylint: disable=g-long-lambda
            gd_state,
            loss_fn,
            learning_rate=self._constant(0.01)),
        num_steps=1000,
        trace_fn=lambda state, extra: [state.state, extra.loss])

    self.assertAllClose(1., x[-1], atol=1e-3)
    self.assertAllClose(2., y[-1], atol=1e-3)
    self.assertAllClose(0., loss[-1], atol=1e-3)
Exemplo n.º 13
0
  def testRunningMean(self, shape, aggregation):
    rng = np.random.RandomState(_test_seed())
    data = self._constant(rng.randn(*shape))

    def kernel(rms, idx):
      rms, _ = fun_mcmc.running_mean_step(rms, data[idx], axis=aggregation)
      return (rms, idx + 1), ()

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_mean = np.mean(data, true_aggregation)

    (rms, _), _ = fun_mcmc.trace(
        state=(fun_mcmc.running_mean_init(true_mean.shape, data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
        trace_fn=lambda *args: ())

    self.assertAllClose(true_mean, rms.mean)
Exemplo n.º 14
0
  def testRandomWalkMetropolis(self):
    num_steps = 1000
    state = tf.ones([16], dtype=tf.int32)
    target_logits = self._constant([1., 2., 3., 4.]) + 2.
    proposal_logits = self._constant([4., 3., 2., 1.]) + 2.

    def target_log_prob_fn(x):
      return tf.gather(target_logits, x), ()

    def proposal_fn(x, seed):
      current_logits = tf.gather(proposal_logits, x)
      proposal = util.random_categorical(proposal_logits[tf.newaxis],
                                         x.shape[0], seed)[0]
      proposed_logits = tf.gather(proposal_logits, proposal)
      return tf.cast(proposal, x.dtype), ((), proposed_logits - current_logits)

    def kernel(rwm_state, seed):
      rwm_seed, seed = util.split_seed(seed, 2)
      rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis(
          rwm_state,
          target_log_prob_fn=target_log_prob_fn,
          proposal_fn=proposal_fn,
          seed=rwm_seed)
      return (rwm_state, seed), rwm_extra

    seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.random_walk_metropolis_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state[0].state))(state, seed)
    # Discard the warmup samples.
    chain = chain[500:]

    sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1])
    self.assertAllClose(tf.nn.softmax(target_logits), sample_mean, atol=0.11)
Exemplo n.º 15
0
  def testSimpleDualAverages(self):

    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    def kernel(sda_state, rms_state):
      sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.)
      rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state)
      return (sda_state, rms_state), rms_state.mean

    _, (x, y) = fun_mcmc.trace(
        (
            fun_mcmc.simple_dual_averages_init(
                [self._constant(0.), self._constant(0.)]),
            fun_mcmc.running_mean_init([[], []], [self._dtype, self._dtype]),
        ),
        kernel,
        num_steps=1000,
    )

    self.assertAllClose(1., x[-1], atol=1e-1)
    self.assertAllClose(2., y[-1], atol=1e-1)
Exemplo n.º 16
0
  def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims):
    self.skipTest('https://github.com/tensorflow/probability/issues/1054')
    # TODO(siege): Update fun_mcmc.potential_scale_reduction.
    rng = np.random.RandomState(_test_seed())
    chain_means = rng.randn(*((1,) + chain_shape[1:])).astype(np.float32)
    chains = 0.4 * rng.randn(*chain_shape).astype(np.float32) + chain_means

    true_rhat = tfp.mcmc.potential_scale_reduction(
        chains, independent_chain_ndims=independent_chain_ndims)

    chains = self._constant(chains)
    psrs, _ = fun_mcmc.trace(
        state=fun_mcmc.potential_scale_reduction_init(chain_shape[1:],
                                                      self._dtype),
        fn=lambda psrs: fun_mcmc.potential_scale_reduction_step(  # pylint: disable=g-long-lambda
            psrs, chains[psrs.num_points]),
        num_steps=chain_shape[0],
        trace_fn=lambda *_: ())

    running_rhat = fun_mcmc.potential_scale_reduction_extract(
        psrs, independent_chain_ndims=independent_chain_ndims)
    self.assertAllClose(true_rhat, running_rhat)
Exemplo n.º 17
0
 def fun(x):
   return fun_mcmc.trace(
       x, lambda x: (x + 1., x + 1.), 2, trace_mask=False, unroll=unroll)
Exemplo n.º 18
0
  def testRunningApproximateAutoCovariance(self, state_shape, event_ndims,
                                           aggregation):
    # We'll use HMC as the source of our chain.
    # While HMC is being sampled, we also compute the running autocovariance.
    step_size = 0.2
    num_steps = 1000
    num_leapfrog_steps = 10
    max_lags = 300

    state = tf.zeros(state_shape, dtype=self._dtype)

    def target_log_prob_fn(x):
      lp = -0.5 * tf.square(x)
      if event_ndims is None:
        return lp, ()
      else:
        return tf.reduce_sum(lp, -1), ()

    def kernel(hmc_state, raac_state, seed):
      hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
          raac_state, hmc_state.state, axis=aggregation)
      return (hmc_state, raac_state, seed), hmc_extra

    seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(
            fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
            fun_mcmc.running_approximate_auto_covariance_init(
                max_lags=max_lags,
                state_shape=state_shape,
                dtype=state.dtype,
                axis=aggregation),
            seed,
        ),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state[0].state))(state, seed)

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_variance = np.array(
        tf.math.reduce_variance(np.array(chain), true_aggregation))
    true_autocov = np.array(
        tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags))
    if aggregation is not None:
      true_autocov = tf.reduce_mean(
          true_autocov, [a + 1 for a in util.flatten_tree(aggregation)])

    self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5)
    self.assertAllClose(
        true_autocov,
        raac_state.auto_covariance / raac_state.auto_covariance[0],
        atol=0.1)
Exemplo n.º 19
0
    def testAdaptiveHMC(self):
        num_chains = 16
        num_steps = 4000
        num_warmup_steps = num_steps // 2
        num_adapt_steps = int(0.8 * num_warmup_steps)

        # Setup the model and state constraints.
        model = tfp.distributions.JointDistributionSequential([
            tfp.distributions.Normal(loc=self._constant(0.), scale=1.),
            tfp.distributions.Independent(
                tfp.distributions.LogNormal(loc=self._constant([1., 1.]),
                                            scale=0.5), 1),
        ])
        bijector = [tfp.bijectors.Identity(), tfp.bijectors.Exp()]
        transform_fn = util_tfp.bijector_to_transform_fn(bijector,
                                                         model.dtype,
                                                         batch_ndims=1)

        def target_log_prob_fn(*x):
            return model.log_prob(x), ()

        # Start out at zeros (in the unconstrained space).
        state, _ = transform_fn(*[
            tf.zeros([num_chains] + list(e), dtype=self._dtype)
            for e in model.event_shape
        ])

        reparam_log_prob_fn, reparam_state = fun_mcmc.reparameterize_potential_fn(
            target_log_prob_fn, transform_fn, state)

        # Define the kernel.
        def kernel(adaptive_hmc_state, seed):
            hmc_seed, seed = util.split_seed(seed, 2)

            adaptive_hmc_state, adaptive_hmc_extra = (
                prefab.adaptive_hamiltonian_monte_carlo_step(
                    adaptive_hmc_state,
                    target_log_prob_fn=reparam_log_prob_fn,
                    num_adaptation_steps=num_adapt_steps,
                    seed=hmc_seed))

            return (adaptive_hmc_state, seed), (adaptive_hmc_extra.state,
                                                adaptive_hmc_extra.is_accepted,
                                                adaptive_hmc_extra.step_size)

        seed = self._make_seed(_test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, (state_chain, is_accepted_chain,
            _) = tf.function(lambda reparam_state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
                state=(prefab.adaptive_hamiltonian_monte_carlo_init(
                    reparam_state, reparam_log_prob_fn), seed),
                fn=kernel,
                num_steps=num_steps))(reparam_state, seed)

        # Discard the warmup samples.
        state_chain = [s[num_warmup_steps:] for s in state_chain]
        is_accepted_chain = is_accepted_chain[num_warmup_steps:]

        accept_rate = tf.reduce_mean(tf.cast(is_accepted_chain, tf.float32))
        rhat = tfp.mcmc.potential_scale_reduction(state_chain)
        sample_mean = [tf.reduce_mean(s, axis=[0, 1]) for s in state_chain]
        sample_var = [
            tf.math.reduce_variance(s, axis=[0, 1]) for s in state_chain
        ]

        self.assertAllAssertsNested(lambda rhat: self.assertAllLess(rhat, 1.1),
                                    rhat)
        self.assertAllClose(0.8, accept_rate, atol=0.05)
        self.assertAllClose(model.mean(), sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(model.variance(), sample_var, rtol=0.1, atol=0.1)
Exemplo n.º 20
0
def interactive_trace(
    state: 'fun_mc.State',
    fn: 'fun_mc.TransitionOperator',
    num_steps: 'fun_mc.IntTensor',
    trace_mask: 'fun_mc.BooleanNest' = True,
    block_until_ready: 'bool' = True,
    progress_bar_fn: 'Callable[[Iterable[Any]], Iterable[Any]]' = (
        _tqdm_progress_bar_fn),
) -> 'Tuple[fun_mc.State, fun_mc.TensorNest]':
    """Wrapped around fun_mcmc.trace, suited for interactive work.

  This is accomplished through unrolling fun_mcmc.trace, as well as optionally
  using a progress bar (TQDM by default).

  Args:
    state: A nest of `Tensor`s or None.
    fn: A `TransitionOperator`.
    num_steps: Number of steps to run the function for. Must be greater than 1.
    trace_mask: A potentially shallow nest with boolean leaves applied to the
      `extra` return value of `fn`. This controls whether or not to actually
      trace the quantities in `extra`. For subtrees of `extra` where the mask
      leaf is `True`, those subtrees are traced (i.e. the corresponding subtrees
      in `traces` will contain an extra leading dimension equalling
      `num_steps`). For subtrees of `extra` where the mask leaf is `False`,
      those subtrees are merely propagated, and their corresponding subtrees in
      `traces` correspond to their final value.
    block_until_ready: Whether to wait for the computation to finish between
      steps. This results in smoother progress bars under, e.g., JAX.
    progress_bar_fn: A callable that will be called with an iterable with length
      `num_steps` and which returns another iterable with the same length. This
      will be advanced for every step taken. If None, no progress bar is
      shown. Default: `lambda it: tqdm.tqdm(it, leave=True)`.

  Returns:
    state: The final state returned by `fn`.
    traces: A nest with the same structure as the extra return value of `fn`,
      but with leaves replaced with stacked and unstacked values according to
      the `trace_mask`.
  """
    num_steps = tf.get_static_value(num_steps)
    if num_steps is None:
        raise ValueError(
            'Interactive tracing requires `num_steps` to be statically known.')

    if progress_bar_fn is None:
        pbar = None
    else:
        pbar = iter(progress_bar_fn(range(num_steps)))

    def fn_with_progress(state):
        state, extra = fun_mc.call_transition_operator(fn, state)
        if block_until_ready:
            state, extra = util.block_until_ready((state, extra))
        if pbar is not None:
            try:
                next(pbar)
            except StopIteration:
                pass
        return [state], extra

    [state], trace = fun_mc.trace(
        # Wrap the state in a singleton list to simplify implementation of
        # `fn_with_progress`.
        state=[state],
        fn=fn_with_progress,
        num_steps=num_steps,
        trace_mask=trace_mask,
        unroll=True,
    )
    return state, trace
Exemplo n.º 21
0
 def trace_n(num_steps):
   return fun_mcmc.trace(0, lambda x: (x + 1, ()), num_steps)[0]