Exemple #1
0
  def testListOfStatesWhereFirstPassesSecondFails(self):
    """Simple test showing API with two states.  Read first!."""
    n_samples = 1000

    # state_0 is two scalar chains taken from iid Normal(0, 1).  Will pass.
    state_0 = rng.randn(n_samples, 2)

    # state_1 is three 4-variate chains taken from Normal(0, 1) that have been
    # shifted.  Since every chain is shifted, they are not the same, and the
    # test should fail.
    offset = np.array([1., -1., 2.]).reshape(3, 1)
    state_1 = rng.randn(n_samples, 3, 4) + offset

    rhat = mcmc_diagnostics.potential_scale_reduction(
        chains_states=[state_0, state_1], independent_chain_ndims=1)

    self.assertIsInstance(rhat, list)
    with self.test_session() as sess:
      rhat_0_, rhat_1_ = sess.run(rhat)

    # r_hat_0 should be close to 1, meaning test is passed.
    self.assertAllEqual((), rhat_0_.shape)
    self.assertAllClose(1., rhat_0_, rtol=0.02)

    # r_hat_1 should be greater than 1.2, meaning test has failed.
    self.assertAllEqual((4,), rhat_1_.shape)
    self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2)
Exemple #2
0
    def check_results(self, state_, independent_chain_shape, should_pass):
        sample_ndims = 1
        independent_chain_ndims = len(independent_chain_shape)
        with self.test_session():
            state = array_ops.placeholder_with_default(
                input=state_,
                shape=state_.shape if self.use_static_shape else None)

            rhat = mcmc_diagnostics.potential_scale_reduction(
                state, independent_chain_ndims=independent_chain_ndims)

            if self.use_static_shape:
                self.assertAllEqual(
                    state_.shape[sample_ndims + independent_chain_ndims:],
                    rhat.shape)

            rhat_ = rhat.eval()
            if should_pass:
                self.assertAllClose(np.ones_like(rhat_),
                                    rhat_,
                                    atol=0,
                                    rtol=0.02)
            else:
                self.assertAllEqual(
                    np.ones_like(rhat_).astype(bool), rhat_ > 1.2)
  def check_results(self, state_, independent_chain_shape, should_pass):
    sample_ndims = 1
    independent_chain_ndims = len(independent_chain_shape)
    with self.test_session():
      state = array_ops.placeholder_with_default(
          input=state_, shape=state_.shape if self.use_static_shape else None)

      rhat = mcmc_diagnostics.potential_scale_reduction(
          state, independent_chain_ndims=independent_chain_ndims)

      if self.use_static_shape:
        self.assertAllEqual(
            state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape)

      rhat_ = rhat.eval()
      if should_pass:
        self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02)
      else:
        self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2)
Exemple #4
0
 def testIndependentNdimsLessThanOneRaises(self):
   with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"):
     mcmc_diagnostics.potential_scale_reduction(
         rng.rand(2, 3, 4), independent_chain_ndims=0)