Exemplo n.º 1
0
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Exemplo n.º 2
0
        def trace():
            # pylint: disable=g-long-lambda
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=self._make_seed(tfp_test_util.test_seed()))

            fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init(
                state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Exemplo n.º 3
0
    def trace():
      kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
          state,
          step_size=self._constant(0.1),
          num_integrator_steps=3,
          target_log_prob_fn=target_log_prob_fn,
          seed=_test_seed())

      fun_mcmc.trace(
          state=fun_mcmc.hamiltonian_monte_carlo_init(
              tf.zeros([1], dtype=self._dtype), target_log_prob_fn),
          fn=kernel,
          num_steps=4,
          trace_fn=lambda *args: ())
Exemplo n.º 4
0
  def testRunningCovarianceMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = tf.convert_to_tensor(
        np.concatenate(
            [
                rng.randn(window_size, 2),
                np.array([1., 2.]) +
                np.array([2., 3.]) * rng.randn(window_size * 10, 2)
            ],
            axis=0,
        ))

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_covariance_step(
          rvs, data[idx], window_size=window_size)
      return (rvs, idx + 1), (rvs.mean, rvs.covariance)

    _, (mean, cov) = fun_mcmc.trace(
        state=(fun_mcmc.running_covariance_init([2], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean/variance exactly.
    self.assertAllClose(
        np.mean(data[:window_size], axis=0), mean[window_size - 1])
    self.assertAllClose(
        _gen_cov(data[:window_size], axis=0), cov[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean/variance after the change in the distribution. Since the moving
    # average is computed only over ~window_size points, this test is rather
    # noisy.
    self.assertAllClose(np.array([1., 2.]), mean[-1], atol=0.2)
    self.assertAllClose(np.array([[4., 0.], [0., 9.]]), cov[-1], atol=1.)
Exemplo n.º 5
0
    def testTraceTrace(self):
        def fun(x):
            return fun_mcmc.trace(x, lambda x: (x + 1., ()), 2, lambda *args:
                                  ())

        x, _ = fun_mcmc.trace(0., fun, 2, lambda *args: ())
        self.assertAllEqual(4., x)
Exemplo n.º 6
0
def SanitizedAutoCorrelationMean(x,
                                 axis,
                                 reduce_axis,
                                 max_lags=None,
                                 **kwargs):
    shape_arr = np.array(list(x.shape))
    axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis])))
    mean_shape = shape_arr[axes]
    if max_lags is not None:
        mean_shape[axis] = max_lags + 1
    mean_state = fun_mcmc.running_mean_init(mean_shape, x.dtype)
    new_order = list(range(len(shape_arr)))
    new_order[0] = new_order[reduce_axis]
    new_order[reduce_axis] = 0
    x = tf.transpose(x, new_order)
    x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x)
    mean_state, _ = fun_mcmc.trace(
        state=mean_state,
        fn=lambda state: fun_mcmc.running_mean_step(  # pylint: disable=g-long-lambda
            state,
            SanitizedAutoCorrelation(x_arr.read(state.num_points),
                                     axis,
                                     max_lags=max_lags,
                                     **kwargs)),
        num_steps=x.shape[0],
        trace_fn=lambda *_: ())
    return mean_state.mean
Exemplo n.º 7
0
  def testWrapTransitionKernel(self):

    class TestKernel(tfp.mcmc.TransitionKernel):

      def one_step(self, current_state, previous_kernel_results):
        return [x + 1 for x in current_state], previous_kernel_results + 1

      def bootstrap_results(self, current_state):
        return sum(current_state)

      def is_calibrated(self):
        return True

    def kernel(state, pkr):
      return fun_mcmc.transition_kernel_wrapper(state, pkr, TestKernel())

    state = {'x': 0., 'y': 1.}
    kr = 1.
    (final_state, final_kr), _ = fun_mcmc.trace(
        (state, kr),
        kernel,
        2,
        trace_fn=lambda *args: (),
    )
    self.assertAllEqual({
        'x': 2.,
        'y': 3.
    }, util.map_tree(np.array, final_state))
    self.assertAllEqual(1. + 2., final_kr)
