Пример #1
0
    def testCategoricalCategoricalKL(self):
        def np_softmax(logits):
            exp_logits = np.exp(logits)
            return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

        with self.test_session() as sess:
            for categories in [2, 4]:
                for batch_size in [1, 10]:
                    a_logits = np.random.randn(batch_size, categories)
                    b_logits = np.random.randn(batch_size, categories)

                    a = categorical.Categorical(logits=a_logits)
                    b = categorical.Categorical(logits=b_logits)

                    kl = kullback_leibler.kl(a, b)
                    kl_val = sess.run(kl)
                    # Make sure KL(a||a) is 0
                    kl_same = sess.run(kullback_leibler.kl(a, a))

                    prob_a = np_softmax(a_logits)
                    prob_b = np_softmax(b_logits)
                    kl_expected = np.sum(prob_a *
                                         (np.log(prob_a) - np.log(prob_b)),
                                         axis=-1)

                    self.assertEqual(kl.get_shape(), (batch_size, ))
                    self.assertAllClose(kl_val, kl_expected)
                    self.assertAllClose(kl_same, np.zeros_like(kl_expected))
Пример #2
0
  def testCategoricalCategoricalKL(self):

    def np_softmax(logits):
      exp_logits = np.exp(logits)
      return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

    with self.test_session() as sess:
      for categories in [2, 4]:
        for batch_size in [1, 10]:
          a_logits = np.random.randn(batch_size, categories)
          b_logits = np.random.randn(batch_size, categories)

          a = categorical.Categorical(logits=a_logits)
          b = categorical.Categorical(logits=b_logits)

          kl = kullback_leibler.kl(a, b)
          kl_val = sess.run(kl)
          # Make sure KL(a||a) is 0
          kl_same = sess.run(kullback_leibler.kl(a, a))

          prob_a = np_softmax(a_logits)
          prob_b = np_softmax(b_logits)
          kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)),
                               axis=-1)

          self.assertEqual(kl.get_shape(), (batch_size,))
          self.assertAllClose(kl_val, kl_expected)
          self.assertAllClose(kl_same, np.zeros_like(kl_expected))
    def testBetaBetaKL(self):
        with self.test_session() as sess:
            for shape in [(10, ), (4, 5)]:
                a1 = 6.0 * np.random.random(size=shape) + 1e-4
                b1 = 6.0 * np.random.random(size=shape) + 1e-4
                a2 = 6.0 * np.random.random(size=shape) + 1e-4
                b2 = 6.0 * np.random.random(size=shape) + 1e-4
                # Take inverse softplus of values to test BetaWithSoftplusAB
                a1_sp = np.log(np.exp(a1) - 1.0)
                b1_sp = np.log(np.exp(b1) - 1.0)
                a2_sp = np.log(np.exp(a2) - 1.0)
                b2_sp = np.log(np.exp(b2) - 1.0)

                d1 = beta_lib.Beta(a=a1, b=b1)
                d2 = beta_lib.Beta(a=a2, b=b2)
                d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp)
                d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp)

                kl_expected = (special.betaln(a2, b2) -
                               special.betaln(a1, b1) +
                               (a1 - a2) * special.digamma(a1) +
                               (b1 - b2) * special.digamma(b1) +
                               (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))

                for dist1 in [d1, d1_sp]:
                    for dist2 in [d2, d2_sp]:
                        kl = kullback_leibler.kl(dist1, dist2)
                        kl_val = sess.run(kl)
                        self.assertEqual(kl.get_shape(), shape)
                        self.assertAllClose(kl_val, kl_expected)

                # Make sure KL(d1||d1) is 0
                kl_same = sess.run(kullback_leibler.kl(d1, d1))
                self.assertAllClose(kl_same, np.zeros_like(kl_expected))
Пример #4
0
  def testBetaBetaKL(self):
    with self.test_session() as sess:
      for shape in [(10,), (4, 5)]:
        a1 = 6.0 * np.random.random(size=shape) + 1e-4
        b1 = 6.0 * np.random.random(size=shape) + 1e-4
        a2 = 6.0 * np.random.random(size=shape) + 1e-4
        b2 = 6.0 * np.random.random(size=shape) + 1e-4
        # Take inverse softplus of values to test BetaWithSoftplusAB
        a1_sp = np.log(np.exp(a1) - 1.0)
        b1_sp = np.log(np.exp(b1) - 1.0)
        a2_sp = np.log(np.exp(a2) - 1.0)
        b2_sp = np.log(np.exp(b2) - 1.0)

        d1 = beta_lib.Beta(a=a1, b=b1)
        d2 = beta_lib.Beta(a=a2, b=b2)
        d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp)
        d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp)

        kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
                       (a1 - a2) * special.digamma(a1) +
                       (b1 - b2) * special.digamma(b1) +
                       (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))

        for dist1 in [d1, d1_sp]:
          for dist2 in [d2, d2_sp]:
            kl = kullback_leibler.kl(dist1, dist2)
            kl_val = sess.run(kl)
            self.assertEqual(kl.get_shape(), shape)
            self.assertAllClose(kl_val, kl_expected)

        # Make sure KL(d1||d1) is 0
        kl_same = sess.run(kullback_leibler.kl(d1, d1))
        self.assertAllClose(kl_same, np.zeros_like(kl_expected))
