def testRegistrationFailures(self):
        class MyDist(normal.Normal):
            pass

        with self.assertRaisesRegexp(TypeError, "must be callable"):
            kullback_leibler.RegisterKL(MyDist, MyDist)("blah")

        # First registration is OK
        kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)

        # Second registration fails
        with self.assertRaisesRegexp(ValueError,
                                     "has already been registered"):
            kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
示例#2
0
        if not self.validate_args:
            return []
        assertions = []
        if is_init != tensor_util.is_ref(self.concentration):
            assertions.append(
                assert_util.assert_positive(
                    self.concentration,
                    message='Argument `concentration` must be positive.'))
        if self.rate is not None and is_init != tensor_util.is_ref(self.rate):
            assertions.append(
                assert_util.assert_positive(
                    self.rate, message='Argument `rate` must be positive.'))
        return assertions


kullback_leibler.RegisterKL(ExpGamma, ExpGamma)(gamma_lib.kl_gamma_gamma)


# TODO(b/182603117): Remove `AutoCompositeTensor` subclass when
# `TransformedDistribution` is converted to `CompositeTensor`.
class ExpInverseGamma(transformed_distribution.TransformedDistribution,
                      distribution.AutoCompositeTensorDistribution):
    """ExpInverseGamma distribution.

  The `ExpInverseGamma` distribution is defined over the real numbers such that
  X ~ ExpInverseGamma(..) => exp(X) ~ InverseGamma(..).

  The distribution is logically equivalent to `tfb.Log()(tfd.InverseGamma(..))`,
  but can be sampled with much better precision.

  #### Mathematical Details