def testLog1p(self): # Check that we can tell that log1p is a good idea def bad_log1p(x): # Extra indirection to fool Grappler in TF1 return tf.math.log(((1. + x) - 1.) + 1.) small = tf.constant(1e-7, dtype=tf.float32) self.assertGreater( self.evaluate(nt.excess_wrong_bits(bad_log1p, small)), 21) self.assertLessEqual( self.evaluate(nt.excess_wrong_bits(tf.math.log1p, small)), 0)
def testLogProbAccuracy(self, dist_name, data): self.skip_if_tf1() dist = tfe.as_composite( data.draw( dhps.distributions( dist_name=dist_name, # Accuracy tools can't handle batches (yet?) batch_shape=(), # Variables presumably do not affect the numerics enable_vars=False, # Checking that samples pass validations (including in 64-bit # arithmetic) is left for another test validate_args=False))) seed = test_util.test_seed(sampler_type='stateless') with tfp_hps.no_tf_rank_errors(): sample = dist.sample(seed=seed) if sample.dtype.is_floating: hp.assume(self.evaluate(tf.reduce_all(~tf.math.is_nan(sample)))) hp.note('Testing on sample {}'.format(sample)) as_tensors = tf.nest.flatten(dist, expand_composites=True) def log_prob_function(tensors, x): dist_ = tf.nest.pack_sequence_as(dist, tensors, expand_composites=True) return dist_.log_prob(x) with tfp_hps.finite_ground_truth_only(): badness = nt.excess_wrong_bits(log_prob_function, as_tensors, sample) # TODO(axch): Lower the acceptable badness to 4, which corresponds # to slightly better accuracy than 1e-6 relative error for # well-conditioned functions. self.assertAllLess(badness, 20)
def testCatastrophicCancellation(self): # Check that we can detect catastrophic cancellations def bad_identity(x): return (1. + x) - 1. # The relative error in this implementation of the identity # function is terrible. small = tf.constant(1e-6, dtype=tf.float32) self.assertGreater( self.evaluate(nt.relative_error_at(bad_identity, small)), 0.04) # But the function itself is well-conditioned self.assertAllEqual( self.evaluate(nt.inputwise_condition_numbers(bad_identity, small)), [1.]) # So the error due to poor conditioning is small self.assertLess( self.evaluate(nt.error_due_to_ill_conditioning( bad_identity, small)), 1e-13) # And we have a lot of excess wrong bits self.assertGreater( self.evaluate(nt.excess_wrong_bits(bad_identity, small)), 19)