def testLogNormalLogNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) ln_a = tfd.LogNormal(loc=mu_a, scale=sigma_a, validate_args=True) ln_b = tfd.LogNormal(loc=mu_b, scale=sigma_b, validate_args=True) kl = tfd.kl_divergence(ln_a, ln_b) kl_val = self.evaluate(kl) normal_a = tfd.Normal(loc=mu_a, scale=sigma_a, validate_args=True) normal_b = tfd.Normal(loc=mu_b, scale=sigma_b, validate_args=True) kl_expected_from_normal = tfd.kl_divergence(normal_a, normal_b) kl_expected_from_formula = ( (mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ((sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) x = ln_a.sample(int(2e5), seed=test_util.test_seed()) kl_sample = tf.reduce_mean(ln_a.log_prob(x) - ln_b.log_prob(x), axis=0) kl_sample_ = self.evaluate(kl_sample) self.assertEqual(kl.shape, (batch_size, )) self.assertAllClose(kl_val, kl_expected_from_normal) self.assertAllClose(kl_val, kl_expected_from_formula) self.assertAllClose(kl_expected_from_formula, kl_sample_, atol=0.0, rtol=1e-2)
def testKLRaises(self): ind1 = tfd.Independent(distribution=tfd.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1, validate_args=True) ind2 = tfd.Independent(distribution=tfd.Normal(loc=np.float32(-1), scale=np.float32(0.5)), reinterpreted_batch_ndims=0, validate_args=True) with self.assertRaisesRegexp(ValueError, 'Event shapes do not match'): tfd.kl_divergence(ind1, ind2) ind1 = tfd.Independent(distribution=tfd.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=1, validate_args=True) ind2 = tfd.Independent(distribution=tfd.MultivariateNormalDiag( loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])), reinterpreted_batch_ndims=0, validate_args=True) with self.assertRaisesRegexp(NotImplementedError, 'different event shapes'): tfd.kl_divergence(ind1, ind2)
def testUniformUniformKLFinite(self): batch_size = 6 a_low = -1.0 * np.arange(1, batch_size + 1) a_high = np.array([1.0] * batch_size) b_low = -2.0 * np.arange(1, batch_size + 1) b_high = np.array([2.0] * batch_size) a = tfd.Uniform(low=a_low, high=a_high, validate_args=True) b = tfd.Uniform(low=b_low, high=b_high, validate_args=True) true_kl = np.log(b_high - b_low) - np.log(a_high - a_low) kl = tfd.kl_divergence(a, b) # This is essentially an approximated integral from the direct definition # of KL divergence. x = a.sample(int(1e4), seed=test_util.test_seed()) kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0) kl_, kl_sample_ = self.evaluate([kl, kl_sample]) self.assertAllClose(true_kl, kl_, atol=2e-15) self.assertAllClose(true_kl, kl_sample_, atol=0.0, rtol=1e-1) zero_kl = tfd.kl_divergence(a, a) true_zero_kl_, zero_kl_ = self.evaluate( [tf.zeros_like(true_kl), zero_kl]) self.assertAllEqual(true_zero_kl_, zero_kl_)
def __init__( self, units, activation=None, activity_regularizer=None, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(_DenseVariational, self).__init__( activity_regularizer=activity_regularizer, **kwargs) self.units = units self.activation = tf.keras.activations.get(activation) self.input_spec = tf.layers.InputSpec(min_ndim=2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn self.kernel_divergence_fn = kernel_divergence_fn self.bias_posterior_fn = bias_posterior_fn self.bias_posterior_tensor_fn = bias_posterior_tensor_fn self.bias_prior_fn = bias_prior_fn self.bias_divergence_fn = bias_divergence_fn
def __init__( self, units, activation=None, activity_regularizer=None, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(_DenseVariational, self).__init__( activity_regularizer=activity_regularizer, **kwargs) self.units = units self.activation = tf.keras.activations.get(activation) self.input_spec = tf.keras.layers.InputSpec(min_ndim=2) self.kernel_posterior_fn = kernel_posterior_fn self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn self.kernel_prior_fn = kernel_prior_fn self.kernel_divergence_fn = kernel_divergence_fn self.bias_posterior_fn = bias_posterior_fn self.bias_posterior_tensor_fn = bias_posterior_tensor_fn self.bias_prior_fn = bias_prior_fn self.bias_divergence_fn = bias_divergence_fn
def __init__( self, units, activation=None, activity_regularizer=None, client_weight=1., trainable=True, kernel_posterior_fn=None, kernel_posterior_tensor_fn=(lambda d: d.sample()), kernel_prior_fn=None, kernel_divergence_fn=( lambda q, p, ignore: tfd.kl_divergence(q, p)), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=(lambda d: d.sample()), bias_prior_fn=None, bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)), **kwargs): self.untransformed_scale_initializer = None if 'untransformed_scale_initializer' in kwargs: self.untransformed_scale_initializer = \ kwargs.pop('untransformed_scale_initializer') self.loc_initializer = None if 'loc_initializer' in kwargs: self.loc_initializer = \ kwargs.pop('loc_initializer') self.delta_percentile = kwargs.pop('delta_percentile', None) if kernel_posterior_fn is None: kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn if kernel_prior_fn is None: kernel_prior_fn = self.natural_tensor_multivariate_normal_fn super(DenseSharedNatural, self).\ __init__(units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs) self.client_weight = client_weight self.delta_function = tf.subtract if self.delta_percentile and not activation == 'softmax': self.delta_function = sparse_delta_function(self.delta_percentile) print(self, activation, 'using delta sparisfication') self.apply_delta_function = tf.add self.client_variable_dict = {} self.client_center_variable_dict = {} self.server_variable_dict = {}
def __init__( self, filters, kernel_size, strides=1, padding='valid', client_weight=1., data_format='channels_last', dilation_rate=1, activation=None, activity_regularizer=None, kernel_posterior_fn=None, kernel_posterior_tensor_fn=(lambda d: d.sample()), kernel_prior_fn=None, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn= tfp_layers_util.default_mean_field_normal_fn(is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), **kwargs): self.untransformed_scale_initializer = None if 'untransformed_scale_initializer' in kwargs: self.untransformed_scale_initializer = \ kwargs.pop('untransformed_scale_initializer') if kernel_posterior_fn is None: kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn if kernel_prior_fn is None: kernel_prior_fn = self.natural_tensor_multivariate_normal_fn super(Conv1DVirtualNatural, self).__init__( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=tf.keras.activations.get(activation), activity_regularizer=activity_regularizer, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs) self.client_weight = client_weight self.delta_function = tf.subtract self.apply_delta_function = tf.add self.client_variable_dict = {} self.client_center_variable_dict = {} self.server_variable_dict = {}
def testTransformedKLDifferentBijectorFails(self): d1 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=2.), validate_args=True) d2 = self._cls()(tfd.Exponential(rate=0.25), bijector=tfb.Scale(scale=3.), validate_args=True) with self.assertRaisesRegex(NotImplementedError, r'their bijectors are not equal'): tfd.kl_divergence(d1, d2)
def test_kl(self): a = tfd.MultivariateNormalDiag([tf.range(3.)] * 4, tf.ones(3)) b = tfd.MultivariateNormalDiag([tf.range(3.) + .5] * 4, tf.ones(3)) kl = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(3, 4)), tfd.Masked(b, tf.sequence_mask(2, 4))) kl2 = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(2, 4)), tfd.Masked(b, tf.sequence_mask(3, 4))) self.assertAllClose(a.kl_divergence(b)[:2], kl[:2]) self.assertAllEqual(float('nan'), kl[2]) self.assertAllEqual(0., kl[3]) self.assertAllEqual(float('nan'), kl2[2])
def testKLBatchBroadcast(self): batch_shape = [2] event_shape = [3] mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) # No batch shape. mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) mvn_a = tfd.MultivariateNormalTriL( loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True) mvn_b = tfd.MultivariateNormalTriL( loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) kl = tfd.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], mu_b, sigma_b) expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], mu_b, sigma_b) self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1])
def testWeibullWeibullKL(self): a_concentration = np.array([3.]) a_scale = np.array([2.]) b_concentration = np.array([6.]) b_scale = np.array([4.]) a = tfd.Weibull(concentration=a_concentration, scale=a_scale, validate_args=True) b = tfd.Weibull(concentration=b_concentration, scale=b_scale, validate_args=True) kl = tfd.kl_divergence(a, b) expected_kl = ( np.log(a_concentration / a_scale**a_concentration) - np.log(b_concentration / b_scale**b_concentration) + ((a_concentration - b_concentration) * (np.log(a_scale) - np.euler_gamma / a_concentration)) + ((a_scale / b_scale)**b_concentration * np.exp(np.math.lgamma(b_concentration / a_concentration + 1.))) - 1.) x = a.sample(int(1e5), seed=test_util.test_seed()) kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0) kl_sample_val = self.evaluate(kl_sample) self.assertAllClose(expected_kl, kl_sample_val, atol=0.0, rtol=1e-2) self.assertAllClose(expected_kl, self.evaluate(kl))
def testKLBatchBroadcast(self): batch_shape = [2] event_shape = [3] loc_a, scale_a = self._random_loc_and_scale(batch_shape, event_shape) # No batch shape. loc_b, scale_b = self._random_loc_and_scale([], event_shape) mvn_a = tfd.MultivariateNormalLinearOperator( loc=loc_a, scale=scale_a, validate_args=True) mvn_b = tfd.MultivariateNormalLinearOperator( loc=loc_b, scale=scale_b, validate_args=True) kl = tfd.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) expected_kl_0 = self._compute_non_batch_kl( loc_a[0, :], self.evaluate(scale_a.to_dense())[0, :, :], loc_b, self.evaluate(scale_b.to_dense())) expected_kl_1 = self._compute_non_batch_kl( loc_a[1, :], self.evaluate(scale_a.to_dense())[1, :, :], loc_b, self.evaluate(scale_b.to_dense())) self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1])
def testNormalNormalKL(self): batch_size = 6 mu_a = np.array([3.0] * batch_size) sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) mu_b = np.array([-3.0] * batch_size) sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) n_a = tfd.Normal(loc=mu_a, scale=sigma_a, validate_args=True) n_b = tfd.Normal(loc=mu_b, scale=sigma_b, validate_args=True) kl = tfd.kl_divergence(n_a, n_b) kl_val = self.evaluate(kl) kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) x = n_a.sample(int(1e5), seed=test_util.test_seed()) kl_samples = n_a.log_prob(x) - n_b.log_prob(x) kl_samples_ = self.evaluate(kl_samples) self.assertEqual(kl.shape, (batch_size, )) self.assertAllClose(kl_val, kl_expected) self.assertAllMeansClose(kl_samples_, kl_expected, axis=0, atol=0.0, rtol=1e-2)
def testKLScalarToMultivariate(self): normal1 = tfd.Normal( loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) ind1 = tfd.Independent(distribution=normal1, reinterpreted_batch_ndims=1, validate_args=True) normal2 = tfd.Normal( loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) ind2 = tfd.Independent(distribution=normal2, reinterpreted_batch_ndims=1, validate_args=True) normal_kl = tfd.kl_divergence(normal1, normal2) ind_kl = tfd.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(tf.reduce_sum(input_tensor=normal_kl, axis=-1)), self.evaluate(ind_kl))
def testContinuousBernoulliContinuousBernoulliKL(self): batch_size = 6 a_p = np.array([0.6] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = tfd.ContinuousBernoulli(probs=a_p, validate_args=True) b = tfd.ContinuousBernoulli(probs=b_p, validate_args=True) kl = tfd.kl_divergence(a, b) kl_val = self.evaluate(kl) kl_expected = ( mean(a_p) * ( np.log(a_p) + np.log1p(-b_p) - np.log(b_p) - np.log1p(-a_p) ) + log_norm_const(a_p) - log_norm_const(b_p) + np.log1p(-a_p) - np.log1p(-b_p) ) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def test_docstring_example_bernoulli(self): num_draws = int(1e5) probs_p = tf.constant(0.4) probs_q = tf.constant(0.7) with tf.GradientTape(persistent=True) as tape: tape.watch(probs_p) tape.watch(probs_q) p = tfd.Bernoulli(probs=probs_p) q = tfd.Bernoulli(probs=probs_q) exact_kl_bernoulli_bernoulli = tfp.monte_carlo.expectation( f=lambda x: p.log_prob(x) - q.log_prob(x), samples=p.sample(num_draws, seed=42), log_prob=p.log_prob, use_reparametrization=( p.reparameterization_type == tfd.FULLY_REPARAMETERIZED)) approx_kl_bernoulli_bernoulli = tfd.kl_divergence(p, q) [ exact_kl_bernoulli_bernoulli_, approx_kl_bernoulli_bernoulli_, ] = self.evaluate([ exact_kl_bernoulli_bernoulli, approx_kl_bernoulli_bernoulli, ]) self.assertEqual(False, p.reparameterization_type == tfd.FULLY_REPARAMETERIZED) self.assertAllClose( exact_kl_bernoulli_bernoulli_, approx_kl_bernoulli_bernoulli_, rtol=0.01, atol=0.) print(exact_kl_bernoulli_bernoulli_, approx_kl_bernoulli_bernoulli_) # Compare gradients. (Not present in `docstring`.) gradp = lambda fp: tape.gradient(fp, probs_p) gradq = lambda fq: tape.gradient(fq, probs_q) [ gradp_exact_kl_bernoulli_bernoulli_, gradq_exact_kl_bernoulli_bernoulli_, gradp_approx_kl_bernoulli_bernoulli_, gradq_approx_kl_bernoulli_bernoulli_, ] = self.evaluate([ gradp(exact_kl_bernoulli_bernoulli), gradq(exact_kl_bernoulli_bernoulli), gradp(approx_kl_bernoulli_bernoulli), gradq(approx_kl_bernoulli_bernoulli), ]) # Notice that variance (i.e., `rtol`) is higher when using score-trick. self.assertAllClose( gradp_exact_kl_bernoulli_bernoulli_, gradp_approx_kl_bernoulli_bernoulli_, rtol=0.05, atol=0.) self.assertAllClose( gradq_exact_kl_bernoulli_bernoulli_, gradq_approx_kl_bernoulli_bernoulli_, rtol=0.03, atol=0.)
def testKLIdentity(self): normal1 = tfd.Normal(loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])) # This is functionally just a wrapper around normal1, # and doesn't change any outputs. ind1 = tfd.Independent(distribution=normal1, reinterpreted_batch_ndims=0) normal2 = tfd.Normal(loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3])) # This is functionally just a wrapper around normal2, # and doesn't change any outputs. ind2 = tfd.Independent(distribution=normal2, reinterpreted_batch_ndims=0) normal_kl = tfd.kl_divergence(normal1, normal2) ind_kl = tfd.kl_divergence(ind1, ind2) self.assertAllClose(self.evaluate(normal_kl), self.evaluate(ind_kl))
def test_kl_divergence(self): d0 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1), x=tfd.Normal(loc=0, scale=2.)), validate_args=True) d1 = tfd.JointDistributionNamed( dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1), x=tfd.Normal(loc=1, scale=1.)), validate_args=True) self.assertEqual(d0.model.keys(), d1.model.keys()) expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k]) for k in d0.model.keys()) actual_kl = tfd.kl_divergence(d0, d1) other_actual_kl = d0.kl_divergence(d1) expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([ expected_kl, actual_kl, other_actual_kl]) self.assertNear(expected_kl_, actual_kl_, err=1e-5) self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
def testKLMultivariateToMultivariate(self): # (1, 1, 2) batch of MVNDiag mvn1 = tfd.MultivariateNormalDiag( loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) ind1 = tfd.Independent(distribution=mvn1, reinterpreted_batch_ndims=2) # (1, 1, 2) batch of MVNDiag mvn2 = tfd.MultivariateNormalDiag( loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) ind2 = tfd.Independent(distribution=mvn2, reinterpreted_batch_ndims=2) mvn_kl = tfd.kl_divergence(mvn1, mvn2) ind_kl = tfd.kl_divergence(ind1, ind2) self.assertAllClose( self.evaluate(tf.reduce_sum(input_tensor=mvn_kl, axis=[-1, -2])), self.evaluate(ind_kl))
def compute_kl_div(self): ''' compute KL( q(z|c) || r(z) ) ''' prior = tfd.Normal(tf.zeros(self.latent_dim), tf.ones(self.latent_dim)) posteriors = [ tfd.Normal(mu, tf.math.sqrt(var)) for mu, var in zip( tf.unstack(self.z_means), tf.unstack(self.z_vars)) ] kl_divs = [tfd.kl_divergence(post, prior) for post in posteriors] kl_div_sum = tf.reduce_sum(tf.stack(kl_divs)) return kl_div_sum
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), seed=None, **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} seed: Python scalar `int` which initializes the random number generator. Default value: `None` (i.e., use global seed). """ # pylint: enable=g-doc-args super(DenseFlipout, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs) # Set additional attributes which do not exist in the parent class. self.seed = seed
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), seed=None, **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} seed: Python scalar `int` which initializes the random number generator. Default value: `None` (i.e., use global seed). """ # pylint: enable=g-doc-args super(DenseFlipout, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs) # Set additional attributes which do not exist in the parent class. self.seed = seed
def test_kl_divergence(self): q_scale = 2. p = tfd.Sample( tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=1), 1), [5, 4], validate_args=True) q = tfd.Sample( tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=2.), 1), [5, 4], validate_args=True) actual_kl = tfd.kl_divergence(p, q) expected_kl = ((5 * 4) * (0.5 * q_scale**-2. - 0.5 + np.log(q_scale)) * # Actual KL. np.ones([3]) * 2) # Batch, events. self.assertAllClose(expected_kl, self.evaluate(actual_kl))
def testKLTwoIdenticalDistributionsIsZero(self): batch_shape = [2] event_shape = [3] mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = tfd.MultivariateNormalTriL( loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True) # Should be zero since KL(p || p) = =. kl = tfd.kl_divergence(mvn_a, mvn_a) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) self.assertAllClose(np.zeros(*batch_shape), kl_v)
def testWeibullGammaKLAgreeWeibullWeibull(self): a_concentration = np.array([3.]) a_scale = np.array([2.]) b_concentration = np.array([1.]) b_rate = np.array([0.25]) a = tfd.Weibull(concentration=a_concentration, scale=a_scale, validate_args=True) b = tfd.Gamma(concentration=b_concentration, rate=b_rate, validate_args=True) c = tfd.Weibull(concentration=b_concentration, scale=1 / b_rate, validate_args=True) kl_weibull_weibull = tfd.kl_divergence(a, c) kl_weibull_gamma = tfd.kl_divergence(a, b) self.assertAllClose(self.evaluate(kl_weibull_gamma), self.evaluate(kl_weibull_weibull), atol=0.0, rtol=1e-6)
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(DenseLocalReparameterization, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs)
def __init__( self, units, activation=None, activity_regularizer=None, trainable=True, kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(), kernel_posterior_tensor_fn=lambda d: d.sample(), kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn, kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn( is_singular=True), bias_posterior_tensor_fn=lambda d: d.sample(), bias_prior_fn=None, bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p), **kwargs): # pylint: disable=g-doc-args """Construct layer. Args: ${args} """ # pylint: enable=g-doc-args super(DenseLocalReparameterization, self).__init__( units=units, activation=activation, activity_regularizer=activity_regularizer, trainable=trainable, kernel_posterior_fn=kernel_posterior_fn, kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, kernel_prior_fn=kernel_prior_fn, kernel_divergence_fn=kernel_divergence_fn, bias_posterior_fn=bias_posterior_fn, bias_posterior_tensor_fn=bias_posterior_tensor_fn, bias_prior_fn=bias_prior_fn, bias_divergence_fn=bias_divergence_fn, **kwargs)
def testBernoulliBernoulliKL(self): batch_size = 6 a_p = np.array([0.6] * batch_size, dtype=np.float32) b_p = np.array([0.4] * batch_size, dtype=np.float32) a = tfd.Bernoulli(probs=a_p, validate_args=True) b = tfd.Bernoulli(probs=b_p, validate_args=True) kl = tfd.kl_divergence(a, b) kl_val = self.evaluate(kl) kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log( (1. - a_p) / (1. - b_p))) self.assertEqual(kl.shape, (batch_size,)) self.assertAllClose(kl_val, kl_expected)
def testBernoulliBernoulliKLWhenProbOneIsZero(self): # KL[a || b] = Pa * Log[Pa / Pb] + (1 - Pa) * Log[(1 - Pa) / (1 - Pb)] # This is defined iff (Pb = 0 ==> Pa = 0) AND (Pb = 1 ==> Pa = 1). a = tfd.Bernoulli(probs=[0., 0., 0.]) b = tfd.Bernoulli(probs=[0.5, 1., 0.]) kl_expected = [ # The Pa term kills the entire first term. 0 + 1 * np.log(1 / 0.5), # P[b = 0] = 0, but P[a = 0] != 0, so not absolutely continuous. # Some would argue that NaN would be more correct... np.inf, # P[b = 1] = 0, and P[a = 1] = 0, so absolute continuity holds. 0 + 1 * np.log(1 / 1) ] self.assertAllClose(self.evaluate(tfd.kl_divergence(a, b)), kl_expected)
def testKLNonBatch(self): batch_shape = [] event_shape = [2] mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = tfd.MultivariateNormalTriL( loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True) mvn_b = tfd.MultivariateNormalTriL( loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True) kl = tfd.kl_divergence(mvn_a, mvn_b) self.assertEqual(batch_shape, kl.shape) kl_v = self.evaluate(kl) expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b) self.assertAllClose(expected_kl, kl_v)
def test_docstring_example_normal(self): num_draws = int(1e5) mu_p = tf.constant(0.) mu_q = tf.constant(1.) with tf.GradientTape(persistent=True) as tape: tape.watch(mu_p) tape.watch(mu_q) p = tfd.Normal(loc=mu_p, scale=1.) q = tfd.Normal(loc=mu_q, scale=2.) exact_kl_normal_normal = tfd.kl_divergence(p, q) approx_kl_normal_normal = tfp.monte_carlo.expectation( f=lambda x: p.log_prob(x) - q.log_prob(x), samples=p.sample(num_draws, seed=42), log_prob=p.log_prob, use_reparameterization=( p.reparameterization_type == tfd.FULLY_REPARAMETERIZED)) [exact_kl_normal_normal_, approx_kl_normal_normal_ ] = self.evaluate([exact_kl_normal_normal, approx_kl_normal_normal]) self.assertEqual( True, p.reparameterization_type == tfd.FULLY_REPARAMETERIZED) self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_, rtol=0.01, atol=0.) # Compare gradients. (Not present in `docstring`.) gradp = lambda fp: tape.gradient(fp, mu_p) gradq = lambda fq: tape.gradient(fq, mu_q) [ gradp_exact_kl_normal_normal_, gradq_exact_kl_normal_normal_, gradp_approx_kl_normal_normal_, gradq_approx_kl_normal_normal_, ] = self.evaluate([ gradp(exact_kl_normal_normal), gradq(exact_kl_normal_normal), gradp(approx_kl_normal_normal), gradq(approx_kl_normal_normal), ]) self.assertAllClose(gradp_exact_kl_normal_normal_, gradp_approx_kl_normal_normal_, rtol=0.01, atol=0.) self.assertAllClose(gradq_exact_kl_normal_normal_, gradq_approx_kl_normal_normal_, rtol=0.01, atol=0.)
def __init__(self, input_dim, output_dim, mask_zero=False, input_length=None, client_weight=1., trainable=True, embeddings_initializer=tf.keras.initializers.RandomUniform( -0.01, 0.01), embedding_posterior_fn=None, embedding_posterior_tensor_fn=(lambda d: d.sample()), embedding_prior_fn=None, embedding_divergence_fn=( lambda q, p, ignore: tfd.kl_divergence(q, p)), **kwargs ): self.untransformed_scale_initializer = None if 'untransformed_scale_initializer' in kwargs: self.untransformed_scale_initializer = \ kwargs.pop('untransformed_scale_initializer') if embedding_posterior_fn is None: embedding_posterior_fn = self.renormalize_natural_mean_field_normal_fn if embedding_prior_fn is None: embedding_prior_fn = self.natural_tensor_multivariate_normal_fn super(NaturalGaussianEmbedding, self).__init__(input_dim, output_dim, mask_zero=mask_zero, input_length=input_length, trainable=trainable, embeddings_initializer=embeddings_initializer, **kwargs) self.client_weight = client_weight self.delta_function = tf.subtract self.apply_delta_function = tf.add self.embedding_posterior_fn = embedding_posterior_fn self.embedding_prior_fn = embedding_prior_fn self.embedding_posterior_tensor_fn = embedding_posterior_tensor_fn self.embedding_divergence_fn = embedding_divergence_fn self.client_variable_dict = {} self.client_center_variable_dict = {} self.server_variable_dict = {}
def loss(self, features): encoder = self._make_encoder(self.latent_size, self.activation) decoder = self._make_decoder(self.latent_size, self.activation) latent_prior = self._make_prior() approx_posterior = encoder(features) approx_posterior_sample = approx_posterior.sample(self.n_samples) decoder_likelihood = decoder(approx_posterior_sample) distortion = -decoder_likelihood.log_prob(features) if self.analytic_kl: rate = tfd.kl_divergence(approx_posterior, latent_prior) else: rate = (approx_posterior.log_prob(approx_posterior_sample) - latent_prior.log_prob(approx_posterior_sample)) elbo_local = -(rate + distortion) elbo = tf.reduce_mean(elbo_local) loss = -elbo return loss