Пример #1
0
    def testSampleChainSeedReproducibleWorksCorrectly(self):
        with self.test_session(graph=ops.Graph()) as sess:
            num_results = 10
            independent_chain_ndims = 1

            def log_gamma_log_prob(x):
                event_dims = math_ops.range(independent_chain_ndims,
                                            array_ops.rank(x))
                return self._log_gamma_log_prob(x, event_dims)

            kwargs = dict(
                target_log_prob_fn=log_gamma_log_prob,
                current_state=np.random.rand(4, 3, 2),
                step_size=0.1,
                num_leapfrog_steps=2,
                num_burnin_steps=150,
                seed=52,
            )

            samples0, kernel_results0 = hmc.sample_chain(**dict(
                list(kwargs.items()) + list(
                    dict(num_results=2 * num_results,
                         num_steps_between_results=0).items())))

            samples1, kernel_results1 = hmc.sample_chain(**dict(
                list(kwargs.items()) + list(
                    dict(num_results=num_results,
                         num_steps_between_results=1).items())))

            [
                samples0_,
                samples1_,
                target_log_prob0_,
                target_log_prob1_,
            ] = sess.run([
                samples0,
                samples1,
                kernel_results0.current_target_log_prob,
                kernel_results1.current_target_log_prob,
            ])
            self.assertAllClose(samples0_[::2],
                                samples1_,
                                atol=1e-5,
                                rtol=1e-5)
            self.assertAllClose(target_log_prob0_[::2],
                                target_log_prob1_,
                                atol=1e-5,
                                rtol=1e-5)
Пример #2
0
 def testStateParts(self):
   with self.test_session() as sess:
     dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
     dist_y = independent_lib.Independent(
         gamma_lib.Gamma(concentration=self.dtype([1, 2]),
                         rate=self.dtype([0.5, 0.75])),
         reinterpreted_batch_ndims=1)
     def target_log_prob(x, y):
       return dist_x.log_prob(x) + dist_y.log_prob(y)
     x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
     samples, _ = hmc.sample_chain(
         num_results=int(2e3),
         target_log_prob_fn=target_log_prob,
         current_state=x0,
         step_size=0.85,
         num_leapfrog_steps=3,
         num_burnin_steps=int(250),
         seed=49)
     actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
     actual_vars = [_reduce_variance(s, axis=0) for s in samples]
     expected_means = [dist_x.mean(), dist_y.mean()]
     expected_vars = [dist_x.variance(), dist_y.variance()]
     [
         actual_means_,
         actual_vars_,
         expected_means_,
         expected_vars_,
     ] = sess.run([
         actual_means,
         actual_vars,
         expected_means,
         expected_vars,
     ])
     self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
     self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.30)
Пример #3
0
 def testStateParts(self):
   with self.test_session(graph=ops.Graph()) as sess:
     dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
     dist_y = independent_lib.Independent(
         gamma_lib.Gamma(concentration=self.dtype([1, 2]),
                         rate=self.dtype([0.5, 0.75])),
         reinterpreted_batch_ndims=1)
     def target_log_prob(x, y):
       return dist_x.log_prob(x) + dist_y.log_prob(y)
     x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
     samples, _ = hmc.sample_chain(
         num_results=int(2e3),
         target_log_prob_fn=target_log_prob,
         current_state=x0,
         step_size=0.85,
         num_leapfrog_steps=3,
         num_burnin_steps=int(250),
         seed=49)
     actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
     actual_vars = [_reduce_variance(s, axis=0) for s in samples]
     expected_means = [dist_x.mean(), dist_y.mean()]
     expected_vars = [dist_x.variance(), dist_y.variance()]
     [
         actual_means_,
         actual_vars_,
         expected_means_,
         expected_vars_,
     ] = sess.run([
         actual_means,
         actual_vars,
         expected_means,
         expected_vars,
     ])
     self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
     self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