Пример #5
0
  def testCategoricalCategoricalKL(self):
    def np_softmax(logits):
      exp_logits = np.exp(logits)
      return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

    with self.test_session() as sess:
      for categories in [2, 10]:
        for batch_size in [1, 2]:
          p_logits = self._rng.random_sample((batch_size, categories))
          q_logits = self._rng.random_sample((batch_size, categories))
          p = onehot_categorical.OneHotCategorical(logits=p_logits)
          q = onehot_categorical.OneHotCategorical(logits=q_logits)
          prob_p = np_softmax(p_logits)
          prob_q = np_softmax(q_logits)
          kl_expected = np.sum(
              prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1)

          kl_actual = kullback_leibler.kl(p, q)
          kl_same = kullback_leibler.kl(p, p)
          x = p.sample(int(2e4), seed=0)
          x = math_ops.cast(x, dtype=dtypes.float32)
          # Compute empirical KL(p||q).
          kl_sample = math_ops.reduce_mean(p.log_prob(x) - q.log_prob(x), 0)

          [kl_sample_, kl_actual_, kl_same_] = sess.run([kl_sample, kl_actual,
                                                         kl_same])
          self.assertEqual(kl_actual.get_shape(), (batch_size,))
          self.assertAllClose(kl_same_, np.zeros_like(kl_expected))
          self.assertAllClose(kl_actual_, kl_expected, atol=0., rtol=1e-6)
          self.assertAllClose(kl_sample_, kl_expected, atol=1e-2, rtol=0.)
Пример #6
0
    def testIndirectRegistration(self):
        class Sub1(normal.Normal):
            pass

        class Sub2(normal.Normal):
            pass

        class Sub11(Sub1):
            pass

        # pylint: disable=unused-argument,unused-variable
        @kullback_leibler.RegisterKL(Sub1, Sub1)
        def _kl11(a, b, name=None):
            return "sub1-1"

        @kullback_leibler.RegisterKL(Sub1, Sub2)
        def _kl12(a, b, name=None):
            return "sub1-2"

        @kullback_leibler.RegisterKL(Sub2, Sub1)
        def _kl21(a, b, name=None):
            return "sub2-1"

        # pylint: enable=unused-argument,unused_variable

        sub1 = Sub1(loc=0.0, scale=1.0)
        sub2 = Sub2(loc=0.0, scale=1.0)
        sub11 = Sub11(loc=0.0, scale=1.0)

        self.assertEqual("sub1-1",
                         kullback_leibler.kl(sub1, sub1, allow_nan=True))
        self.assertEqual("sub1-2",
                         kullback_leibler.kl(sub1, sub2, allow_nan=True))
        self.assertEqual("sub2-1",
                         kullback_leibler.kl(sub2, sub1, allow_nan=True))
        self.assertEqual("sub1-1",
                         kullback_leibler.kl(sub11, sub11, allow_nan=True))
        self.assertEqual("sub1-1",
                         kullback_leibler.kl(sub11, sub1, allow_nan=True))
        self.assertEqual("sub1-2",
                         kullback_leibler.kl(sub11, sub2, allow_nan=True))
        self.assertEqual("sub1-1",
                         kullback_leibler.kl(sub11, sub1, allow_nan=True))
        self.assertEqual("sub1-2",
                         kullback_leibler.kl(sub11, sub2, allow_nan=True))
        self.assertEqual("sub2-1",
                         kullback_leibler.kl(sub2, sub11, allow_nan=True))
        self.assertEqual("sub1-1",
                         kullback_leibler.kl(sub1, sub11, allow_nan=True))
