def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.cached_session() as sess: for categories in [2, 4]: for batch_size in [1, 10]: a_logits = np.random.randn(batch_size, categories) b_logits = np.random.randn(batch_size, categories) a = categorical.Categorical(logits=a_logits) b = categorical.Categorical(logits=b_logits) kl = kullback_leibler.kl_divergence(a, b) kl_val = sess.run(kl) # Make sure KL(a||a) is 0 kl_same = sess.run(kullback_leibler.kl_divergence(a, a)) prob_a = np_softmax(a_logits) prob_b = np_softmax(b_logits) kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), axis=-1) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testDomainErrorExceptions(self): class MyDistException(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDistException, MyDistException) # pylint: disable=unused-argument,unused-variable def _kl(a, b, name=None): return tf.identity([float("nan")]) # pylint: disable=unused-argument,unused-variable with self.cached_session(): a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False) kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl.eval() with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): a.kl_divergence(a).eval() a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True) kl_ok = kullback_leibler.kl_divergence(a, a) self.assertAllEqual([float("nan")], kl_ok.eval()) self_kl_ok = a.kl_divergence(a) self.assertAllEqual([float("nan")], self_kl_ok.eval()) cross_ok = a.cross_entropy(a) self.assertAllEqual([float("nan")], cross_ok.eval())
def testCategoricalCategoricalKL(self): def np_softmax(logits): exp_logits = np.exp(logits) return exp_logits / exp_logits.sum(axis=-1, keepdims=True) with self.cached_session() as sess: for categories in [2, 4]: for batch_size in [1, 10]: a_logits = np.random.randn(batch_size, categories) b_logits = np.random.randn(batch_size, categories) a = categorical.Categorical(logits=a_logits) b = categorical.Categorical(logits=b_logits) kl = kullback_leibler.kl_divergence(a, b) kl_val = sess.run(kl) # Make sure KL(a||a) is 0 kl_same = sess.run(kullback_leibler.kl_divergence(a, a)) prob_a = np_softmax(a_logits) prob_b = np_softmax(b_logits) kl_expected = np.sum(prob_a * (np.log(prob_a) - np.log(prob_b)), axis=-1) self.assertEqual(kl.shape, (batch_size, )) self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testDirichletDirichletKL(self): conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) d1 = dirichlet_lib.Dirichlet(conc1) d2 = dirichlet_lib.Dirichlet(conc2) x = d1.sample(int(1e4), seed=0) kl_sample = tf.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(d1, d2) kl_sample_val = self.evaluate(kl_sample) kl_actual_val = self.evaluate(kl_actual) self.assertEqual(conc1.shape[:-1], kl_actual.shape) if not special: return kl_expected = ( special.gammaln(np.sum(conc1, -1)) - special.gammaln(np.sum(conc2, -1)) - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( np.sum(conc1, -1, keepdims=True))), -1)) self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) # Make sure KL(d1||d1) is 0 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testDomainErrorExceptions(self): class MyDistException(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDistException, MyDistException) # pylint: disable=unused-argument,unused-variable def _kl(a, b, name=None): return tf.identity([float("nan")]) # pylint: disable=unused-argument,unused-variable a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False) self.evaluate(kl) with self.assertRaisesOpError( "KL calculation between .* and .* returned NaN values"): self.evaluate(a.kl_divergence(a)) a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True) kl_ok = kullback_leibler.kl_divergence(a, a) self.assertAllEqual([float("nan")], self.evaluate(kl_ok)) self_kl_ok = a.kl_divergence(a) self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok)) cross_ok = a.cross_entropy(a) self.assertAllEqual([float("nan")], self.evaluate(cross_ok))
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK")) # Check that everything still works for an object `d` for which # `d.__class__` is a registered type but `type(d)` is not registered. w_a = getattr(wrapt, "ObjectProxy", lambda x: x)(a) self.assertTrue(wrapt is None or type(w_a) != type(a)) # pylint: disable=unidiomatic-typecheck self.assertEqual(a.__class__, w_a.__class__) self.assertEqual("OK", kullback_leibler.kl_divergence(a, w_a, name="OK")) self.assertEqual("OK", kullback_leibler.kl_divergence(w_a, a, name="OK")) self.assertEqual("OK", kullback_leibler.kl_divergence(w_a, w_a, name="OK")) # NOTE: `Distribution.kl_divergence(other, name)` does not pass the name # through to `kullback_leibler.kl_divergence` self.assertEqual("KullbackLeibler", a.kl_divergence(w_a, name="OK")) self.assertEqual("KullbackLeibler", w_a.kl_divergence(a, name="OK")) self.assertEqual("KullbackLeibler", w_a.kl_divergence(w_a, name="OK"))
def testBetaBetaKL(self): for shape in [(10,), (4, 5)]: a1 = 6.0 * np.random.random(size=shape) + 1e-4 b1 = 6.0 * np.random.random(size=shape) + 1e-4 a2 = 6.0 * np.random.random(size=shape) + 1e-4 b2 = 6.0 * np.random.random(size=shape) + 1e-4 d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) if not special: return kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + (a1 - a2) * special.digamma(a1) + (b1 - b2) * special.digamma(b1) + (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) kl = kullback_leibler.kl_divergence(d1, d2) kl_val = self.evaluate(kl) self.assertEqual(kl.shape, shape) self.assertAllClose(kl_val, kl_expected) # Make sure KL(d1||d1) is 0 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testDirichletDirichletKL(self): conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) d1 = dirichlet_lib.Dirichlet(conc1) d2 = dirichlet_lib.Dirichlet(conc2) x = d1.sample(int(1e4), seed=0) kl_sample = tf.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(d1, d2) kl_sample_val = self.evaluate(kl_sample) kl_actual_val = self.evaluate(kl_actual) self.assertEqual(conc1.shape[:-1], kl_actual.shape) if not special: return kl_expected = ( special.gammaln(np.sum(conc1, -1)) - special.gammaln(np.sum(conc2, -1)) - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma(np.sum(conc1, -1, keepdims=True))), -1)) self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) # Make sure KL(d1||d1) is 0 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def testBetaBetaKL(self): for shape in [(10,), (4, 5)]: a1 = 6.0 * np.random.random(size=shape) + 1e-4 b1 = 6.0 * np.random.random(size=shape) + 1e-4 a2 = 6.0 * np.random.random(size=shape) + 1e-4 b2 = 6.0 * np.random.random(size=shape) + 1e-4 d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) if not special: return kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + (a1 - a2) * special.digamma(a1) + (b1 - b2) * special.digamma(b1) + (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) kl = kullback_leibler.kl_divergence(d1, d2) kl_val = self.evaluate(kl) self.assertEqual(kl.shape, shape) self.assertAllClose(kl_val, kl_expected) # Make sure KL(d1||d1) is 0 kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) self.assertAllClose(kl_same, np.zeros_like(kl_expected))
def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default "kl_independent". Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum( input_tensor=kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError( "KL between Independents with different " "event shapes not supported.") else: raise ValueError("Event shapes do not match.") else: with tf.control_dependencies([ tf.compat.v1.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), tf.compat.v1.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) ]): num_reduce_dims = (prefer_static.rank_from_shape( a.event_shape_tensor, a.event_shape) - prefer_static.rank_from_shape( p.event_shape_tensor, a.event_shape)) reduce_dims = prefer_static.range(-num_reduce_dims - 1, -1, 1) return tf.reduce_sum(input_tensor=kullback_leibler.kl_divergence( p, q, name=name), axis=reduce_dims)
def testBackwardsCompatibilityDeterministic(self): tfp_normal = normal.Normal(0.0, 1.0) tf_normal = tf.distributions.Normal(0.0, 1.0) tfp_deterministic = deterministic.Deterministic(0.0) kullback_leibler.kl_divergence(tfp_deterministic, tf_normal) tf.distributions.kl_divergence(tfp_deterministic, tf_normal) kullback_leibler.kl_divergence(tfp_deterministic, tfp_normal) tf.distributions.kl_divergence(tfp_deterministic, tfp_normal)
def testBackwardsCompatibilityDeterministic(self): tfp_normal = normal.Normal(0.0, 1.0) tf_normal = tf.distributions.Normal(0.0, 1.0) tfp_deterministic = deterministic.Deterministic(0.0) kullback_leibler.kl_divergence(tfp_deterministic, tf_normal) tf.distributions.kl_divergence(tfp_deterministic, tf_normal) kullback_leibler.kl_divergence(tfp_deterministic, tfp_normal) tf.distributions.kl_divergence(tfp_deterministic, tfp_normal)
def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. We can leverage the fact that ``` KL(Independent(a) || Independent(b)) = sum(KL(a || b)) ``` where the sum is over the `reinterpreted_batch_ndims`. Args: a: Instance of `Independent`. b: Instance of `Independent`. name: (optional) name to use for created ops. Default "kl_independent". Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ p = a.distribution q = b.distribution # The KL between any two (non)-batched distributions is a scalar. # Given that the KL between two factored distributions is the sum, i.e. # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): if a.event_shape == b.event_shape: if p.event_shape == q.event_shape: num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] return tf.reduce_sum( kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) else: raise NotImplementedError("KL between Independents with different " "event shapes not supported.") else: raise ValueError("Event shapes do not match.") else: with tf.control_dependencies([ tf.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), tf.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) ]): num_reduce_dims = ( tf.shape(a.event_shape_tensor()[0]) - tf.shape( p.event_shape_tensor()[0])) reduce_dims = tf.range(-num_reduce_dims - 1, -1, 1) return tf.reduce_sum( kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
def testMissing(self): class MyDist(distribution_lib.Distribution): def __init__(self): super(MyDist, self).__init__( dtype=tf.float32, reparameterization_type=reparameterization. FULLY_REPARAMETERIZED, validate_args=True, allow_nan_stats=True) with self.assertRaisesRegexp( NotImplementedError, "No KL(distribution_a || distribution_b)"): kullback_leibler.kl_divergence(MyDist(), MyDist())
def __init__( self, rank, filters, kernel_size, is_mc, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, activation=None, activity_regularizer=None, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=(lambda q, p, ignore: kl_lib.kl_divergence(q, p)), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True ), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), **kwargs ): super(_ConvVariational, self).__init__( activity_regularizer=activity_regularizer, **kwargs ) self.rank = rank self.is_mc = is_mc self.filters = filters self.kernel_size = tf_layers_util.normalize_tuple( kernel_size, rank, "kernel_size" ) self.strides = tf_layers_util.normalize_tuple(strides, rank, "strides") self.padding = tf_layers_util.normalize_padding(padding) self.data_format = tf_layers_util.normalize_data_format(data_format) self.dilation_rate = tf_layers_util.normalize_tuple( dilation_rate, rank, "dilation_rate" ) self.activation = tf.keras.activations.get(activation) self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn self.kernel_divergence_fn = kernel_divergence_fn self.bias_posterior_fn = bias_posterior_fn self.bias_posterior_tensor_fn = bias_posterior_tensor_fn self.bias_prior_fn = bias_prior_fn self.bias_divergence_fn = bias_divergence_fn
def testGammaGammaKL(self): alpha0 = np.array([3.]) beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) alpha1 = np.array([0.4]) beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = tf.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(g0, g1) # Execute graph. [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.shape) if not special: return kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) + special.gammaln(alpha1) - special.gammaln(alpha0) + alpha1 * np.log(beta0) - alpha1 * np.log(beta1) + alpha0 * (beta1 / beta0 - 1.)) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1)
def _kl_masked_masked(a, b, name=None): """KL divergence between Masked distributions.""" with tf.name_scope(name or 'kl_masked_masked'): a_valid = tf.convert_to_tensor(a.validity_mask) b_valid = tf.convert_to_tensor(b.validity_mask) underlying_kl = kullback_leibler.kl_divergence(a.distribution, b.distribution) # The treatment for KL is as follows: # When both random variables are valid, the underlying KL applies. # When neither random variable is valid, the KL is 0., i.e. # `a log a - a log b = 0` because log a and log b are everywhere 0. # When exactly one is valid, we (a) raise an assertion error, if either # distribution's allow_nan_stats is set to False, or (b) return nan in # such positions. asserts = [] if not (a.allow_nan_stats and b.allow_nan_stats): asserts.append( assert_util.assert_equal( a_valid, b_valid, message='KL is only valid for matching mask values')) with tf.control_dependencies(asserts): both_valid = (a_valid & b_valid) neither_valid = (~a_valid) & (~b_valid) dtype = underlying_kl.dtype return tf.where( both_valid, underlying_kl, tf.where(neither_valid, tf.zeros([], dtype), float('nan')))
def _kl_sample(a, b, name="kl_sample"): """Batched KL divergence `KL(a || b)` for BatchStacker distributions. We can leverage the fact that: ``` KL(BatchStacker(a) || BatchStacker(b)) = sum(KL(a || b)) ``` where the sum is over the `batch_stack` dims. Args: a: Instance of `BatchStacker` distribution. b: Instance of `BatchStacker` distribution. name: (optional) name to use for created ops. Default value: `"kl_sample"`'. Returns: kldiv: Batchwise `KL(a || b)`. Raises: ValueError: If the `batch_stack` of `a` and `b` don't match. """ assertions = [] a_ss = tf.get_static_value(a.batch_stack) b_ss = tf.get_static_value(b.batch_stack) msg = "`a.batch_stack` must be identical to `b.batch_stack`." if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.batch_stack, b.batch_stack, message=msg)) with tf.control_dependencies(assertions): return kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name)
def _kl_joint_joint(d0, d1, name=None): """Calculate the KL divergence between two `JointDistributionSequential`s. Args: d0: instance of a `JointDistributionSequential` object. d1: instance of a `JointDistributionSequential` object. name: (optional) Name to use for created operations. Default value: `"kl_joint_joint"`. Returns: kl_joint_joint: `Tensor` The sum of KL divergences between elemental distributions of two joint distributions. Raises: ValueError: when joint distributions have a different number of elemental distributions. ValueError: when either joint distribution has a distribution with dynamic dependency, i.e., when either joint distribution is not a collection of independent distributions. """ if len(d0.distribution_fn) != len(d1.distribution_fn): raise ValueError( 'Can only compute KL divergence between JointDistributionSequential ' 'distributions with the same number of component distributions.') if (not all(a is None for a in d0._dist_fn_args) or # pylint: disable=protected-access not all(a is None for a in d1._dist_fn_args)): # pylint: disable=protected-access raise ValueError( 'Can only compute KL divergence when all distributions are ' 'independent.') with tf.name_scope(name or 'kl_jointseq_jointseq'): return sum( kullback_leibler.kl_divergence(d0_, d1_) for d0_, d1_ in zip(d0.distribution_fn, d1.distribution_fn))
def surrogate_posterior_kl_divergence_prior(self, name=None): """Compute `KL(surrogate inducing point posterior || prior)`. See [Hensman, 2013][1]. Args: name: Python `str` name prefixed to Ops created by this class. Default value: 'surrogate_posterior_kl_divergence_prior'. Returns: kl: Scalar tensor representing the KL between the (surrogate/variational) posterior over inducing point function values, and the GP prior over the inducing point function values. #### References [1]: Hensman, J., Lawrence, N. "Gaussian Processes for Big Data", 2013 https://arxiv.org/abs/1309.6835 """ with tf.name_scope(name or 'surrogate_posterior_kl_divergence_prior'): inducing_prior = gaussian_process.GaussianProcess( kernel=self._kernel, mean_fn=self._mean_fn, index_points=self._inducing_index_points, observation_noise_variance=self._observation_noise_variance) return kullback_leibler.kl_divergence( self._variational_inducing_observations_posterior, inducing_prior)
def testGammaGammaKL(self): alpha0 = np.array([3.]) beta0 = np.array([1., 2., 3., 1.5, 2.5, 3.5]) alpha1 = np.array([0.4]) beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) x = g0.sample(int(1e4), seed=0) kl_sample = tf.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) kl_actual = kullback_leibler.kl_divergence(g0, g1) # Execute graph. [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) if not special: return kl_expected = ((alpha0 - alpha1) * special.digamma(alpha0) + special.gammaln(alpha1) - special.gammaln(alpha0) + alpha1 * np.log(beta0) - alpha1 * np.log(beta1) + alpha0 * (beta1 / beta0 - 1.)) self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6) self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1)
def __init__( self, rank, filters, kernel_size, is_mc, strides=1, padding="valid", data_format="channels_last", dilation_rate=1, activation=None, activity_regularizer=None, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True ), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), **kwargs ): super(_ConvReparameterization, self).__init__( rank=rank, filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, is_mc=is_mc, data_format=data_format, dilation_rate=dilation_rate, activation=tf.keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs )
def testBackwardsCompatibilityAliases(self): # Each element is a tuple(classes), tuple(args). aliases = [ ((bernoulli.Bernoulli, tf.distributions.Bernoulli), (1., )), ((beta.Beta, tf.distributions.Beta), (1., 1.)), ((categorical.Categorical, tf.distributions.Categorical), ([1.0, 1.0], )), ((dirichlet.Dirichlet, tf.distributions.Dirichlet), ([1.0], )), ((gamma.Gamma, tf.distributions.Gamma), (1.0, 1.0)), ((normal.Normal, tf.distributions.Normal), (1.0, 1.0)), ] for dists, args in aliases: for class0, class1 in itertools.permutations(dists): d0 = class0(*args) d1 = class1(*args) kullback_leibler.kl_divergence(d0, d1) tf.distributions.kl_divergence(d0, d1)
def testBackwardsCompatibilityAliases(self): # Each element is a tuple(classes), tuple(args). aliases = [ ((bernoulli.Bernoulli, tf.distributions.Bernoulli), (1.,)), ((beta.Beta, tf.distributions.Beta), (1., 1.)), ((categorical.Categorical, tf.distributions.Categorical), ([1.0, 1.0],)), ((dirichlet.Dirichlet, tf.distributions.Dirichlet), ([1.0],)), ((gamma.Gamma, tf.distributions.Gamma), (1.0, 1.0)), ((normal.Normal, tf.distributions.Normal), (1.0, 1.0)), ] for dists, args in aliases: for class0, class1 in itertools.permutations(dists): d0 = class0(*args) d1 = class1(*args) kullback_leibler.kl_divergence(d0, d1) tf.distributions.kl_divergence(d0, d1)
def _kl_independent(a: Batchwise, b: Batchwise, name='kl_batch_concatenation'): r"""Batched KL divergence `KL(a || b)` for concatenated distributions. Just the summation of all distributions KL """ KLs = [] for d1, d2 in zip(a.distributions, b.distributions): KLs.append(kullback_leibler.kl_divergence(d1, d2, name=name)) return tf.concat(KLs, axis=a.axis)
def add_kl_loss(self, posterior_dist, prior_dist): """Add KL divergence loss.""" if self.kl_use_exact: self.add_loss( kl_lib.kl_divergence(posterior_dist, prior_dist) * self.kl_weight * self.kl_anneal) else: self.add_loss( self._kl_approximation(posterior_dist, prior_dist) * self.kl_weight * self.kl_anneal)
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence( q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), seed=None, **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} seed: Python scalar `int` which initializes the random number generator. Default value: `None` (i.e., use global seed). """ # pylint: enable=g-doc-args super(DenseFlipout, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs) # Set additional attributes which do not exist in the parent class. self.seed = seed
def _kl_independent(a: CombinedDistribution, b: CombinedDistribution, name='kl_combined'): r"""Batched KL divergence `KL(a || b)` for CombinedDistribution distributions. Just the summation of all distributions KL """ kl = 0. for d1, d2 in zip(a.distributions, b.distributions): kl += kullback_leibler.kl_divergence(d1, d2, name=name) return kl
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
def testRegistration(self): class MyDist(normal.Normal): pass # Register KL to a lambda that spits out the name parameter @kullback_leibler.RegisterKL(MyDist, MyDist) def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable return name a = MyDist(loc=0.0, scale=1.0) self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence( q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(DenseLocalReparameterization, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs)
def __init__( self, units, activation=None, activity_regularizer=None, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence( q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(_DenseVariational, self).__init__(activity_regularizer=activity_regularizer, **kwargs) self.units = units self.activation = tf.keras.activations.get(activation) self.input_spec = tf.keras.layers.InputSpec(min_ndim=2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn self.kernel_divergence_fn = kernel_divergence_fn self.bias_posterior_fn = bias_posterior_fn self.bias_posterior_tensor_fn = bias_posterior_tensor_fn self.bias_prior_fn = bias_prior_fn self.bias_divergence_fn = bias_divergence_fn
def testBernoulliBernoulliKL(self): batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = bernoulli.Bernoulli(probs=a_p) b = bernoulli.Bernoulli(probs=b_p) kl = kullback_leibler.kl_divergence(a, b) kl_val = self.evaluate(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def testBernoulliBernoulliKL(self): batch_size = 6 a_p = np.array([0.5] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = bernoulli.Bernoulli(probs=a_p) b = bernoulli.Bernoulli(probs=b_p) kl = kullback_leibler.kl_divergence(a, b) kl_val = self.evaluate(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def _kl_sample(a: BatchStacker, b: BatchStacker, name: str = "kl_sample") -> tf.Tensor: r""" Batched KL divergence :math:`KL(a || b)` for :class:`~.BatchStacker` distributions. We can leverage the fact that: .. math:: KL(BatchStacker(a) || BatchStacker(b)) = \sum(KL(a || b)) where the :math:`\sum` is over the ``batch_stack`` dims. Parameters ---------- a : BatchStacker Instance of ``BatchStacker`` distribution. b : BatchStacker Instance of ``BatchStacker`` distribution. name : str Name to use for created ops. Returns ------- kldiv : tf.Tensor Batchwise :math:`KL(a || b)`. Raises ------ ValueError If the ``batch_stack`` of ``a`` and ``b`` don't match. """ assertions = [] a_ss = tf.get_static_value(a.batch_stack) b_ss = tf.get_static_value(b.batch_stack) msg = "`a.batch_stack` must be identical to `b.batch_stack`." if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.batch_stack, b.batch_stack, message=msg)) with tf.control_dependencies(assertions): return kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name)
def _kl_blockwise_blockwise(b0, b1, name=None): """Calculate the batched KL divergence KL(b0 || b1) with b0 and b1 Blockwise distributions. Args: b0: instance of a Blockwise distribution object. b1: instance of a Blockwise distribution object. name: (optional) Name to use for created operations. Default is "kl_blockwise_blockwise". Returns: kl_blockwise_blockwise: `Tensor`. The batchwise KL(b0 || b1). """ if len(b0.distributions) != len(b1.distributions): raise ValueError( 'Can only compute KL divergence between Blockwise distributions with ' 'the same number of component distributions.') # We also need to check that the event shapes match for each one. b0_event_sizes = [_event_size(d) for d in b0.distributions] b1_event_sizes = [_event_size(d) for d in b1.distributions] assertions = [] message = ( 'Can only compute KL divergence between Blockwise distributions ' 'with the same pairwise event shapes.') if (all(isinstance(event_size, int) for event_size in b0_event_sizes) and all( isinstance(event_size, int) for event_size in b1_event_sizes)): if b0_event_sizes != b1_event_sizes: raise ValueError(message) else: if b0.validate_args or b1.validate_args: assertions.extend( assert_util.assert_equal( # pylint: disable=g-complex-comprehension e1, e2, message=message) for e1, e2 in zip(b0_event_sizes, b1_event_sizes)) with tf.compat.v2.name_scope(name or 'kl_blockwise_blockwise'): with tf.control_dependencies(assertions): return sum([ kullback_leibler.kl_divergence(d1, d2) for d1, d2 in zip(b0.distributions, b1.distributions) ])
def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl_divergence(n_a, n_b) kl_val = self.evaluate(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) kl = kullback_leibler.kl_divergence(n_a, n_b) kl_val = self.evaluate(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) self.assertEqual(kl.shape, (batch_size, )) self.assertAllClose(kl_val, kl_expected)
def _kl_logitnormal_logitnormal(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a and b LogitNormal. This is the same as the KL divergence between the underlying Normal distributions. Args: a: instance of a LogitNormal distribution object. b: instance of a LogitNormal distribution object. name: Name to use for created operations. Default value: `None` (i.e., `'kl_logitnormal_logitnormal'`). Returns: kl_div: Batchwise KL(a || b) """ return kullback_leibler.kl_divergence( a.distribution, b.distribution, name=(name or 'kl_logitnormal_logitnormal'))
def _kl_sample(a, b, name='kl_sample'): """Batched KL divergence `KL(a || b)` for Sample distributions. We can leverage the fact that: ``` KL(Sample(a) || Sample(b)) = sum(KL(a || b)) ``` where the sum is over the `sample_shape` dims. Args: a: Instance of `Sample` distribution. b: Instance of `Sample` distribution. name: (optional) name to use for created ops. Default value: `"kl_sample"`'. Returns: kldiv: Batchwise `KL(a || b)`. Raises: ValueError: If the `sample_shape` of `a` and `b` don't match. """ assertions = [] a_ss = tf.get_static_value(a.sample_shape) b_ss = tf.get_static_value(b.sample_shape) msg = '`a.sample_shape` must be identical to `b.sample_shape`.' if a_ss is not None and b_ss is not None: if not np.array_equal(a_ss, b_ss): raise ValueError(msg) elif a.validate_args or b.validate_args: assertions.append( assert_util.assert_equal(a.sample_shape, b.sample_shape, message=msg)) with tf.control_dependencies(assertions): kl = kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name) n = ps.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl
def _kl_independent(a, b, name='kl_independent'): r"""Batched KL divergence `KL(a || b)` for ConditionalTensor distributions. This will just ignore the concatenated tensor and return `kl_divergence` of the original distribution. Arguments: a: Instance of `ConditionalTensor`. b: Instance of `ConditionalTensor`. name: (optional) name to use for created ops. Default 'kl_independent'. Returns: Batchwise `KL(a || b)`. Raises: ValueError: If the event space for `a` and `b`, or their underlying distributions don't match. """ return kullback_leibler.kl_divergence(a.distribution, b.distribution, name=name)
def _kl_divergence(self, other): return kullback_leibler.kl_divergence( self, other, allow_nan_stats=self.allow_nan_stats)
def testBackwardsCompatibilityFallback(self): tf_normal = tf.distributions.Normal(0.0, 1.0) kullback_leibler.kl_divergence(tf_normal, tf_normal)