Пример #4
0
    def _chain_gets_correct_expectations(self,
                                         x,
                                         independent_chain_ndims,
                                         sess,
                                         feed_dict=None):
        counter = collections.Counter()

        def log_gamma_log_prob(x):
            counter["target_calls"] += 1
            event_dims = math_ops.range(independent_chain_ndims,
                                        array_ops.rank(x))
            return self._log_gamma_log_prob(x, event_dims)

        num_results = array_ops.placeholder(np.int32, [], name="num_results")
        step_size = array_ops.placeholder(np.float32, [], name="step_size")
        num_leapfrog_steps = array_ops.placeholder(np.int32, [],
                                                   name="num_leapfrog_steps")

        if feed_dict is None:
            feed_dict = {}
        feed_dict.update({
            num_results: 150,
            step_size: 0.05,
            num_leapfrog_steps: 2
        })

        samples, kernel_results = hmc.sample_chain(
            num_results=num_results,
            target_log_prob_fn=log_gamma_log_prob,
            current_state=x,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            num_burnin_steps=150,
            seed=42)

        self.assertAllEqual(dict(target_calls=2), counter)

        expected_x = (math_ops.digamma(self._shape_param) -
                      np.log(self._rate_param))

        expected_exp_x = self._shape_param / self._rate_param

        log_accept_ratio_, samples_, expected_x_ = sess.run(
            [kernel_results.log_accept_ratio, samples, expected_x], feed_dict)

        actual_x = samples_.mean()
        actual_exp_x = np.exp(samples_).mean()
        acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.))

        logging_ops.vlog(
            1, "True      E[x, exp(x)]: {}\t{}".format(expected_x_,
                                                       expected_exp_x))
        logging_ops.vlog(
            1, "Estimated E[x, exp(x)]: {}\t{}".format(actual_x, actual_exp_x))
        self.assertNear(actual_x, expected_x_, 2e-2)
        self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
        self.assertAllEqual(np.ones_like(acceptance_probs, np.bool),
                            acceptance_probs > 0.5)
        self.assertAllEqual(np.ones_like(acceptance_probs, np.bool),
                            acceptance_probs <= 1.)
Пример #5
0
  def testSampleChainSeedReproducibleWorksCorrectly(self):
    with self.test_session(graph=ops.Graph()) as sess:
      num_results = 10
      independent_chain_ndims = 1

      def log_gamma_log_prob(x):
        event_dims = math_ops.range(independent_chain_ndims,
                                    array_ops.rank(x))
        return self._log_gamma_log_prob(x, event_dims)

      kwargs = dict(
          target_log_prob_fn=log_gamma_log_prob,
          current_state=np.random.rand(4, 3, 2),
          step_size=0.1,
          num_leapfrog_steps=2,
          num_burnin_steps=150,
          seed=52,
      )

      samples0, kernel_results0 = hmc.sample_chain(
          **dict(list(kwargs.items()) + list(dict(
              num_results=2 * num_results,
              num_steps_between_results=0).items())))

      samples1, kernel_results1 = hmc.sample_chain(
          **dict(list(kwargs.items()) + list(dict(
              num_results=num_results,
              num_steps_between_results=1).items())))

      [
          samples0_,
          samples1_,
          target_log_prob0_,
          target_log_prob1_,
      ] = sess.run([
          samples0,
          samples1,
          kernel_results0.current_target_log_prob,
          kernel_results1.current_target_log_prob,
      ])
      self.assertAllClose(samples0_[::2], samples1_,
                          atol=1e-5, rtol=1e-5)
      self.assertAllClose(target_log_prob0_[::2], target_log_prob1_,
                          atol=1e-5, rtol=1e-5)
