def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.test_session() as sess: for categories in [2, 4]: for batch_size in [1, 10]: a_logits = np.random.randn(batch_size, categories) b_logits = np.random.randn(batch_size, categories) a = categorical.Categorical(logits=a_logits) b = categorical.Categorical(logits=b_logits) kl = kullback_leibler.kl(a, b) kl_val = sess.run(kl) # Make sure KL(a||a) is 0 kl_same = sess.run(kullback_leibler.kl(a, a)) prob_a = np_softmax(a_logits) prob_b = np_softmax(b_logits) kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), axis=-1) self.assertEqual(kl.get_shape(), (batch_size, )) self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.test_session() as sess: for categories in [2, 4]: for batch_size in [1, 10]: a_logits = np.random.randn(batch_size, categories) b_logits = np.random.randn(batch_size, categories) a = categorical.Categorical(logits=a_logits) b = categorical.Categorical(logits=b_logits) kl = kullback_leibler.kl(a, b) kl_val = sess.run(kl) # Make sure KL(a||a) is 0 kl_same = sess.run(kullback_leibler.kl(a, a)) prob_a = np_softmax(a_logits) prob_b = np_softmax(b_logits) kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), axis=-1) self.assertEqual(kl.get_shape(), (batch_size,)) self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testBetaBetaKL(self): with self.test_session() as sess: for shape in [(10, ), (4, 5)]: a1 = 6.0 * np.random.random(size=shape) + 1e-4 b1 = 6.0 * np.random.random(size=shape) + 1e-4 a2 = 6.0 * np.random.random(size=shape) + 1e-4 b2 = 6.0 * np.random.random(size=shape) + 1e-4 # Take inverse softplus of values to test BetaWithSoftplusAB a1_sp = np.log(np.exp(a1) - 1.0) b1_sp = np.log(np.exp(b1) - 1.0) a2_sp = np.log(np.exp(a2) - 1.0) b2_sp = np.log(np.exp(b2) - 1.0) d1 = beta_lib.Beta(a=a1, b=b1) d2 = beta_lib.Beta(a=a2, b=b2) d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp) d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp) kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + (a1 - a2) * special.digamma(a1) + (b1 - b2) * special.digamma(b1) + (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) for dist1 in [d1, d1_sp]: for dist2 in [d2, d2_sp]: kl = kullback_leibler.kl(dist1, dist2) kl_val = sess.run(kl) self.assertEqual(kl.get_shape(), shape) self.assertAllClose(kl_val, kl_expected) # Make sure KL(d1||d1) is 0 kl_same = sess.run(kullback_leibler.kl(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testBetaBetaKL(self): with self.test_session() as sess: for shape in [(10,), (4, 5)]: a1 = 6.0 * np.random.random(size=shape) + 1e-4 b1 = 6.0 * np.random.random(size=shape) + 1e-4 a2 = 6.0 * np.random.random(size=shape) + 1e-4 b2 = 6.0 * np.random.random(size=shape) + 1e-4 # Take inverse softplus of values to test BetaWithSoftplusAB a1_sp = np.log(np.exp(a1) - 1.0) b1_sp = np.log(np.exp(b1) - 1.0) a2_sp = np.log(np.exp(a2) - 1.0) b2_sp = np.log(np.exp(b2) - 1.0) d1 = beta_lib.Beta(a=a1, b=b1) d2 = beta_lib.Beta(a=a2, b=b2) d1_sp = beta_lib.BetaWithSoftplusAB(a=a1_sp, b=b1_sp) d2_sp = beta_lib.BetaWithSoftplusAB(a=a2_sp, b=b2_sp) kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + (a1 - a2) * special.digamma(a1) + (b1 - b2) * special.digamma(b1) + (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) for dist1 in [d1, d1_sp]: for dist2 in [d2, d2_sp]: kl = kullback_leibler.kl(dist1, dist2) kl_val = sess.run(kl) self.assertEqual(kl.get_shape(), shape) self.assertAllClose(kl_val, kl_expected) # Make sure KL(d1||d1) is 0 kl_same = sess.run(kullback_leibler.kl(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.test_session() as sess: for categories in [2, 10]: for batch_size in [1, 2]: p_logits = self._rng.random_sample((batch_size, categories)) q_logits = self._rng.random_sample((batch_size, categories)) p = onehot_categorical.OneHotCategorical(logits=p_logits) q = onehot_categorical.OneHotCategorical(logits=q_logits) prob_p = np_softmax(p_logits) prob_q = np_softmax(q_logits) kl_expected = np.sum( prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1) kl_actual = kullback_leibler.kl(p, q) kl_same = kullback_leibler.kl(p, p) x = p.sample(int(2e4), seed=0) x = math_ops.cast(x, dtype=dtypes.float32) # Compute empirical KL(p||q). kl_sample = math_ops.reduce_mean(p.log_prob(x) - q.log_prob(x), 0) [kl_sample_, kl_actual_, kl_same_] = sess.run([kl_sample, kl_actual, kl_same]) self.assertEqual(kl_actual.get_shape(), (batch_size,)) self.assertAllClose(kl_same_, np.zeros_like(kl_expected)) self.assertAllClose(kl_actual_, kl_expected, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_expected, atol=1e-2, rtol=0.)
def testIndirectRegistration(self): class Sub1(normal.Normal): pass class Sub2(normal.Normal): pass class Sub11(Sub1): pass # pylint: disable=unused-argument,unused-variable @kullback_leibler.RegisterKL(Sub1, Sub1) def _kl11(a, b, name=None): return "sub1-1" @kullback_leibler.RegisterKL(Sub1, Sub2) def _kl12(a, b, name=None): return "sub1-2" @kullback_leibler.RegisterKL(Sub2, Sub1) def _kl21(a, b, name=None): return "sub2-1" # pylint: enable=unused-argument,unused_variable sub1 = Sub1(loc=0.0, scale=1.0) sub2 = Sub2(loc=0.0, scale=1.0) sub11 = Sub11(loc=0.0, scale=1.0) self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True)) self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub11, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11, allow_nan=True))
def testGammaGammaKL(self): alpha0 = np.array([3.]) beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) alpha1 = np.array([0.4]) beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. with self.test_session() as sess: g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0) g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl(g0, g1) # Execute graph. [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) + special.gammaln(alpha1) - special.gammaln(alpha0) + alpha1 * np.log(beta0) - alpha1 * np.log(beta1) + alpha0 * (beta1 / beta0 - 1.)) self.assertEqual(beta0.shape, kl_actual.get_shape()) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
def testGammaGammaKL(self): alpha0 = np.array([3.]) beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) alpha1 = np.array([0.4]) beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. with self.test_session() as sess: g0 = gamma_lib.Gamma(alpha=alpha0, beta=beta0) g1 = gamma_lib.Gamma(alpha=alpha1, beta=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = math_ops.reduce_mean( g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl(g0, g1) # Execute graph. [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual]) kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) + special.gammaln(alpha1) - special.gammaln(alpha0) + alpha1 * np.log(beta0) - alpha1 * np.log(beta1) + alpha0 * (beta1 / beta0 - 1.)) self.assertEqual(beta0.shape, kl_actual.get_shape()) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
def testDefaultVariationalAndPrior(self): _, prior, variational, _, log_likelihood = mini_vae() elbo = vi.elbo(log_likelihood) expected_elbo = log_likelihood - kullback_leibler.kl( variational.distribution, prior) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo]))
def testDefaultVariationalAndPrior(self): _, prior, variational, _, log_likelihood = mini_vae() elbo = vi.elbo(log_likelihood) expected_elbo = log_likelihood - kullback_leibler.kl( variational.distribution, prior) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo]))
def testExplicitVariationalAndPrior(self): with self.test_session() as sess: _, _, variational, _, log_likelihood = mini_vae() prior = normal.Normal(mu=3., sigma=2.) elbo = vi.elbo( log_likelihood, variational_with_prior={variational: prior}) expected_elbo = log_likelihood - kullback_leibler.kl( variational.distribution, prior) sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo]))
def testExplicitVariationalAndPrior(self): with self.test_session() as sess: _, _, variational, _, log_likelihood = mini_vae() prior = normal.Normal(loc=3., scale=2.) elbo = vi.elbo(log_likelihood, variational_with_prior={variational: prior}) expected_elbo = log_likelihood - kullback_leibler.kl( variational.distribution, prior) sess.run(variables.global_variables_initializer()) self.assertAllEqual(*sess.run([expected_elbo, elbo]))
def testDomainErrorExceptions(self): class MyDistException(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDistException, MyDistException) # pylint: disable=unused-argument,unused-variable def _kl(a, b, name=None): return array_ops.identity([float("nan")]) # pylint: disable=unused-argument,unused-variable with self.test_session(): a = MyDistException(loc=0.0, scale=1.0) kl = kullback_leibler.kl(a, a) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl.eval() kl_ok = kullback_leibler.kl(a, a, allow_nan=True) self.assertAllEqual([float("nan")], kl_ok.eval())
def testIndirectRegistration(self): class Sub1(normal.Normal): pass class Sub2(normal.Normal): pass class Sub11(Sub1): pass # pylint: disable=unused-argument,unused-variable @kullback_leibler.RegisterKL(Sub1, Sub1) def _kl11(a, b, name=None): return "sub1-1" @kullback_leibler.RegisterKL(Sub1, Sub2) def _kl12(a, b, name=None): return "sub1-2" @kullback_leibler.RegisterKL(Sub2, Sub1) def _kl21(a, b, name=None): return "sub2-1" # pylint: enable=unused-argument,unused_variable sub1 = Sub1(loc=0.0, scale=1.0) sub2 = Sub2(loc=0.0, scale=1.0) sub11 = Sub11(loc=0.0, scale=1.0) self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2, allow_nan=True)) self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1, allow_nan=True)) self.assertEqual( "sub1-1", kullback_leibler.kl(sub11, sub11, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1, allow_nan=True)) self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2, allow_nan=True)) self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11, allow_nan=True)) self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11, allow_nan=True))
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK"))
def testDomainErrorExceptions(self): class MyDistException(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDistException, MyDistException) # pylint: disable=unused-argument,unused-variable def _kl(a, b, name=None): return array_ops.identity([float("nan")]) # pylint: disable=unused-argument,unused-variable with self.test_session(): a = MyDistException(loc=0.0, scale=1.0) kl = kullback_leibler.kl(a, a) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl.eval() kl_ok = kullback_leibler.kl(a, a, allow_nan=True) self.assertAllEqual([float("nan")], kl_ok.eval())
def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.test_session() as sess: for categories in [2, 10]: for batch_size in [1, 2]: p_logits = self._rng.random_sample( (batch_size, categories)) q_logits = self._rng.random_sample( (batch_size, categories)) p = onehot_categorical.OneHotCategorical(logits=p_logits) q = onehot_categorical.OneHotCategorical(logits=q_logits) prob_p = np_softmax(p_logits) prob_q = np_softmax(q_logits) kl_expected = np.sum(prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1) kl_actual = kullback_leibler.kl(p, q) kl_same = kullback_leibler.kl(p, p) x = p.sample(int(2e4), seed=0) x = math_ops.cast(x, dtype=dtypes.float32) # Compute empirical KL(p||q). kl_sample = math_ops.reduce_mean( p.log_prob(x) - q.log_prob(x), 0) [kl_sample_, kl_actual_, kl_same_] = sess.run([kl_sample, kl_actual, kl_same]) self.assertEqual(kl_actual.get_shape(), (batch_size, )) self.assertAllClose(kl_same_, np.zeros_like(kl_expected)) self.assertAllClose(kl_actual_, kl_expected, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_expected, atol=1e-2, rtol=0.)
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) # Run kl() with allow_nan=True because strings can't go through is_nan. self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) # Run kl() with allow_nan=True because strings can't go through is_nan. self.assertEqual("OK", kullback_leibler.kl(a, a, allow_nan=True, name="OK"))
def testBernoulliBernoulliKL(self): with self.test_session() as sess: batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = bernoulli.Bernoulli(p=a_p) b = bernoulli.Bernoulli(p=b_p) kl = kullback_leibler.kl(a, b) kl_val = sess.run(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.get_shape(), (batch_size, )) self.assertAllClose(kl_val, kl_expected)
def testBernoulliBernoulliKL(self): with self.test_session() as sess: batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = bernoulli.Bernoulli(probs=a_p) b = bernoulli.Bernoulli(probs=b_p) kl = kullback_leibler.kl(a, b) kl_val = sess.run(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.get_shape(), (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def testNormalNormalKL(self): with self.test_session() as sess: batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl(n_a, n_b) kl_val = sess.run(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) self.assertEqual(kl.get_shape(), (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def testNormalNormalKL(self): with self.test_session() as sess: batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl(n_a, n_b) kl_val = sess.run(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) self.assertEqual(kl.get_shape(), (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def _elbo(form, log_likelihood, log_joint, variational_with_prior, keep_batch_dim): """Internal implementation of ELBO. Users should use `elbo`. Args: form: ELBOForms constant. Controls how the ELBO is computed. log_likelihood: `Tensor` log p(x|Z). log_joint: `Tensor` log p(x, Z). variational_with_prior: `dict<StochasticTensor, Distribution>`, varational distributions to prior distributions. keep_batch_dim: bool. Whether to keep the batch dimension when reducing the entropy/KL. Returns: ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`. """ ELBOForms.check_form(form) # Order of preference # 1. Analytic KL: log_likelihood - KL(q||p) # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q] # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) = # log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z) def _reduce(val): if keep_batch_dim: return val else: return math_ops.reduce_sum(val) kl_terms = [] entropy_terms = [] prior_terms = [] for q, z, p in [(qz.distribution, qz.value(), pz) for qz, pz in variational_with_prior.items()]: # Analytic KL kl = None if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}: try: kl = kullback_leibler.kl(q, p) logging.info("Using analytic KL between q:%s, p:%s", q, p) except NotImplementedError as e: if form == ELBOForms.analytic_kl: raise e if kl is not None: kl_terms.append(-1. * _reduce(kl)) continue # Analytic entropy entropy = None if form in {ELBOForms.default, ELBOForms.analytic_entropy}: try: entropy = q.entropy() logging.info("Using analytic entropy for q:%s", q) except NotImplementedError as e: if form == ELBOForms.analytic_entropy: raise e if entropy is not None: entropy_terms.append(_reduce(entropy)) if log_likelihood is not None: prior = p.log_prob(z) prior_terms.append(_reduce(prior)) continue # Sample if form in {ELBOForms.default, ELBOForms.sample}: entropy = -q.log_prob(z) entropy_terms.append(_reduce(entropy)) if log_likelihood is not None: prior = p.log_prob(z) prior_terms.append(_reduce(prior)) first_term = log_joint if log_joint is not None else log_likelihood return sum([first_term] + kl_terms + entropy_terms + prior_terms)