def test_all_distributions_either_work_or_raise_error(self, dist_name, data):
    if not tf.executing_eagerly():
      self.skipTest('No need to test every distribution in graph mode.')
    if dist_name in PRECONDITIONING_FAILS_DISTS:
      self.skipTest('Known failure.')

    dist = data.draw(dhps.base_distributions(
        dist_name=dist_name,
        enable_vars=False,
        param_strategy_fn=_constrained_zeros_fn))
    try:
      b = tfp.experimental.bijectors.make_distribution_bijector(dist)
    except NotImplementedError:
      # Okay to fail as long as explicit error is raised.
      self.skipTest('Bijector not implemented.')
    self.assertDistributionIsApproximatelyStandardNormal(tfb.Invert(b)(dist))
    def testDistribution(self, dist_name, data):
        dist = data.draw(
            dhps.base_distributions(
                dist_name=dist_name,
                enable_vars=False,
                # Unregularized MLEs can be numerically problematic, e.g., empirical
                # (co)variances can be singular. To avoid such numerical issues, we
                # sanity-check the MLE only for a fixed sample with assumed-sane
                # parameter values (zeros constrained to the parameter support).
                param_strategy_fn=_constrained_zeros_fn,
                batch_shape=data.draw(
                    tfp_hps.shapes(min_ndims=0, max_ndims=2, max_side=5))))
        x, lp = self.evaluate(
            dist.experimental_sample_and_log_prob(
                10, seed=test_util.test_seed(sampler_type='stateless')))

        try:
            parameters = self.evaluate(
                type(dist)._maximum_likelihood_parameters(x))
        except NotImplementedError:
            self.skipTest('Fitting not implemented.')

        flat_params = tf.nest.flatten(parameters)
        lp_fn = lambda *flat_params: type(dist)(  # pylint: disable=g-long-lambda
            validate_args=True,
            **tf.nest.pack_sequence_as(parameters, flat_params)).log_prob(x)
        lp_mle, grads = self.evaluate(
            tfp_math.value_and_gradient(lp_fn, flat_params))

        # Likelihood of MLE params should be higher than of the original params.
        self.assertAllGreaterEqual(
            tf.reduce_sum(lp_mle, axis=0) - tf.reduce_sum(lp, axis=0), -1e-4)

        if dist_name not in MLE_AT_CONSTRAINT_BOUNDARY:
            # MLE parameters should be a critical point of the log prob.
            for g in grads:
                if np.any(np.isnan(g)):
                    # Skip parameters with undefined or unstable gradients (e.g.,
                    # Categorical `num_classes`).
                    continue
                self.assertAllClose(tf.zeros_like(g), g, atol=1e-2)
 def testDistributionWithVars(self, dist_name, data):
     dist = data.draw(
         dhps.base_distributions(dist_name=dist_name, enable_vars=True))
     self.evaluate([var.initializer for var in dist.variables])
     self.check_bad_loc_scale(dist)
     self.check_event_space_bijector_constrains(dist, data)