Exemplo n.º 8
0
  def testRunningVarianceMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = tf.convert_to_tensor(
        np.concatenate(
            [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)],
            axis=0))

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_variance_step(
          rvs, data[idx], window_size=window_size)
      return (rvs, idx + 1), (rvs.mean, rvs.variance)

    _, (mean, var) = fun_mcmc.trace(
        state=(fun_mcmc.running_variance_init([], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean/variance exactly.
    self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1])
    self.assertAllClose(np.var(data[:window_size]), var[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean/variance after the change in the distribution. Since the moving
    # average is computed only over ~window_size points, this test is rather
    # noisy.
    self.assertAllClose(1., mean[-1], atol=0.2)
    self.assertAllClose(4., var[-1], atol=0.8)
Exemplo n.º 9
0
  def testPreconditionedHMC(self):
    step_size = self._constant(0.2)
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2], dtype=self._dtype)

    base_mean = self._constant([1., 0])
    base_cov = self._constant([[1, 0.5], [0.5, 1]])

    bijector = tfp.bijectors.Softplus()
    base_dist = tfp.distributions.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    def kernel(hmc_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_state.state_extra[0]

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed()))

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 10
0
  def testTraceTrace(self):

    def fun(x):
      return fun_mcmc.trace(x, lambda x: (x + 1., x + 1.), 2, trace_mask=False)

    x, trace = fun_mcmc.trace(0., fun, 2)
    self.assertAllEqual(4., x)
    self.assertAllEqual([2., 4.], trace)
Exemplo n.º 11
0
  def testTraceMask(self):

    def fun(x):
      return x + 1, (2 * x, 3 * x)

    x, (trace_1, trace_2) = fun_mcmc.trace(
        state=0, fn=fun, num_steps=3, trace_mask=(True, False))

    self.assertAllEqual(3, x)
    self.assertAllEqual([0, 2, 4], trace_1)
    self.assertAllEqual(6, trace_2)

    x, (trace_1, trace_2) = fun_mcmc.trace(
        state=0, fn=fun, num_steps=3, trace_mask=False)

    self.assertAllEqual(3, x)
    self.assertAllEqual(4, trace_1)
    self.assertAllEqual(6, trace_2)
Exemplo n.º 12
0
  def testTraceSingle(self):

    def fun(x):
      return x + 1., 2 * x

    x, e_trace = fun_mcmc.trace(
        state=0., fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1)

    self.assertAllEqual(5., x)
    self.assertAllEqual([0., 2., 4., 6., 8.], e_trace)
Exemplo n.º 13
0
  def testTraceSingle(self):
    def fun(x):
      if x is None:
        x = 0.
      return x + 1., 2 * x

    x, e_trace = fun_mcmc.trace(
        state=None, fn=fun, num_steps=5, trace_fn=lambda _, xp1: xp1)

    self.assertAllEqual(5., x.numpy())
    self.assertAllEqual([0., 2., 4., 6., 8.], e_trace.numpy())
Exemplo n.º 14
0
  def testTraceNested(self):

    def fun(x, y):
      return (x + 1., y + 2.), ()

    (x, y), (x_trace, y_trace) = fun_mcmc.trace(
        state=(0., 0.), fn=fun, num_steps=5, trace_fn=lambda xy, _: xy)

    self.assertAllEqual(5., x)
    self.assertAllEqual(10., y)
    self.assertAllEqual([1., 2., 3., 4., 5.], x_trace)
    self.assertAllEqual([2., 4., 6., 8., 10.], y_trace)