Пример #6
0
  def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
                                       sess, feed_dict=None):
    counter = collections.Counter()
    def log_gamma_log_prob(x):
      counter["target_calls"] += 1
      event_dims = math_ops.range(independent_chain_ndims,
                                  array_ops.rank(x))
      return self._log_gamma_log_prob(x, event_dims)

    num_results = array_ops.placeholder(
        np.int32, [], name="num_results")
    step_size = array_ops.placeholder(
        np.float32, [], name="step_size")
    num_leapfrog_steps = array_ops.placeholder(
        np.int32, [], name="num_leapfrog_steps")

    if feed_dict is None:
      feed_dict = {}
    feed_dict.update({num_results: 150,
                      step_size: 0.05,
                      num_leapfrog_steps: 2})

    samples, kernel_results = hmc.sample_chain(
        num_results=num_results,
        target_log_prob_fn=log_gamma_log_prob,
        current_state=x,
        step_size=step_size,
        num_leapfrog_steps=num_leapfrog_steps,
        num_burnin_steps=150,
        seed=42)

    self.assertAllEqual(dict(target_calls=2), counter)

    expected_x = (math_ops.digamma(self._shape_param)
                  - np.log(self._rate_param))

    expected_exp_x = self._shape_param / self._rate_param

    log_accept_ratio_, samples_, expected_x_ = sess.run(
        [kernel_results.log_accept_ratio, samples, expected_x],
        feed_dict)

    actual_x = samples_.mean()
    actual_exp_x = np.exp(samples_).mean()
    acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.))

    logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
        expected_x_, expected_exp_x))
    logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
        actual_x, actual_exp_x))
    self.assertNear(actual_x, expected_x_, 2e-2)
    self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
    self.assertAllEqual(np.ones_like(acceptance_probs, np.bool),
                        acceptance_probs > 0.5)
    self.assertAllEqual(np.ones_like(acceptance_probs, np.bool),
                        acceptance_probs <= 1.)
Пример #7
0
    def _chain_gets_correct_expectations(self,
                                         x,
                                         independent_chain_ndims,
                                         sess,
                                         feed_dict=None):
        def log_gamma_log_prob(x):
            event_dims = math_ops.range(independent_chain_ndims,
                                        array_ops.rank(x))
            return self._log_gamma_log_prob(x, event_dims)

        num_results = array_ops.placeholder(np.int32, [], name="num_results")
        step_size = array_ops.placeholder(np.float32, [], name="step_size")
        num_leapfrog_steps = array_ops.placeholder(np.int32, [],
                                                   name="num_leapfrog_steps")

        if feed_dict is None:
            feed_dict = {}
        feed_dict.update({
            num_results: 150,
            step_size: 0.1,
            num_leapfrog_steps: 2
        })

        samples, kernel_results = hmc.sample_chain(
            num_results=num_results,
            target_log_prob_fn=log_gamma_log_prob,
            current_state=x,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            num_burnin_steps=150,
            seed=42)

        expected_x = (math_ops.digamma(self._shape_param) -
                      np.log(self._rate_param))

        expected_exp_x = self._shape_param / self._rate_param

        acceptance_probs_, samples_, expected_x_ = sess.run(
            [kernel_results.acceptance_probs, samples, expected_x], feed_dict)

        actual_x = samples_.mean()
        actual_exp_x = np.exp(samples_).mean()

        logging_ops.vlog(
            1, "True      E[x, exp(x)]: {}\t{}".format(expected_x_,
                                                       expected_exp_x))
        logging_ops.vlog(
            1, "Estimated E[x, exp(x)]: {}\t{}".format(actual_x, actual_exp_x))
        self.assertNear(actual_x, expected_x_, 2e-2)
        self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
        self.assertTrue((acceptance_probs_ > 0.5).all())
        self.assertTrue((acceptance_probs_ <= 1.0).all())
Пример #8
0
 def _testChainWorksDtype(self, dtype):
   states, kernel_results = hmc.sample_chain(
       num_results=10,
       target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
       current_state=np.zeros(5).astype(dtype),
       step_size=0.01,
       num_leapfrog_steps=10,
       seed=48)
   with self.test_session() as sess:
     states_, acceptance_probs_ = sess.run(
         [states, kernel_results.acceptance_probs])
   self.assertEqual(dtype, states_.dtype)
   self.assertEqual(dtype, acceptance_probs_.dtype)
Пример #9
0
 def _testChainWorksDtype(self, dtype):
     states, kernel_results = hmc.sample_chain(
         num_results=10,
         target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
         current_state=np.zeros(5).astype(dtype),
         step_size=0.01,
         num_leapfrog_steps=10,
         seed=48)
     with self.test_session() as sess:
         states_, acceptance_probs_ = sess.run(
             [states, kernel_results.acceptance_probs])
     self.assertEqual(dtype, states_.dtype)
     self.assertEqual(dtype, acceptance_probs_.dtype)
