def test_all_distributions_either_work_or_raise_error(self, dist_name, data): if not tf.executing_eagerly(): self.skipTest('No need to test every distribution in graph mode.') if dist_name in PRECONDITIONING_FAILS_DISTS: self.skipTest('Known failure.') dist = data.draw(dhps.base_distributions( dist_name=dist_name, enable_vars=False, param_strategy_fn=_constrained_zeros_fn)) try: b = tfp.experimental.bijectors.make_distribution_bijector(dist) except NotImplementedError: # Okay to fail as long as explicit error is raised. self.skipTest('Bijector not implemented.') self.assertDistributionIsApproximatelyStandardNormal(tfb.Invert(b)(dist))
def testDistribution(self, dist_name, data): dist = data.draw( dhps.base_distributions( dist_name=dist_name, enable_vars=False, # Unregularized MLEs can be numerically problematic, e.g., empirical # (co)variances can be singular. To avoid such numerical issues, we # sanity-check the MLE only for a fixed sample with assumed-sane # parameter values (zeros constrained to the parameter support). param_strategy_fn=_constrained_zeros_fn, batch_shape=data.draw( tfp_hps.shapes(min_ndims=0, max_ndims=2, max_side=5)))) x, lp = self.evaluate( dist.experimental_sample_and_log_prob( 10, seed=test_util.test_seed(sampler_type='stateless'))) try: parameters = self.evaluate( type(dist)._maximum_likelihood_parameters(x)) except NotImplementedError: self.skipTest('Fitting not implemented.') flat_params = tf.nest.flatten(parameters) lp_fn = lambda *flat_params: type(dist)( # pylint: disable=g-long-lambda validate_args=True, **tf.nest.pack_sequence_as(parameters, flat_params)).log_prob(x) lp_mle, grads = self.evaluate( tfp_math.value_and_gradient(lp_fn, flat_params)) # Likelihood of MLE params should be higher than of the original params. self.assertAllGreaterEqual( tf.reduce_sum(lp_mle, axis=0) - tf.reduce_sum(lp, axis=0), -1e-4) if dist_name not in MLE_AT_CONSTRAINT_BOUNDARY: # MLE parameters should be a critical point of the log prob. for g in grads: if np.any(np.isnan(g)): # Skip parameters with undefined or unstable gradients (e.g., # Categorical `num_classes`). continue self.assertAllClose(tf.zeros_like(g), g, atol=1e-2)
def testDistributionWithVars(self, dist_name, data): dist = data.draw( dhps.base_distributions(dist_name=dist_name, enable_vars=True)) self.evaluate([var.initializer for var in dist.variables]) self.check_bad_loc_scale(dist) self.check_event_space_bijector_constrains(dist, data)