Пример #7
0
  def testGammaGammaKL(self):
    alpha0 = np.array([3.])
    beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5])

    alpha1 = np.array([0.4])
    beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])

    # Build graph.
    with self.test_session() as sess:
      g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0)
      g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1)
      x = g0.sample(int(1e4), seed=0)
      kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
      kl_actual = kullback_leibler.kl(g0, g1)

    # Execute graph.
    [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])

    kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0)
                   + special.gammaln(alpha1)
                   - special.gammaln(alpha0)
                   + alpha1 * np.log(beta0)
                   - alpha1 * np.log(beta1)
                   + alpha0 * (beta1 / beta0 - 1.))

    self.assertEqual(beta0.shape, kl_actual.get_shape())
    self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
    self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
Пример #8
0
    def testGammaGammaKL(self):
        alpha0 = np.array([3.])
        beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5])

        alpha1 = np.array([0.4])
        beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])

        # Build graph.
        with self.test_session() as sess:
            g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0)
            g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1)
            x = g0.sample(int(1e4), seed=0)
            kl_sample = math_ops.reduce_mean(
                g0.log_prob(x) - g1.log_prob(x), 0)
            kl_actual = kullback_leibler.kl(g0, g1)

        # Execute graph.
        [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])

        kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) +
                       special.gammaln(alpha1) - special.gammaln(alpha0) +
                       alpha1 * np.log(beta0) - alpha1 * np.log(beta1) +
                       alpha0 * (beta1 / beta0 - 1.))

        self.assertEqual(beta0.shape, kl_actual.get_shape())
        self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
        self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
 def testDefaultVariationalAndPrior(self):
     _, prior, variational, _, log_likelihood = mini_vae()
     elbo = vi.elbo(log_likelihood)
     expected_elbo = log_likelihood - kullback_leibler.kl(
         variational.distribution, prior)
     with self.test_session() as sess:
         sess.run(variables.global_variables_initializer())
         self.assertAllEqual(*sess.run([expected_elbo, elbo]))
 def testDefaultVariationalAndPrior(self):
   _, prior, variational, _, log_likelihood = mini_vae()
   elbo = vi.elbo(log_likelihood)
   expected_elbo = log_likelihood - kullback_leibler.kl(
       variational.distribution, prior)
   with self.test_session() as sess:
     sess.run(variables.global_variables_initializer())
     self.assertAllEqual(*sess.run([expected_elbo, elbo]))
 def testExplicitVariationalAndPrior(self):
   with self.test_session() as sess:
     _, _, variational, _, log_likelihood = mini_vae()
     prior = normal.Normal(mu=3., sigma=2.)
     elbo = vi.elbo(
         log_likelihood, variational_with_prior={variational: prior})
     expected_elbo = log_likelihood - kullback_leibler.kl(
         variational.distribution, prior)
     sess.run(variables.global_variables_initializer())
     self.assertAllEqual(*sess.run([expected_elbo, elbo]))
 def testExplicitVariationalAndPrior(self):
     with self.test_session() as sess:
         _, _, variational, _, log_likelihood = mini_vae()
         prior = normal.Normal(loc=3., scale=2.)
         elbo = vi.elbo(log_likelihood,
                        variational_with_prior={variational: prior})
         expected_elbo = log_likelihood - kullback_leibler.kl(
             variational.distribution, prior)
         sess.run(variables.global_variables_initializer())
         self.assertAllEqual(*sess.run([expected_elbo, elbo]))
Пример #13
0
    def testDomainErrorExceptions(self):
        class MyDistException(normal.Normal):
            pass

        # Register KL to a lambda that spits out the name parameter
        @kullback_leibler.RegisterKL(MyDistException, MyDistException)
        # pylint: disable=unused-argument,unused-variable
        def _kl(a, b, name=None):
            return array_ops.identity([float("nan")])

        # pylint: disable=unused-argument,unused-variable

        with self.test_session():
            a = MyDistException(loc=0.0, scale=1.0)
            kl = kullback_leibler.kl(a, a)
            with self.assertRaisesOpError(
                    "KL calculation between .* and .* returned NaN values"):
                kl.eval()
            kl_ok = kullback_leibler.kl(a, a, allow_nan=True)
            self.assertAllEqual([float("nan")], kl_ok.eval())
  def testIndirectRegistration(self):

    class Sub1(normal.Normal):
      pass

    class Sub2(normal.Normal):
      pass

    class Sub11(Sub1):
      pass

    # pylint: disable=unused-argument,unused-variable
    @kullback_leibler.RegisterKL(Sub1, Sub1)
    def _kl11(a, b, name=None):
      return "sub1-1"

    @kullback_leibler.RegisterKL(Sub1, Sub2)
    def _kl12(a, b, name=None):
      return "sub1-2"

    @kullback_leibler.RegisterKL(Sub2, Sub1)
    def _kl21(a, b, name=None):
      return "sub2-1"

    # pylint: enable=unused-argument,unused_variable

    sub1 = Sub1(loc=0.0, scale=1.0)
    sub2 = Sub2(loc=0.0, scale=1.0)
    sub11 = Sub11(loc=0.0, scale=1.0)

    self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True))
    self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True))
    self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1, allow_nan=True))
    self.assertEqual(
        "sub1-1", kullback_leibler.kl(sub11, sub11, allow_nan=True))
    self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True))
    self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True))
    self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True))
    self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True))
    self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11, allow_nan=True))
    self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11, allow_nan=True))
