Example #1
0
    def testHandlesNanFromKinetic(self):
        x = [1, np.inf, -np.inf, np.nan]
        momentums, proposed_momentums = [[np.reshape(self.dtype(x), [-1, 1])]
                                         for x in np.meshgrid(x, x)]
        num_chains = len(momentums[0])

        momentums = [tf.convert_to_tensor(momentums[0])]
        proposed_momentums = [tf.convert_to_tensor(proposed_momentums[0])]

        log_acceptance_correction = _compute_log_acceptance_correction(
            momentums, proposed_momentums, independent_chain_ndims=1)
        grads = tf.gradients(log_acceptance_correction, momentums)

        [actual_log_acceptance_correction,
         grads_] = self.evaluate([log_acceptance_correction, grads])

        # Ensure log_acceptance_correction is `inf` (note: that's positive inf) in
        # weird cases and finite otherwise.
        expected_log_acceptance_correction = -(self.dtype([0] + [np.inf] *
                                                          (num_chains - 1)))
        self.assertAllEqual(expected_log_acceptance_correction,
                            actual_log_acceptance_correction)

        # Ensure gradient is finite.
        g = grads_[0].reshape([len(x), len(x)])[:, 0]
        self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))

        # The remaining gradients are nan because the momentum was itself nan or
        # inf.
        g = grads_[0].reshape([len(x), len(x)])[:, 1:]
        self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
Example #2
0
  def testHandlesNanFromKinetic(self):
    x = [1, np.inf, -np.inf, np.nan]
    momentums, proposed_momentums = [
        [np.reshape(self.dtype(x), [-1, 1])]
        for x in np.meshgrid(x, x)]
    num_chains = len(momentums[0])

    momentums = [tf.convert_to_tensor(momentums[0])]
    proposed_momentums = [tf.convert_to_tensor(proposed_momentums[0])]

    log_acceptance_correction = _compute_log_acceptance_correction(
        momentums,
        proposed_momentums,
        independent_chain_ndims=1)
    grads = tf.gradients(log_acceptance_correction, momentums)

    [actual_log_acceptance_correction, grads_] = self.evaluate([
        log_acceptance_correction, grads])

    # Ensure log_acceptance_correction is `inf` (note: that's positive inf) in
    # weird cases and finite otherwise.
    expected_log_acceptance_correction = -(
        self.dtype([0] + [np.inf]*(num_chains - 1)))
    self.assertAllEqual(expected_log_acceptance_correction,
                        actual_log_acceptance_correction)

    # Ensure gradient is finite.
    g = grads_[0].reshape([len(x), len(x)])[:, 0]
    self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))

    # The remaining gradients are nan because the momentum was itself nan or
    # inf.
    g = grads_[0].reshape([len(x), len(x)])[:, 1:]
    self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))