예제 #1
0
  def testRunningVarianceMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = self._constant(
        np.concatenate(
            [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)],
            axis=0))

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_variance_step(
          rvs, data[idx], window_size=window_size)
      return (rvs, idx + 1), (rvs.mean, rvs.variance)

    _, (mean, var) = fun_mcmc.trace(
        state=(fun_mcmc.running_variance_init([], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean/variance exactly.
    self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1])
    self.assertAllClose(np.var(data[:window_size]), var[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean/variance after the change in the distribution. Since the moving
    # average is computed only over ~window_size points, this test is rather
    # noisy.
    self.assertAllClose(1., mean[-1], atol=0.2)
    self.assertAllClose(4., var[-1], atol=0.8)
예제 #2
0
def adaptive_hamiltonian_monte_carlo_init(
    state: 'fun_mc.TensorNest',
    target_log_prob_fn: 'fun_mc.PotentialFn',
    step_size: 'fun_mc.FloatTensor' = 1e-2,
    initial_mean: 'fun_mc.FloatNest' = 0.,
    initial_scale: 'fun_mc.FloatNest' = 1.,
    scale_smoothing_steps: 'fun_mc.IntTensor' = 10,
) -> 'AdaptiveHamiltonianMonteCarloState':
    """Initializes `AdaptiveHamiltonianMonteCarloState`.

  Args:
    state: Initial state of the chain.
    target_log_prob_fn: Target log prob fn.
    step_size: Initial scalar step size.
    initial_mean: Initial mean for computing the running variance estimate. Must
      broadcast structurally and tensor-wise with state.
    initial_scale: Initial scale for computing the running variance estimate.
      Must broadcast structurally and tensor-wise with state.
    scale_smoothing_steps: How much weight to assign to the `initial_mean` and
      `initial_scale`. Increase this to stabilize early adaptation.

  Returns:
    adaptive_hmc_state: State of the `adaptive_hamiltonian_monte_carlo_step`
      `TransitionOperator`.
  """
    hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn)
    dtype = util.flatten_tree(hmc_state.state)[0].dtype
    chain_ndims = len(hmc_state.target_log_prob.shape)
    running_var_state = fun_mc.running_variance_init(
        shape=util.map_tree(lambda s: s.shape[chain_ndims:], hmc_state.state),
        dtype=util.map_tree(lambda s: s.dtype, hmc_state.state),
    )
    initial_mean = fun_mc.maybe_broadcast_structure(initial_mean, state)
    initial_scale = fun_mc.maybe_broadcast_structure(initial_scale, state)

    # It's important to add some smoothing here, as initial updates can be very
    # different than the stationary distribution.
    # TODO(siege): Add a pseudo-update functionality to avoid fiddling with the
    # internals here.
    running_var_state = running_var_state._replace(
        num_points=util.map_tree(
            lambda p: (  # pylint: disable=g-long-lambda
                int(np.prod(hmc_state.target_log_prob.shape)) * tf.cast(
                    scale_smoothing_steps, p.dtype)),
            running_var_state.num_points),
        mean=util.map_tree(  # pylint: disable=g-long-lambda
            lambda m, init_m: tf.ones_like(m) * init_m, running_var_state.mean,
            initial_mean),
        variance=util.map_tree(  # pylint: disable=g-long-lambda
            lambda v, init_s: tf.ones_like(v) * init_s**2,
            running_var_state.variance, initial_scale),
    )
    log_step_size_opt_state = fun_mc.adam_init(
        tf.math.log(tf.convert_to_tensor(step_size, dtype=dtype)))

    return AdaptiveHamiltonianMonteCarloState(
        hmc_state=hmc_state,
        running_var_state=running_var_state,
        log_step_size_opt_state=log_step_size_opt_state,
        step=tf.zeros([], tf.int32))
예제 #3
0
  def testRunningVariance(self, shape, aggregation):
    rng = np.random.RandomState(_test_seed())
    data = self._constant(rng.randn(*shape))

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_mean = np.mean(data, true_aggregation)
    true_var = np.var(data, true_aggregation)

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_variance_step(rvs, data[idx], axis=aggregation)
      return (rvs, idx + 1), ()

    (rvs, _), _ = fun_mcmc.trace(
        state=(fun_mcmc.running_variance_init(true_mean.shape,
                                              data[0].dtype), 0),
        fn=kernel,
        num_steps=len(data),
        trace_fn=lambda *args: ())
    self.assertAllClose(true_mean, rvs.mean)
    self.assertAllClose(true_var, rvs.variance)