def testDistribution(self, dist_name, data): if dist_name in WORKING_PRECISION_TEST_BLOCK_LIST: self.skipTest('{} is blocked'.format(dist_name)) def eligibility_filter(name): return name not in WORKING_PRECISION_TEST_BLOCK_LIST dist = data.draw( dhps.distributions(dist_name=dist_name, eligibility_filter=eligibility_filter, enable_vars=False, validate_args=False)) hp.note('Trying distribution {}'.format( self.evaluate_dict(dist.parameters))) seed = test_util.test_seed() with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors(): samples = dist.sample(5, seed=seed) self.assertIn(samples.dtype, [tf.float32, tf.int32]) self.assertEqual(dist.log_prob(samples).dtype, tf.float32) def log_prob_function(dist, x): return dist.log_prob(x) dist64 = tf.nest.map_structure(tensor_to_f64, tfe.as_composite(dist), expand_composites=True) with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors(): result64 = log_prob_function(dist64, tensor_to_f64(samples)) self.assertEqual(result64.dtype, tf.float64)
def testCompositeTensor(self, bijector_name, data): # Test that making a composite tensor of this bijector doesn't throw any # errors. bijector, event_dim = self._draw_bijector( bijector_name, data, batch_shape=[], allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) - set(COMPOSITE_TENSOR_IS_BROKEN))) composite_bij = experimental.as_composite(bijector) flat = tf.nest.flatten(composite_bij, expand_composites=True) unflat = tf.nest.pack_sequence_as(composite_bij, flat, expand_composites=True) # Compare forward maps before and after compositing. n = 3 xs = self._draw_domain_tensor(bijector, data, event_dim, sample_shape=[n]) before_ys = bijector.forward(xs) after_ys = unflat.forward(xs) self.assertAllClose(*self.evaluate((before_ys, after_ys))) # Compare inverse maps before and after compositing. ys = self._draw_codomain_tensor(bijector, data, event_dim, sample_shape=[n]) before_xs = bijector.inverse(ys) after_xs = unflat.inverse(ys) self.assertAllClose(*self.evaluate((before_xs, after_xs)))
def _test_sample_and_log_prob(self, dist_name, dist): seed = test_util.test_seed(sampler_type='stateless') num_samples = 3 # Sample from the distribution before composite tensoring sample1 = self.evaluate(dist.sample(num_samples, seed=seed)) hp.note('Drew samples {}'.format(sample1)) # Sample from the distribution after composite tensoring composite_dist = experimental.as_composite(dist) flat = tf.nest.flatten(composite_dist, expand_composites=True) unflat = tf.nest.pack_sequence_as(composite_dist, flat, expand_composites=True) sample2 = self.evaluate(unflat.sample(num_samples, seed=seed)) hp.note('Drew samples {}'.format(sample2)) # Check that the samples are the same self.assertAllClose(sample1, sample2) # Check that all the log_probs agree for the samples from before ... ct_lp1 = unflat.log_prob(sample1) orig_lp1 = dist.log_prob(sample1) ct_lp1_, orig_lp1_ = self.evaluate((ct_lp1, orig_lp1)) self.assertAllClose(ct_lp1_, orig_lp1_) # ... and after. (Even though they're supposed to be the same anyway.) ct_lp2 = unflat.log_prob(sample2) orig_lp2 = dist.log_prob(sample2) ct_lp2_, orig_lp2_ = self.evaluate((ct_lp2, orig_lp2)) self.assertAllClose(ct_lp2_, orig_lp2_)
def testLogProbAccuracy(self, dist_name, data): self.skip_if_tf1() dist = tfe.as_composite( data.draw( dhps.distributions( dist_name=dist_name, # Accuracy tools can't handle batches (yet?) batch_shape=(), # Variables presumably do not affect the numerics enable_vars=False, # Checking that samples pass validations (including in 64-bit # arithmetic) is left for another test validate_args=False))) seed = test_util.test_seed(sampler_type='stateless') with tfp_hps.no_tf_rank_errors(): sample = dist.sample(seed=seed) if sample.dtype.is_floating: hp.assume(self.evaluate(tf.reduce_all(~tf.math.is_nan(sample)))) hp.note('Testing on sample {}'.format(sample)) as_tensors = tf.nest.flatten(dist, expand_composites=True) def log_prob_function(tensors, x): dist_ = tf.nest.pack_sequence_as(dist, tensors, expand_composites=True) return dist_.log_prob(x) with tfp_hps.finite_ground_truth_only(): badness = nt.excess_wrong_bits(log_prob_function, as_tensors, sample) # TODO(axch): Lower the acceptable badness to 4, which corresponds # to slightly better accuracy than 1e-6 relative error for # well-conditioned functions. self.assertAllLess(badness, 20)
def testCompositeTensor(self, bijector_name, data): bijector, event_dim = self._draw_bijector( bijector_name, data, batch_shape=[], validate_args=True, allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) - set(COMPOSITE_TENSOR_IS_BROKEN))) # TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector # when AutoCT is enabled for meta-bijectors and LinearOperator. if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN: composite_bij = experimental.as_composite(bijector) else: composite_bij = bijector if not tf.executing_eagerly(): composite_bij = tf.nest.map_structure( lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda if isinstance(x, DeferredTensor) else x), composite_bij, expand_composites=True) self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor) flat = tf.nest.flatten(composite_bij, expand_composites=True) unflat = tf.nest.pack_sequence_as( composite_bij, flat, expand_composites=True) # Compare forward maps before and after compositing. n = 3 xs = self._draw_domain_tensor(bijector, data, event_dim, sample_shape=[n]) before_ys = bijector.forward(xs) after_ys = unflat.forward(xs) self.assertAllClose(*self.evaluate((before_ys, after_ys))) # Compare inverse maps before and after compositing. ys = self._draw_codomain_tensor(bijector, data, event_dim, sample_shape=[n]) before_xs = bijector.inverse(ys) after_xs = unflat.inverse(ys) self.assertAllClose(*self.evaluate((before_xs, after_xs))) # Input to tf.function self.assertAllClose( before_ys, tf.function(lambda b: b.forward(xs))(composite_bij), rtol=COMPOSITE_TENSOR_RTOL[bijector_name], atol=COMPOSITE_TENSOR_ATOL[bijector_name]) # Forward mapping: Check differentiation through forward mapping with # respect to the input and parameter variables. Also check that any # variables are not referenced overmuch. xs = self._draw_domain_tensor(bijector, data, event_dim) wrt_vars = [xs] + [v for v in composite_bij.trainable_variables if v.dtype.is_floating] with tf.GradientTape() as tape: tape.watch(wrt_vars) # TODO(b/73073515): Fix graph mode gradients with bijector caching. ys = bijector.forward(xs + 0) grads = tape.gradient(ys, wrt_vars) assert_no_none_grad(bijector, 'forward', wrt_vars, grads)