Пример #1
0
    def testRegistrationFailures(self):
        class MyDist(normal.Normal):
            pass

        with self.assertRaisesRegexp(TypeError, "must be callable"):
            kullback_leibler.RegisterKL(MyDist, MyDist)("blah")

        # First registration is OK
        kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)

        # Second registration fails
        with self.assertRaisesRegexp(ValueError,
                                     "has already been registered"):
            kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
Пример #2
0
        # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                = tr[inv(B) A A' inv(B)']
        #                = tr[(inv(B) A) (inv(B) A)']
        #                = sum_{ik} (inv(B) A)_{ik}^2
        # The second equality follows from the cyclic permutation property.
        b_inv_a = cov_b.sqrt_solve(cov_a.sqrt_to_dense())
        t = math_ops.reduce_sum(math_ops.square(b_inv_a),
                                reduction_indices=[-1, -2])
        q = cov_b.inv_quadratic_form_on_vectors(mu_b - mu_a)
        k = math_ops.cast(cov_a.vector_space_dimension(), mvn_a.dtype)
        one_half_l = cov_b.sqrt_log_det() - cov_a.sqrt_log_det()
        return 0.5 * (t + q - k) + one_half_l


# Register KL divergences.
kl_classes = [
    MultivariateNormalFull,
    MultivariateNormalCholesky,
    MultivariateNormalDiag,
    MultivariateNormalDiagPlusVDVT,
]

for mvn_aa in kl_classes:
    # Register when they are the same here, and do not register when they are the
    # same below because that would result in a repeated registration.
    kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
    for mvn_bb in kl_classes:
        if mvn_bb != mvn_aa:
            kullback_leibler.RegisterKL(mvn_aa,
                                        mvn_bb)(_kl_mvn_mvn_brute_force)
Пример #3
0
    d1: instance of a Beta distribution object.
    d2: instance of a Beta distribution object.
    name: (optional) Name to use for created operations.
      default is "kl_beta_beta".

  Returns:
    Batchwise KL(d1 || d2)
  """
  inputs = [d1.a, d1.b, d1.a_b_sum, d2.a_b_sum]
  with ops.name_scope(name, "kl_beta_beta", inputs):
    # ln(B(a', b') / B(a, b))
    log_betas = (math_ops.lgamma(d2.a) + math_ops.lgamma(d2.b)
                - math_ops.lgamma(d2.a_b_sum) + math_ops.lgamma(d1.a_b_sum)
                - math_ops.lgamma(d1.a) - math_ops.lgamma(d1.b))
    # (a - a')*psi(a) + (b - b')*psi(b) + (a' - a + b' - b)*psi(a + b)
    digammas = ((d1.a - d2.a)*math_ops.digamma(d1.a)
              + (d1.b - d2.b)*math_ops.digamma(d1.b)
              + (d2.a_b_sum - d1.a_b_sum)*math_ops.digamma(d1.a_b_sum))
    return log_betas + digammas


# Register KL divergences.
kl_classes = [
    Beta,
    BetaWithSoftplusAB,
]

for beta_aa in kl_classes:
  for beta_bb in kl_classes:
    kullback_leibler.RegisterKL(beta_aa, beta_bb)(_kl_beta_beta)