Пример #15
0
  def testRegistration(self):

    class MyDist(normal.Normal):
      pass

    # Register KL to a lambda that spits out the name parameter
    @kullback_leibler.RegisterKL(MyDist, MyDist)
    def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
      return name

    a = MyDist(loc=0.0, scale=1.0)
    self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK"))
  def testDomainErrorExceptions(self):

    class MyDistException(normal.Normal):
      pass

    # Register KL to a lambda that spits out the name parameter
    @kullback_leibler.RegisterKL(MyDistException, MyDistException)
    # pylint: disable=unused-argument,unused-variable
    def _kl(a, b, name=None):
      return array_ops.identity([float("nan")])

    # pylint: disable=unused-argument,unused-variable

    with self.test_session():
      a = MyDistException(loc=0.0, scale=1.0)
      kl = kullback_leibler.kl(a, a)
      with self.assertRaisesOpError(
          "KL calculation between .* and .* returned NaN values"):
        kl.eval()
      kl_ok = kullback_leibler.kl(a, a, allow_nan=True)
      self.assertAllEqual([float("nan")], kl_ok.eval())
Пример #17
0
    def testCategoricalCategoricalKL(self):
        def np_softmax(logits):
            exp_logits = np.exp(logits)
            return exp_logits / exp_logits.sum(axis=-1, keepdims=True)

        with self.test_session() as sess:
            for categories in [2, 10]:
                for batch_size in [1, 2]:
                    p_logits = self._rng.random_sample(
                        (batch_size, categories))
                    q_logits = self._rng.random_sample(
                        (batch_size, categories))
                    p = onehot_categorical.OneHotCategorical(logits=p_logits)
                    q = onehot_categorical.OneHotCategorical(logits=q_logits)
                    prob_p = np_softmax(p_logits)
                    prob_q = np_softmax(q_logits)
                    kl_expected = np.sum(prob_p *
                                         (np.log(prob_p) - np.log(prob_q)),
                                         axis=-1)

                    kl_actual = kullback_leibler.kl(p, q)
                    kl_same = kullback_leibler.kl(p, p)
                    x = p.sample(int(2e4), seed=0)
                    x = math_ops.cast(x, dtype=dtypes.float32)
                    # Compute empirical KL(p||q).
                    kl_sample = math_ops.reduce_mean(
                        p.log_prob(x) - q.log_prob(x), 0)

                    [kl_sample_, kl_actual_,
                     kl_same_] = sess.run([kl_sample, kl_actual, kl_same])
                    self.assertEqual(kl_actual.get_shape(), (batch_size, ))
                    self.assertAllClose(kl_same_, np.zeros_like(kl_expected))
                    self.assertAllClose(kl_actual_,
                                        kl_expected,
                                        atol=0.,
                                        rtol=1e-6)
                    self.assertAllClose(kl_sample_,
                                        kl_expected,
                                        atol=1e-2,
                                        rtol=0.)
  def testRegistration(self):

    class MyDist(normal.Normal):
      pass

    # Register KL to a lambda that spits out the name parameter
    @kullback_leibler.RegisterKL(MyDist, MyDist)
    def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
      return name

    a = MyDist(loc=0.0, scale=1.0)
    # Run kl() with allow_nan=True because strings can't go through is_nan.
    self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
Пример #19
0
    def testRegistration(self):
        class MyDist(normal.Normal):
            pass

        # Register KL to a lambda that spits out the name parameter
        @kullback_leibler.RegisterKL(MyDist, MyDist)
        def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
            return name

        a = MyDist(loc=0.0, scale=1.0)
        # Run kl() with allow_nan=True because strings can't go through is_nan.
        self.assertEqual("OK",
                         kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
Пример #20
0
    def testBernoulliBernoulliKL(self):
        with self.test_session() as sess:
            batch_size = 6
            a_p = np.array([0.5] * batch_size, dtype=np.float32)
            b_p = np.array([0.4] * batch_size, dtype=np.float32)

            a = bernoulli.Bernoulli(p=a_p)
            b = bernoulli.Bernoulli(p=b_p)

            kl = kullback_leibler.kl(a, b)
            kl_val = sess.run(kl)

            kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
                (1. - a_p) / (1. - b_p)))

            self.assertEqual(kl.get_shape(), (batch_size, ))
            self.assertAllClose(kl_val, kl_expected)