Exemplo n.º 15
0
  def testBasicHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = tf.constant([2., 3.])
    base_scale = tf.constant([2., 0.5])

    def target_log_prob_fn(x):
      return -tf.reduce_sum(0.5 * tf.square(
          (x - base_mean) / base_scale), -1), ()

    def kernel(hmc_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_extra

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state[0].state))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

    true_samples = util.random_normal(
        shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_var = tf.math.reduce_variance(true_samples, axis=0)

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
Exemplo n.º 16
0
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        if not self._is_on_jax:
          hmc_seed = _test_seed()
        else:
          hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = fun_mcmc.prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 17
0
  def testGradientDescent(self):
    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    _, [(x, y), loss] = fun_mcmc.trace(
        fun_mcmc.GradientDescentState([tf.zeros([]), tf.zeros([])]),
        lambda gd_state: fun_mcmc.gradient_descent_step(  # pylint: disable=g-long-lambda
            gd_state, loss_fn, learning_rate=0.01),
        num_steps=1000,
        trace_fn=lambda state, extra: [state.state, extra.loss])

    self.assertAllClose(1., x[-1], atol=1e-3)
    self.assertAllClose(2., y[-1], atol=1e-3)
    self.assertAllClose(0., loss[-1], atol=1e-3)
Exemplo n.º 18
0
    def testRandomWalkMetropolis(self):
        num_steps = 1000
        state = tf.ones([16], dtype=tf.int32)
        target_logits = tf.constant([1., 2., 3., 4.]) + 2.
        proposal_logits = tf.constant([4., 3., 2., 1.]) + 2.

        def target_log_prob_fn(x):
            return tf.gather(target_logits, x), ()

        def proposal_fn(x, seed):
            current_logits = tf.gather(proposal_logits, x)
            proposal = util.random_categorical(proposal_logits[tf.newaxis],
                                               x.shape[0], seed)[0]
            proposed_logits = tf.gather(proposal_logits, proposal)
            return tf.cast(proposal,
                           x.dtype), ((), proposed_logits - current_logits)

        def kernel(rwm_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                rwm_seed = tfp_test_util.test_seed()
            else:
                rwm_seed, seed = util.split_seed(seed, 2)
            rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis(
                rwm_state,
                target_log_prob_fn=target_log_prob_fn,
                proposal_fn=proposal_fn,
                seed=rwm_seed)
            return (rwm_state, seed), rwm_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = tfp_test_util.test_seed()
        else:
            seed = self._make_seed(tfp_test_util.test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
            state=(fun_mcmc.random_walk_metropolis_init(
                state, target_log_prob_fn), seed),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state[0].state))(state, seed)
        # Discard the warmup samples.
        chain = chain[500:]

        sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1])
        self.assertAllClose(tf.nn.softmax(target_logits),
                            sample_mean,
                            atol=0.1)
Exemplo n.º 19
0
  def testPreconditionedHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = [1., 0]
    base_cov = [[1, 0.5], [0.5, 1]]

    bijector = tfb.Softplus()
    base_dist = tfd.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
        state,
        step_size=step_size,
        num_integrator_steps=num_leapfrog_steps,
        target_log_prob_fn=target_log_prob_fn,
        seed=_test_seed()))

    _, chain = fun_mcmc.trace(
        state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state.state_extra[0])
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=_test_seed())

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 20
0
  def testAdam(self):

    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    _, [(x, y), loss] = fun_mcmc.trace(
        fun_mcmc.adam_init([self._constant(0.), self._constant(0.)]),
        lambda adam_state: fun_mcmc.adam_step(  # pylint: disable=g-long-lambda
            adam_state,
            loss_fn,
            learning_rate=self._constant(0.01)),
        num_steps=1000,
        trace_fn=lambda state, extra: [state.state, extra.loss])

    self.assertAllClose(1., x[-1], atol=1e-3)
    self.assertAllClose(2., y[-1], atol=1e-3)
    self.assertAllClose(0., loss[-1], atol=1e-3)
