def testLog1p(self):
        # Check that we can tell that log1p is a good idea

        def bad_log1p(x):
            # Extra indirection to fool Grappler in TF1
            return tf.math.log(((1. + x) - 1.) + 1.)

        small = tf.constant(1e-7, dtype=tf.float32)
        self.assertGreater(
            self.evaluate(nt.excess_wrong_bits(bad_log1p, small)), 21)

        self.assertLessEqual(
            self.evaluate(nt.excess_wrong_bits(tf.math.log1p, small)), 0)
    def testLogProbAccuracy(self, dist_name, data):
        self.skip_if_tf1()
        dist = tfe.as_composite(
            data.draw(
                dhps.distributions(
                    dist_name=dist_name,
                    # Accuracy tools can't handle batches (yet?)
                    batch_shape=(),
                    # Variables presumably do not affect the numerics
                    enable_vars=False,
                    # Checking that samples pass validations (including in 64-bit
                    # arithmetic) is left for another test
                    validate_args=False)))
        seed = test_util.test_seed(sampler_type='stateless')
        with tfp_hps.no_tf_rank_errors():
            sample = dist.sample(seed=seed)
        if sample.dtype.is_floating:
            hp.assume(self.evaluate(tf.reduce_all(~tf.math.is_nan(sample))))
        hp.note('Testing on sample {}'.format(sample))

        as_tensors = tf.nest.flatten(dist, expand_composites=True)

        def log_prob_function(tensors, x):
            dist_ = tf.nest.pack_sequence_as(dist,
                                             tensors,
                                             expand_composites=True)
            return dist_.log_prob(x)

        with tfp_hps.finite_ground_truth_only():
            badness = nt.excess_wrong_bits(log_prob_function, as_tensors,
                                           sample)
        # TODO(axch): Lower the acceptable badness to 4, which corresponds
        # to slightly better accuracy than 1e-6 relative error for
        # well-conditioned functions.
        self.assertAllLess(badness, 20)
    def testCatastrophicCancellation(self):
        # Check that we can detect catastrophic cancellations

        def bad_identity(x):
            return (1. + x) - 1.

        # The relative error in this implementation of the identity
        # function is terrible.
        small = tf.constant(1e-6, dtype=tf.float32)
        self.assertGreater(
            self.evaluate(nt.relative_error_at(bad_identity, small)), 0.04)

        # But the function itself is well-conditioned
        self.assertAllEqual(
            self.evaluate(nt.inputwise_condition_numbers(bad_identity, small)),
            [1.])

        # So the error due to poor conditioning is small
        self.assertLess(
            self.evaluate(nt.error_due_to_ill_conditioning(
                bad_identity, small)), 1e-13)

        # And we have a lot of excess wrong bits
        self.assertGreater(
            self.evaluate(nt.excess_wrong_bits(bad_identity, small)), 19)