Пример #21
0
  def testBernoulliBernoulliKL(self):
    with self.test_session() as sess:
      batch_size = 6
      a_p = np.array([0.5] * batch_size, dtype=np.float32)
      b_p = np.array([0.4] * batch_size, dtype=np.float32)

      a = bernoulli.Bernoulli(probs=a_p)
      b = bernoulli.Bernoulli(probs=b_p)

      kl = kullback_leibler.kl(a, b)
      kl_val = sess.run(kl)

      kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
          (1. - a_p) / (1. - b_p)))

      self.assertEqual(kl.get_shape(), (batch_size,))
      self.assertAllClose(kl_val, kl_expected)
Пример #22
0
  def testNormalNormalKL(self):
    with self.test_session() as sess:
      batch_size = 6
      mu_a = np.array([3.0] * batch_size)
      sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
      mu_b = np.array([-3.0] * batch_size)
      sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

      n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
      n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)

      kl = kullback_leibler.kl(n_a, n_b)
      kl_val = sess.run(kl)

      kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
          (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))

      self.assertEqual(kl.get_shape(), (batch_size,))
      self.assertAllClose(kl_val, kl_expected)
Пример #23
0
  def testNormalNormalKL(self):
    with self.test_session() as sess:
      batch_size = 6
      mu_a = np.array([3.0] * batch_size)
      sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
      mu_b = np.array([-3.0] * batch_size)
      sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

      n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
      n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)

      kl = kullback_leibler.kl(n_a, n_b)
      kl_val = sess.run(kl)

      kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
          (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))

      self.assertEqual(kl.get_shape(), (batch_size,))
      self.assertAllClose(kl_val, kl_expected)
Пример #24
0
def _elbo(form, log_likelihood, log_joint, variational_with_prior,
          keep_batch_dim):
  """Internal implementation of ELBO. Users should use `elbo`.

  Args:
    form: ELBOForms constant. Controls how the ELBO is computed.
    log_likelihood: `Tensor` log p(x|Z).
    log_joint: `Tensor` log p(x, Z).
    variational_with_prior: `dict<StochasticTensor, Distribution>`, varational
      distributions to prior distributions.
    keep_batch_dim: bool. Whether to keep the batch dimension when reducing
      the entropy/KL.

  Returns:
    ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`.
  """
  ELBOForms.check_form(form)

  # Order of preference
  # 1. Analytic KL: log_likelihood - KL(q||p)
  # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q]
  # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) =
  #            log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z)

  def _reduce(val):
    if keep_batch_dim:
      return val
    else:
      return math_ops.reduce_sum(val)

  kl_terms = []
  entropy_terms = []
  prior_terms = []
  for q, z, p in [(qz.distribution, qz.value(), pz)
                  for qz, pz in variational_with_prior.items()]:
    # Analytic KL
    kl = None
    if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}:
      try:
        kl = kullback_leibler.kl(q, p)
        logging.info("Using analytic KL between q:%s, p:%s", q, p)
      except NotImplementedError as e:
        if form == ELBOForms.analytic_kl:
          raise e
    if kl is not None:
      kl_terms.append(-1. * _reduce(kl))
      continue

    # Analytic entropy
    entropy = None
    if form in {ELBOForms.default, ELBOForms.analytic_entropy}:
      try:
        entropy = q.entropy()
        logging.info("Using analytic entropy for q:%s", q)
      except NotImplementedError as e:
        if form == ELBOForms.analytic_entropy:
          raise e
    if entropy is not None:
      entropy_terms.append(_reduce(entropy))
      if log_likelihood is not None:
        prior = p.log_prob(z)
        prior_terms.append(_reduce(prior))
      continue

    # Sample
    if form in {ELBOForms.default, ELBOForms.sample}:
      entropy = -q.log_prob(z)
      entropy_terms.append(_reduce(entropy))
      if log_likelihood is not None:
        prior = p.log_prob(z)
        prior_terms.append(_reduce(prior))

  first_term = log_joint if log_joint is not None else log_likelihood
  return sum([first_term] + kl_terms + entropy_terms + prior_terms)