Exemplo n.º 21
0
    def computation(state):
      bijector = tfb.Softplus()
      base_dist = tfd.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step):
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn)

        rate = tf.compat.v1.train.polynomial_decay(
            0.01,
            global_step=step,
            power=0.5,
            decay_steps=num_adapt_steps,
            end_learning_rate=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(4096, seed=_test_seed())
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 22
0
  def testRunningMean(self, shape, aggregation):
    rng = np.random.RandomState(_test_seed())
    data = tf.convert_to_tensor(rng.randn(*shape))

    def kernel(rms, idx):
      rms, _ = fun_mcmc.running_mean_step(rms, data[idx], axis=aggregation)
      return (rms, idx + 1), ()

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_mean = np.mean(data, true_aggregation)

    (rms, _), _ = fun_mcmc.trace(
        state=(fun_mcmc.running_mean_init(true_mean.shape, data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
        trace_fn=lambda *args: ())

    self.assertAllClose(true_mean, rms.mean)
Exemplo n.º 23
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(
                    tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0),
                kernel,
                num_adapt_steps + num_steps,
                trace_fn=lambda state, extra:
                (state[0].state_extra[0], extra.log_accept_ratio))
            true_samples = target_dist.sample(4096,
                                              seed=tfp_test_util.test_seed())
            return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 24
0
  def testPotentialScaleReduction(self, chain_shape, independent_chain_ndims):
    rng = np.random.RandomState(_test_seed())
    chain_means = rng.randn(*((1,) + chain_shape[1:])).astype(np.float32)
    chains = 0.4 * rng.randn(*chain_shape).astype(np.float32) + chain_means

    true_rhat = tfp.mcmc.potential_scale_reduction(
        chains, independent_chain_ndims=independent_chain_ndims)

    chains = tf.convert_to_tensor(chains)
    psrs, _ = fun_mcmc.trace(
        state=fun_mcmc.potential_scale_reduction_init(chain_shape[1:],
                                                      tf.float32),
        fn=lambda psrs: fun_mcmc.potential_scale_reduction_step(  # pylint: disable=g-long-lambda
            psrs, chains[psrs.num_points]),
        num_steps=chain_shape[0],
        trace_fn=lambda *_: ())

    running_rhat = fun_mcmc.potential_scale_reduction_extract(
        psrs, independent_chain_ndims=independent_chain_ndims)
    self.assertAllClose(true_rhat, running_rhat)
Exemplo n.º 25
0
    def testRunningCovariance(self, shape, aggregation):
        data = tf.convert_to_tensor(np.random.randn(*shape))

        true_aggregation = (0, ) + (() if aggregation is None else tuple(
            [a + 1 for a in util.flatten_tree(aggregation)]))
        true_mean = np.mean(data, true_aggregation)
        true_cov = _gen_cov(data, true_aggregation)

        def kernel(rcs, idx):
            rcs, _ = fun_mcmc.running_covariance_step(rcs,
                                                      data[idx],
                                                      axis=aggregation)
            return (rcs, idx + 1), ()

        (rcs, _), _ = fun_mcmc.trace(state=(fun_mcmc.running_covariance_init(
            true_mean.shape, data[0].dtype), 0),
                                     fn=kernel,
                                     num_steps=len(data),
                                     trace_fn=lambda *args: ())
        self.assertAllClose(true_mean, rcs.mean)
        self.assertAllClose(true_cov, rcs.covariance)
Exemplo n.º 26
0
  def testRunningVariance(self, shape, aggregation):
    rng = np.random.RandomState(_test_seed())
    data = self._constant(rng.randn(*shape))

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_mean = np.mean(data, true_aggregation)
    true_var = np.var(data, true_aggregation)

    def kernel(rvs, idx):
      rvs, _ = fun_mcmc.running_variance_step(rvs, data[idx], axis=aggregation)
      return (rvs, idx + 1), ()

    (rvs, _), _ = fun_mcmc.trace(
        state=(fun_mcmc.running_variance_init(true_mean.shape,
                                              data[0].dtype), 0),
        fn=kernel,
        num_steps=len(data),
        trace_fn=lambda *args: ())
    self.assertAllClose(true_mean, rvs.mean)
    self.assertAllClose(true_var, rvs.variance)
Exemplo n.º 27
0
  def testSimpleDualAverages(self):

    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    def kernel(sda_state, rms_state):
      sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.)
      rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state)
      return (sda_state, rms_state), rms_state.mean

    _, (x, y) = fun_mcmc.trace(
        (
            fun_mcmc.simple_dual_averages_init(
                [self._constant(0.), self._constant(0.)]),
            fun_mcmc.running_mean_init([[], []], [self._dtype, self._dtype]),
        ),
        kernel,
        num_steps=1000,
    )

    self.assertAllClose(1., x[-1], atol=1e-1)
    self.assertAllClose(2., y[-1], atol=1e-1)
Exemplo n.º 28
0
  def testRunningMeanMaxPoints(self):
    window_size = 100
    rng = np.random.RandomState(_test_seed())
    data = self._constant(
        np.concatenate(
            [rng.randn(window_size), 1. + 2. * rng.randn(window_size * 10)],
            axis=0))

    def kernel(rms, idx):
      rms, _ = fun_mcmc.running_mean_step(
          rms, data[idx], window_size=window_size)
      return (rms, idx + 1), rms.mean

    _, mean = fun_mcmc.trace(
        state=(fun_mcmc.running_mean_init([], data.dtype), 0),
        fn=kernel,
        num_steps=len(data),
    )
    # Up to window_size, we compute the running mean exactly.
    self.assertAllClose(np.mean(data[:window_size]), mean[window_size - 1])
    # After window_size, we're doing exponential moving average, and pick up the
    # mean after the change in the distribution. Since the moving average is
    # computed only over ~window_size points, this test is rather noisy.
    self.assertAllClose(1., mean[-1], atol=0.2)
Exemplo n.º 29
0
 def trace_n(num_steps):
   return fun_mcmc.trace(0, lambda x: (x + 1, ()), num_steps)[0]
Exemplo n.º 30
0
 def fun(x):
   return fun_mcmc.trace(x, lambda x: (x + 1., x + 1.), 2, trace_mask=False)