Пример #10
0
  def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
                                       sess, feed_dict=None):
    def log_gamma_log_prob(x):
      event_dims = math_ops.range(independent_chain_ndims,
                                  array_ops.rank(x))
      return self._log_gamma_log_prob(x, event_dims)

    num_results = array_ops.placeholder(
        np.int32, [], name="num_results")
    step_size = array_ops.placeholder(
        np.float32, [], name="step_size")
    num_leapfrog_steps = array_ops.placeholder(
        np.int32, [], name="num_leapfrog_steps")

    if feed_dict is None:
      feed_dict = {}
    feed_dict.update({num_results: 150,
                      step_size: 0.1,
                      num_leapfrog_steps: 2})

    samples, kernel_results = hmc.sample_chain(
        num_results=num_results,
        target_log_prob_fn=log_gamma_log_prob,
        current_state=x,
        step_size=step_size,
        num_leapfrog_steps=num_leapfrog_steps,
        num_burnin_steps=150,
        seed=42)

    expected_x = (math_ops.digamma(self._shape_param)
                  - np.log(self._rate_param))

    expected_exp_x = self._shape_param / self._rate_param

    acceptance_probs_, samples_, expected_x_ = sess.run(
        [kernel_results.acceptance_probs, samples, expected_x],
        feed_dict)

    actual_x = samples_.mean()
    actual_exp_x = np.exp(samples_).mean()

    logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
        expected_x_, expected_exp_x))
    logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
        actual_x, actual_exp_x))
    self.assertNear(actual_x, expected_x_, 2e-2)
    self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
    self.assertTrue((acceptance_probs_ > 0.5).all())
    self.assertTrue((acceptance_probs_ <= 1.0).all())
Пример #11
0
 def testChainWorksCorrelatedMultivariate(self):
   dtype = np.float32
   true_mean = dtype([0, 0])
   true_cov = dtype([[1, 0.5],
                     [0.5, 1]])
   num_results = 2000
   counter = collections.Counter()
   with self.test_session(graph=ops.Graph()) as sess:
     def target_log_prob(x, y):
       counter["target_calls"] += 1
       # Corresponds to unnormalized MVN.
       # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
       z = array_ops.stack([x, y], axis=-1) - true_mean
       z = array_ops.squeeze(
           gen_linalg_ops.matrix_triangular_solve(
               np.linalg.cholesky(true_cov),
               z[..., array_ops.newaxis]),
           axis=-1)
       return -0.5 * math_ops.reduce_sum(z**2., axis=-1)
     states, _ = hmc.sample_chain(
         num_results=num_results,
         target_log_prob_fn=target_log_prob,
         current_state=[dtype(-2), dtype(2)],
         step_size=[0.5, 0.5],
         num_leapfrog_steps=2,
         num_burnin_steps=200,
         num_steps_between_results=1,
         seed=54)
     self.assertAllEqual(dict(target_calls=2), counter)
     states = array_ops.stack(states, axis=-1)
     self.assertEqual(num_results, states.shape[0].value)
     sample_mean = math_ops.reduce_mean(states, axis=0)
     x = states - sample_mean
     sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results)
     [sample_mean_, sample_cov_] = sess.run([
         sample_mean, sample_cov])
     self.assertAllClose(true_mean, sample_mean_,
                         atol=0.05, rtol=0.)
     self.assertAllClose(true_cov, sample_cov_,
                         atol=0., rtol=0.1)
Пример #12
0
 def testChainWorksCorrelatedMultivariate(self):
   dtype = np.float32
   true_mean = dtype([0, 0])
   true_cov = dtype([[1, 0.5],
                     [0.5, 1]])
   num_results = 2000
   counter = collections.Counter()
   with self.test_session(graph=ops.Graph()) as sess:
     def target_log_prob(x, y):
       counter["target_calls"] += 1
       # Corresponds to unnormalized MVN.
       # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
       z = array_ops.stack([x, y], axis=-1) - true_mean
       z = array_ops.squeeze(
           gen_linalg_ops.matrix_triangular_solve(
               np.linalg.cholesky(true_cov),
               z[..., array_ops.newaxis]),
           axis=-1)
       return -0.5 * math_ops.reduce_sum(z**2., axis=-1)
     states, _ = hmc.sample_chain(
         num_results=num_results,
         target_log_prob_fn=target_log_prob,
         current_state=[dtype(-2), dtype(2)],
         step_size=[0.5, 0.5],
         num_leapfrog_steps=2,
         num_burnin_steps=200,
         num_steps_between_results=1,
         seed=54)
     self.assertAllEqual(dict(target_calls=2), counter)
     states = array_ops.stack(states, axis=-1)
     self.assertEqual(num_results, states.shape[0].value)
     sample_mean = math_ops.reduce_mean(states, axis=0)
     x = states - sample_mean
     sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results)
     [sample_mean_, sample_cov_] = sess.run([
         sample_mean, sample_cov])
     self.assertAllClose(true_mean, sample_mean_,
                         atol=0.05, rtol=0.)
     self.assertAllClose(true_cov, sample_cov_,
                         atol=0., rtol=0.1)
Пример #13
0
    def testKernelResultsUsingTruncatedDistribution(self):
        def log_prob(x):
            return array_ops.where(
                x >= 0.,
                -x - x**2,  # Non-constant gradient.
                array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype)))

        # This log_prob has the property that it is likely to attract
        # the HMC flow toward, and below, zero...but for x <=0,
        # log_prob(x) = -inf, which should result in rejection, as well
        # as a non-finite log_prob.  Thus, this distribution gives us an opportunity
        # to test out the kernel results ability to correctly capture rejections due
        # to finite AND non-finite reasons.
        # Why use a non-constant gradient?  This ensures the leapfrog integrator
        # will not be exact.

        num_results = 1000
        # Large step size, will give rejections due to integration error in addition
        # to rejection due to going into a region of log_prob = -inf.
        step_size = 0.1
        num_leapfrog_steps = 5
        num_chains = 2

        with self.test_session(graph=ops.Graph()) as sess:

            # Start multiple independent chains.
            initial_state = ops.convert_to_tensor([0.1] * num_chains)

            states, kernel_results = hmc.sample_chain(
                num_results=num_results,
                target_log_prob_fn=log_prob,
                current_state=initial_state,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                seed=42)

            states_, kernel_results_ = sess.run([states, kernel_results])
            pstates_ = kernel_results_.proposed_state

            neg_inf_mask = np.isneginf(
                kernel_results_.proposed_target_log_prob)

            # First:  Test that the mathematical properties of the above log prob
            # function in conjunction with HMC show up as expected in kernel_results_.

            # We better have log_prob = -inf some of the time.
            self.assertLess(0, neg_inf_mask.sum())
            # We better have some rejections due to something other than -inf.
            self.assertLess(neg_inf_mask.sum(),
                            (~kernel_results_.is_accepted).sum())
            # We better have been accepted a decent amount, even near the end of the
            # chain, or else this HMC run just got stuck at some point.
            self.assertLess(
                0.1,
                kernel_results_.is_accepted[int(0.9 * num_results):].mean())
            # We better not have any NaNs in proposed state or log_prob.
            # We may have some NaN in grads, which involve multiplication/addition due
            # to gradient rules.  This is the known "NaN grad issue with tf.where."
            self.assertAllEqual(
                np.zeros_like(states_),
                np.isnan(kernel_results_.proposed_target_log_prob))
            self.assertAllEqual(np.zeros_like(states_), np.isnan(states_))
            # We better not have any +inf in states, grads, or log_prob.
            self.assertAllEqual(
                np.zeros_like(states_),
                np.isposinf(kernel_results_.proposed_target_log_prob))
            self.assertAllEqual(
                np.zeros_like(states_),
                np.isposinf(kernel_results_.proposed_grads_target_log_prob[0]))
            self.assertAllEqual(np.zeros_like(states_), np.isposinf(states_))

            # Second:  Test that kernel_results is congruent with itself and
            # acceptance/rejection of states.

            # Proposed state is negative iff proposed target log prob is -inf.
            np.testing.assert_array_less(pstates_[neg_inf_mask], 0.)
            np.testing.assert_array_less(0., pstates_[~neg_inf_mask])

            # Acceptance probs are zero whenever proposed state is negative.
            self.assertAllEqual(np.zeros_like(pstates_[neg_inf_mask]),
                                kernel_results_.acceptance_probs[neg_inf_mask])

            # The move is accepted ==> state = proposed state.
            self.assertAllEqual(
                states_[kernel_results_.is_accepted],
                pstates_[kernel_results_.is_accepted],
            )
            # The move was rejected <==> state[t] == state[t - 1].
            for t in range(1, num_results):
                for i in range(num_chains):
                    if kernel_results_.is_accepted[t, i]:
                        self.assertNotEqual(states_[t, i], states_[t - 1, i])
                    else:
                        self.assertEqual(states_[t, i], states_[t - 1, i])
Пример #14
0
  def testKernelResultsUsingTruncatedDistribution(self):
    def log_prob(x):
      return array_ops.where(
          x >= 0.,
          -x - x**2,  # Non-constant gradient.
          array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype)))
    # This log_prob has the property that it is likely to attract
    # the flow toward, and below, zero...but for x <=0,
    # log_prob(x) = -inf, which should result in rejection, as well
    # as a non-finite log_prob.  Thus, this distribution gives us an opportunity
    # to test out the kernel results ability to correctly capture rejections due
    # to finite AND non-finite reasons.
    # Why use a non-constant gradient?  This ensures the leapfrog integrator
    # will not be exact.

    num_results = 1000
    # Large step size, will give rejections due to integration error in addition
    # to rejection due to going into a region of log_prob = -inf.
    step_size = 0.1
    num_leapfrog_steps = 5
    num_chains = 2

    with self.test_session(graph=ops.Graph()) as sess:

      # Start multiple independent chains.
      initial_state = ops.convert_to_tensor([0.1] * num_chains)

      states, kernel_results = hmc.sample_chain(
          num_results=num_results,
          target_log_prob_fn=log_prob,
          current_state=initial_state,
          step_size=step_size,
          num_leapfrog_steps=num_leapfrog_steps,
          seed=42)

      states_, kernel_results_ = sess.run([states, kernel_results])
      pstates_ = kernel_results_.proposed_state

      neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob)

      # First:  Test that the mathematical properties of the above log prob
      # function in conjunction with HMC show up as expected in kernel_results_.

      # We better have log_prob = -inf some of the time.
      self.assertLess(0, neg_inf_mask.sum())
      # We better have some rejections due to something other than -inf.
      self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum())
      # We better have accepted a decent amount, even near end of the chain.
      self.assertLess(
          0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean())
      # We better not have any NaNs in states or log_prob.
      # We may have some NaN in grads, which involve multiplication/addition due
      # to gradient rules.  This is the known "NaN grad issue with tf.where."
      self.assertAllEqual(np.zeros_like(states_),
                          np.isnan(kernel_results_.proposed_target_log_prob))
      self.assertAllEqual(np.zeros_like(states_),
                          np.isnan(states_))
      # We better not have any +inf in states, grads, or log_prob.
      self.assertAllEqual(np.zeros_like(states_),
                          np.isposinf(kernel_results_.proposed_target_log_prob))
      self.assertAllEqual(
          np.zeros_like(states_),
          np.isposinf(kernel_results_.proposed_grads_target_log_prob[0]))
      self.assertAllEqual(np.zeros_like(states_),
                          np.isposinf(states_))

      # Second:  Test that kernel_results is congruent with itself and
      # acceptance/rejection of states.

      # Proposed state is negative iff proposed target log prob is -inf.
      np.testing.assert_array_less(pstates_[neg_inf_mask], 0.)
      np.testing.assert_array_less(0., pstates_[~neg_inf_mask])

      # Acceptance probs are zero whenever proposed state is negative.
      acceptance_probs = np.exp(np.minimum(
          kernel_results_.log_accept_ratio, 0.))
      self.assertAllEqual(
          np.zeros_like(pstates_[neg_inf_mask]),
          acceptance_probs[neg_inf_mask])

      # The move is accepted ==> state = proposed state.
      self.assertAllEqual(
          states_[kernel_results_.is_accepted],
          pstates_[kernel_results_.is_accepted],
      )
      # The move was rejected <==> state[t] == state[t - 1].
      for t in range(1, num_results):
        for i in range(num_chains):
          if kernel_results_.is_accepted[t, i]:
            self.assertNotEqual(states_[t, i], states_[t - 1, i])
          else:
            self.assertEqual(states_[t, i], states_[t - 1, i])