Exemple #1
0
    def testLogNormalLogNormalKL(self):
        batch_size = 6
        mu_a = np.array([3.0] * batch_size)
        sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
        mu_b = np.array([-3.0] * batch_size)
        sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

        ln_a = tfd.LogNormal(loc=mu_a, scale=sigma_a, validate_args=True)
        ln_b = tfd.LogNormal(loc=mu_b, scale=sigma_b, validate_args=True)

        kl = tfd.kl_divergence(ln_a, ln_b)
        kl_val = self.evaluate(kl)

        normal_a = tfd.Normal(loc=mu_a, scale=sigma_a, validate_args=True)
        normal_b = tfd.Normal(loc=mu_b, scale=sigma_b, validate_args=True)
        kl_expected_from_normal = tfd.kl_divergence(normal_a, normal_b)

        kl_expected_from_formula = (
            (mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 *
            ((sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))

        x = ln_a.sample(int(2e5), seed=test_util.test_seed())
        kl_sample = tf.reduce_mean(ln_a.log_prob(x) - ln_b.log_prob(x), axis=0)
        kl_sample_ = self.evaluate(kl_sample)

        self.assertEqual(kl.shape, (batch_size, ))
        self.assertAllClose(kl_val, kl_expected_from_normal)
        self.assertAllClose(kl_val, kl_expected_from_formula)
        self.assertAllClose(kl_expected_from_formula,
                            kl_sample_,
                            atol=0.0,
                            rtol=1e-2)
Exemple #2
0
    def testKLRaises(self):
        ind1 = tfd.Independent(distribution=tfd.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                               reinterpreted_batch_ndims=1,
                               validate_args=True)
        ind2 = tfd.Independent(distribution=tfd.Normal(loc=np.float32(-1),
                                                       scale=np.float32(0.5)),
                               reinterpreted_batch_ndims=0,
                               validate_args=True)

        with self.assertRaisesRegexp(ValueError, 'Event shapes do not match'):
            tfd.kl_divergence(ind1, ind2)

        ind1 = tfd.Independent(distribution=tfd.Normal(
            loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5])),
                               reinterpreted_batch_ndims=1,
                               validate_args=True)
        ind2 = tfd.Independent(distribution=tfd.MultivariateNormalDiag(
            loc=np.float32([-1., 1]), scale_diag=np.float32([0.1, 0.5])),
                               reinterpreted_batch_ndims=0,
                               validate_args=True)

        with self.assertRaisesRegexp(NotImplementedError,
                                     'different event shapes'):
            tfd.kl_divergence(ind1, ind2)
    def testUniformUniformKLFinite(self):
        batch_size = 6

        a_low = -1.0 * np.arange(1, batch_size + 1)
        a_high = np.array([1.0] * batch_size)
        b_low = -2.0 * np.arange(1, batch_size + 1)
        b_high = np.array([2.0] * batch_size)
        a = tfd.Uniform(low=a_low, high=a_high, validate_args=True)
        b = tfd.Uniform(low=b_low, high=b_high, validate_args=True)

        true_kl = np.log(b_high - b_low) - np.log(a_high - a_low)

        kl = tfd.kl_divergence(a, b)

        # This is essentially an approximated integral from the direct definition
        # of KL divergence.
        x = a.sample(int(1e4), seed=test_util.test_seed())
        kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0)

        kl_, kl_sample_ = self.evaluate([kl, kl_sample])
        self.assertAllClose(true_kl, kl_, atol=2e-15)
        self.assertAllClose(true_kl, kl_sample_, atol=0.0, rtol=1e-1)

        zero_kl = tfd.kl_divergence(a, a)
        true_zero_kl_, zero_kl_ = self.evaluate(
            [tf.zeros_like(true_kl), zero_kl])
        self.assertAllEqual(true_zero_kl_, zero_kl_)
  def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      ${args}
    """
    # pylint: enable=g-doc-args
    super(_DenseVariational, self).__init__(
        activity_regularizer=activity_regularizer,
        **kwargs)
    self.units = units
    self.activation = tf.keras.activations.get(activation)
    self.input_spec = tf.layers.InputSpec(min_ndim=2)
    self.kernel_posterior_fn = kernel_posterior_fn
    self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
    self.kernel_prior_fn = kernel_prior_fn
    self.kernel_divergence_fn = kernel_divergence_fn
    self.bias_posterior_fn = bias_posterior_fn
    self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
    self.bias_prior_fn = bias_prior_fn
    self.bias_divergence_fn = bias_divergence_fn
  def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      ${args}
    """
    # pylint: enable=g-doc-args
    super(_DenseVariational, self).__init__(
        activity_regularizer=activity_regularizer,
        **kwargs)
    self.units = units
    self.activation = tf.keras.activations.get(activation)
    self.input_spec = tf.keras.layers.InputSpec(min_ndim=2)
    self.kernel_posterior_fn = kernel_posterior_fn
    self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
    self.kernel_prior_fn = kernel_prior_fn
    self.kernel_divergence_fn = kernel_divergence_fn
    self.bias_posterior_fn = bias_posterior_fn
    self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
    self.bias_prior_fn = bias_prior_fn
    self.bias_divergence_fn = bias_divergence_fn
    def __init__(
            self, units,
            activation=None,
            activity_regularizer=None,
            client_weight=1.,
            trainable=True,
            kernel_posterior_fn=None,
            kernel_posterior_tensor_fn=(lambda d: d.sample()),
            kernel_prior_fn=None,
            kernel_divergence_fn=(
                    lambda q, p, ignore: tfd.kl_divergence(q, p)),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=(lambda d: d.sample()),
            bias_prior_fn=None,
            bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
            **kwargs):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')
        self.loc_initializer = None
        if 'loc_initializer' in kwargs:
            self.loc_initializer = \
                kwargs.pop('loc_initializer')

        self.delta_percentile = kwargs.pop('delta_percentile', None)

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(DenseSharedNatural, self).\
            __init__(units,
                     activation=activation,
                     activity_regularizer=activity_regularizer,
                     trainable=trainable,
                     kernel_posterior_fn=kernel_posterior_fn,
                     kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
                     kernel_prior_fn=kernel_prior_fn,
                     kernel_divergence_fn=kernel_divergence_fn,
                     bias_posterior_fn=bias_posterior_fn,
                     bias_posterior_tensor_fn=bias_posterior_tensor_fn,
                     bias_prior_fn=bias_prior_fn,
                     bias_divergence_fn=bias_divergence_fn,
                     **kwargs)

        self.client_weight = client_weight
        self.delta_function = tf.subtract
        if self.delta_percentile and not activation == 'softmax':
            self.delta_function = sparse_delta_function(self.delta_percentile)
            print(self, activation, 'using delta sparisfication')
        self.apply_delta_function = tf.add
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
    def __init__(
            self,
            filters,
            kernel_size,
            strides=1,
            padding='valid',
            client_weight=1.,
            data_format='channels_last',
            dilation_rate=1,
            activation=None,
            activity_regularizer=None,
            kernel_posterior_fn=None,
            kernel_posterior_tensor_fn=(lambda d: d.sample()),
            kernel_prior_fn=None,
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=
            tfp_layers_util.default_mean_field_normal_fn(is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            **kwargs):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(Conv1DVirtualNatural, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=tf.keras.activations.get(activation),
            activity_regularizer=activity_regularizer,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)

        self.client_weight = client_weight
        self.delta_function = tf.subtract
        self.apply_delta_function = tf.add
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
 def testTransformedKLDifferentBijectorFails(self):
     d1 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=2.),
                      validate_args=True)
     d2 = self._cls()(tfd.Exponential(rate=0.25),
                      bijector=tfb.Scale(scale=3.),
                      validate_args=True)
     with self.assertRaisesRegex(NotImplementedError,
                                 r'their bijectors are not equal'):
         tfd.kl_divergence(d1, d2)
Exemple #9
0
 def test_kl(self):
   a = tfd.MultivariateNormalDiag([tf.range(3.)] * 4, tf.ones(3))
   b = tfd.MultivariateNormalDiag([tf.range(3.) + .5] * 4, tf.ones(3))
   kl = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(3, 4)),
                          tfd.Masked(b, tf.sequence_mask(2, 4)))
   kl2 = tfd.kl_divergence(tfd.Masked(a, tf.sequence_mask(2, 4)),
                           tfd.Masked(b, tf.sequence_mask(3, 4)))
   self.assertAllClose(a.kl_divergence(b)[:2], kl[:2])
   self.assertAllEqual(float('nan'), kl[2])
   self.assertAllEqual(0., kl[3])
   self.assertAllEqual(float('nan'), kl2[2])
Exemple #10
0
    def testKLBatchBroadcast(self):
        batch_shape = [2]
        event_shape = [3]
        mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
        # No batch shape.
        mu_b, sigma_b = self._random_mu_and_sigma([], event_shape)
        mvn_a = tfd.MultivariateNormalTriL(
            loc=mu_a,
            scale_tril=np.linalg.cholesky(sigma_a),
            validate_args=True)
        mvn_b = tfd.MultivariateNormalTriL(
            loc=mu_b,
            scale_tril=np.linalg.cholesky(sigma_b),
            validate_args=True)

        kl = tfd.kl_divergence(mvn_a, mvn_b)
        self.assertEqual(batch_shape, kl.shape)

        kl_v = self.evaluate(kl)
        expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
                                              mu_b, sigma_b)
        expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
                                              mu_b, sigma_b)
        self.assertAllClose(expected_kl_0, kl_v[0])
        self.assertAllClose(expected_kl_1, kl_v[1])
    def testWeibullWeibullKL(self):
        a_concentration = np.array([3.])
        a_scale = np.array([2.])
        b_concentration = np.array([6.])
        b_scale = np.array([4.])

        a = tfd.Weibull(concentration=a_concentration,
                        scale=a_scale,
                        validate_args=True)
        b = tfd.Weibull(concentration=b_concentration,
                        scale=b_scale,
                        validate_args=True)

        kl = tfd.kl_divergence(a, b)
        expected_kl = (
            np.log(a_concentration / a_scale**a_concentration) -
            np.log(b_concentration / b_scale**b_concentration) +
            ((a_concentration - b_concentration) *
             (np.log(a_scale) - np.euler_gamma / a_concentration)) +
            ((a_scale / b_scale)**b_concentration *
             np.exp(np.math.lgamma(b_concentration / a_concentration + 1.))) -
            1.)

        x = a.sample(int(1e5), seed=test_util.test_seed())
        kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0)
        kl_sample_val = self.evaluate(kl_sample)

        self.assertAllClose(expected_kl, kl_sample_val, atol=0.0, rtol=1e-2)
        self.assertAllClose(expected_kl, self.evaluate(kl))
  def testKLBatchBroadcast(self):
    batch_shape = [2]
    event_shape = [3]
    loc_a, scale_a = self._random_loc_and_scale(batch_shape, event_shape)
    # No batch shape.
    loc_b, scale_b = self._random_loc_and_scale([], event_shape)
    mvn_a = tfd.MultivariateNormalLinearOperator(
        loc=loc_a, scale=scale_a, validate_args=True)
    mvn_b = tfd.MultivariateNormalLinearOperator(
        loc=loc_b, scale=scale_b, validate_args=True)

    kl = tfd.kl_divergence(mvn_a, mvn_b)
    self.assertEqual(batch_shape, kl.shape)

    kl_v = self.evaluate(kl)
    expected_kl_0 = self._compute_non_batch_kl(
        loc_a[0, :],
        self.evaluate(scale_a.to_dense())[0, :, :], loc_b,
        self.evaluate(scale_b.to_dense()))
    expected_kl_1 = self._compute_non_batch_kl(
        loc_a[1, :],
        self.evaluate(scale_a.to_dense())[1, :, :], loc_b,
        self.evaluate(scale_b.to_dense()))
    self.assertAllClose(expected_kl_0, kl_v[0])
    self.assertAllClose(expected_kl_1, kl_v[1])
Exemple #13
0
    def testNormalNormalKL(self):
        batch_size = 6
        mu_a = np.array([3.0] * batch_size)
        sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
        mu_b = np.array([-3.0] * batch_size)
        sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])

        n_a = tfd.Normal(loc=mu_a, scale=sigma_a, validate_args=True)
        n_b = tfd.Normal(loc=mu_b, scale=sigma_b, validate_args=True)

        kl = tfd.kl_divergence(n_a, n_b)
        kl_val = self.evaluate(kl)

        kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
            (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))

        x = n_a.sample(int(1e5), seed=test_util.test_seed())
        kl_samples = n_a.log_prob(x) - n_b.log_prob(x)
        kl_samples_ = self.evaluate(kl_samples)

        self.assertEqual(kl.shape, (batch_size, ))
        self.assertAllClose(kl_val, kl_expected)
        self.assertAllMeansClose(kl_samples_,
                                 kl_expected,
                                 axis=0,
                                 atol=0.0,
                                 rtol=1e-2)
Exemple #14
0
  def testKLScalarToMultivariate(self):
    normal1 = tfd.Normal(
        loc=np.float32([-1., 1]), scale=np.float32([0.1, 0.5]))
    ind1 = tfd.Independent(distribution=normal1, reinterpreted_batch_ndims=1,
                           validate_args=True)

    normal2 = tfd.Normal(
        loc=np.float32([-3., 3]), scale=np.float32([0.3, 0.3]))
    ind2 = tfd.Independent(distribution=normal2, reinterpreted_batch_ndims=1,
                           validate_args=True)

    normal_kl = tfd.kl_divergence(normal1, normal2)
    ind_kl = tfd.kl_divergence(ind1, ind2)
    self.assertAllClose(
        self.evaluate(tf.reduce_sum(input_tensor=normal_kl, axis=-1)),
        self.evaluate(ind_kl))
Exemple #15
0
  def testContinuousBernoulliContinuousBernoulliKL(self):
    batch_size = 6
    a_p = np.array([0.6] * batch_size, dtype=np.float32)
    b_p = np.array([0.4] * batch_size, dtype=np.float32)

    a = tfd.ContinuousBernoulli(probs=a_p, validate_args=True)
    b = tfd.ContinuousBernoulli(probs=b_p, validate_args=True)

    kl = tfd.kl_divergence(a, b)
    kl_val = self.evaluate(kl)

    kl_expected = (
        mean(a_p)
        * (
            np.log(a_p)
            + np.log1p(-b_p)
            - np.log(b_p)
            - np.log1p(-a_p)
        )
        + log_norm_const(a_p)
        - log_norm_const(b_p)
        + np.log1p(-a_p)
        - np.log1p(-b_p)
    )

    self.assertEqual(kl.shape, (batch_size,))
    self.assertAllClose(kl_val, kl_expected)
Exemple #16
0
  def test_docstring_example_bernoulli(self):
    num_draws = int(1e5)
    probs_p = tf.constant(0.4)
    probs_q = tf.constant(0.7)
    with tf.GradientTape(persistent=True) as tape:
      tape.watch(probs_p)
      tape.watch(probs_q)
      p = tfd.Bernoulli(probs=probs_p)
      q = tfd.Bernoulli(probs=probs_q)
      exact_kl_bernoulli_bernoulli = tfp.monte_carlo.expectation(
          f=lambda x: p.log_prob(x) - q.log_prob(x),
          samples=p.sample(num_draws, seed=42),
          log_prob=p.log_prob,
          use_reparametrization=(
              p.reparameterization_type == tfd.FULLY_REPARAMETERIZED))
      approx_kl_bernoulli_bernoulli = tfd.kl_divergence(p, q)
    [
        exact_kl_bernoulli_bernoulli_,
        approx_kl_bernoulli_bernoulli_,
    ] = self.evaluate([
        exact_kl_bernoulli_bernoulli,
        approx_kl_bernoulli_bernoulli,
    ])
    self.assertEqual(False,
                     p.reparameterization_type == tfd.FULLY_REPARAMETERIZED)
    self.assertAllClose(
        exact_kl_bernoulli_bernoulli_,
        approx_kl_bernoulli_bernoulli_,
        rtol=0.01,
        atol=0.)
    print(exact_kl_bernoulli_bernoulli_, approx_kl_bernoulli_bernoulli_)

    # Compare gradients. (Not present in `docstring`.)
    gradp = lambda fp: tape.gradient(fp, probs_p)
    gradq = lambda fq: tape.gradient(fq, probs_q)
    [
        gradp_exact_kl_bernoulli_bernoulli_,
        gradq_exact_kl_bernoulli_bernoulli_,
        gradp_approx_kl_bernoulli_bernoulli_,
        gradq_approx_kl_bernoulli_bernoulli_,
    ] = self.evaluate([
        gradp(exact_kl_bernoulli_bernoulli),
        gradq(exact_kl_bernoulli_bernoulli),
        gradp(approx_kl_bernoulli_bernoulli),
        gradq(approx_kl_bernoulli_bernoulli),
    ])
    # Notice that variance (i.e., `rtol`) is higher when using score-trick.
    self.assertAllClose(
        gradp_exact_kl_bernoulli_bernoulli_,
        gradp_approx_kl_bernoulli_bernoulli_,
        rtol=0.05,
        atol=0.)
    self.assertAllClose(
        gradq_exact_kl_bernoulli_bernoulli_,
        gradq_approx_kl_bernoulli_bernoulli_,
        rtol=0.03,
        atol=0.)
    def testKLIdentity(self):
        normal1 = tfd.Normal(loc=np.float32([-1., 1]),
                             scale=np.float32([0.1, 0.5]))
        # This is functionally just a wrapper around normal1,
        # and doesn't change any outputs.
        ind1 = tfd.Independent(distribution=normal1,
                               reinterpreted_batch_ndims=0)

        normal2 = tfd.Normal(loc=np.float32([-3., 3]),
                             scale=np.float32([0.3, 0.3]))
        # This is functionally just a wrapper around normal2,
        # and doesn't change any outputs.
        ind2 = tfd.Independent(distribution=normal2,
                               reinterpreted_batch_ndims=0)

        normal_kl = tfd.kl_divergence(normal1, normal2)
        ind_kl = tfd.kl_divergence(ind1, ind2)
        self.assertAllClose(self.evaluate(normal_kl), self.evaluate(ind_kl))
Exemple #18
0
 def test_kl_divergence(self):
   d0 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[100, 120]), 1),
            x=tfd.Normal(loc=0, scale=2.)),
       validate_args=True)
   d1 = tfd.JointDistributionNamed(
       dict(e=tfd.Independent(tfd.Exponential(rate=[10, 12]), 1),
            x=tfd.Normal(loc=1, scale=1.)),
       validate_args=True)
   self.assertEqual(d0.model.keys(), d1.model.keys())
   expected_kl = sum(tfd.kl_divergence(d0.model[k], d1.model[k])
                     for k in d0.model.keys())
   actual_kl = tfd.kl_divergence(d0, d1)
   other_actual_kl = d0.kl_divergence(d1)
   expected_kl_, actual_kl_, other_actual_kl_ = self.evaluate([
       expected_kl, actual_kl, other_actual_kl])
   self.assertNear(expected_kl_, actual_kl_, err=1e-5)
   self.assertNear(expected_kl_, other_actual_kl_, err=1e-5)
    def testKLMultivariateToMultivariate(self):
        # (1, 1, 2) batch of MVNDiag
        mvn1 = tfd.MultivariateNormalDiag(
            loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
            scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
        ind1 = tfd.Independent(distribution=mvn1, reinterpreted_batch_ndims=2)

        # (1, 1, 2) batch of MVNDiag
        mvn2 = tfd.MultivariateNormalDiag(
            loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
            scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))

        ind2 = tfd.Independent(distribution=mvn2, reinterpreted_batch_ndims=2)

        mvn_kl = tfd.kl_divergence(mvn1, mvn2)
        ind_kl = tfd.kl_divergence(ind1, ind2)
        self.assertAllClose(
            self.evaluate(tf.reduce_sum(input_tensor=mvn_kl, axis=[-1, -2])),
            self.evaluate(ind_kl))
Exemple #20
0
 def compute_kl_div(self):
     ''' compute KL( q(z|c) || r(z) ) '''
     prior = tfd.Normal(tf.zeros(self.latent_dim), tf.ones(self.latent_dim))
     posteriors = [
         tfd.Normal(mu, tf.math.sqrt(var)) for mu, var in zip(
             tf.unstack(self.z_means), tf.unstack(self.z_vars))
     ]
     kl_divs = [tfd.kl_divergence(post, prior) for post in posteriors]
     kl_div_sum = tf.reduce_sum(tf.stack(kl_divs))
     return kl_div_sum
  def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      trainable=True,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
          is_singular=True),
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      seed=None,
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      ${args}
      seed: Python scalar `int` which initializes the random number
        generator. Default value: `None` (i.e., use global seed).
    """
    # pylint: enable=g-doc-args
    super(DenseFlipout, self).__init__(
        units=units,
        activation=activation,
        activity_regularizer=activity_regularizer,
        trainable=trainable,
        kernel_posterior_fn=kernel_posterior_fn,
        kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
        kernel_prior_fn=kernel_prior_fn,
        kernel_divergence_fn=kernel_divergence_fn,
        bias_posterior_fn=bias_posterior_fn,
        bias_posterior_tensor_fn=bias_posterior_tensor_fn,
        bias_prior_fn=bias_prior_fn,
        bias_divergence_fn=bias_divergence_fn,
        **kwargs)
    # Set additional attributes which do not exist in the parent class.
    self.seed = seed
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            seed=None,
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      ${args}
      seed: Python scalar `int` which initializes the random number
        generator. Default value: `None` (i.e., use global seed).
    """
        # pylint: enable=g-doc-args
        super(DenseFlipout, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)
        # Set additional attributes which do not exist in the parent class.
        self.seed = seed
Exemple #23
0
 def test_kl_divergence(self):
   q_scale = 2.
   p = tfd.Sample(
       tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=1), 1), [5, 4],
       validate_args=True)
   q = tfd.Sample(
       tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=2.), 1), [5, 4],
       validate_args=True)
   actual_kl = tfd.kl_divergence(p, q)
   expected_kl = ((5 * 4) *
                  (0.5 * q_scale**-2. - 0.5 + np.log(q_scale)) *  # Actual KL.
                  np.ones([3]) * 2)  # Batch, events.
   self.assertAllClose(expected_kl, self.evaluate(actual_kl))
Exemple #24
0
  def testKLTwoIdenticalDistributionsIsZero(self):
    batch_shape = [2]
    event_shape = [3]
    mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    mvn_a = tfd.MultivariateNormalTriL(
        loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True)

    # Should be zero since KL(p || p) = =.
    kl = tfd.kl_divergence(mvn_a, mvn_a)
    self.assertEqual(batch_shape, kl.shape)

    kl_v = self.evaluate(kl)
    self.assertAllClose(np.zeros(*batch_shape), kl_v)
Exemple #25
0
    def testWeibullGammaKLAgreeWeibullWeibull(self):
        a_concentration = np.array([3.])
        a_scale = np.array([2.])
        b_concentration = np.array([1.])
        b_rate = np.array([0.25])

        a = tfd.Weibull(concentration=a_concentration,
                        scale=a_scale,
                        validate_args=True)
        b = tfd.Gamma(concentration=b_concentration,
                      rate=b_rate,
                      validate_args=True)
        c = tfd.Weibull(concentration=b_concentration,
                        scale=1 / b_rate,
                        validate_args=True)

        kl_weibull_weibull = tfd.kl_divergence(a, c)
        kl_weibull_gamma = tfd.kl_divergence(a, b)

        self.assertAllClose(self.evaluate(kl_weibull_gamma),
                            self.evaluate(kl_weibull_weibull),
                            atol=0.0,
                            rtol=1e-6)
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      ${args}
    """
        # pylint: enable=g-doc-args
        super(DenseLocalReparameterization, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)
  def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      trainable=True,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
          is_singular=True),
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      ${args}
    """
    # pylint: enable=g-doc-args
    super(DenseLocalReparameterization, self).__init__(
        units=units,
        activation=activation,
        activity_regularizer=activity_regularizer,
        trainable=trainable,
        kernel_posterior_fn=kernel_posterior_fn,
        kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
        kernel_prior_fn=kernel_prior_fn,
        kernel_divergence_fn=kernel_divergence_fn,
        bias_posterior_fn=bias_posterior_fn,
        bias_posterior_tensor_fn=bias_posterior_tensor_fn,
        bias_prior_fn=bias_prior_fn,
        bias_divergence_fn=bias_divergence_fn,
        **kwargs)
Exemple #28
0
  def testBernoulliBernoulliKL(self):
    batch_size = 6
    a_p = np.array([0.6] * batch_size, dtype=np.float32)
    b_p = np.array([0.4] * batch_size, dtype=np.float32)

    a = tfd.Bernoulli(probs=a_p, validate_args=True)
    b = tfd.Bernoulli(probs=b_p, validate_args=True)

    kl = tfd.kl_divergence(a, b)
    kl_val = self.evaluate(kl)

    kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
        (1. - a_p) / (1. - b_p)))

    self.assertEqual(kl.shape, (batch_size,))
    self.assertAllClose(kl_val, kl_expected)
Exemple #29
0
 def testBernoulliBernoulliKLWhenProbOneIsZero(self):
     # KL[a || b] = Pa * Log[Pa / Pb] + (1 - Pa) * Log[(1 - Pa) / (1 - Pb)]
     # This is defined iff (Pb = 0 ==> Pa = 0) AND (Pb = 1 ==> Pa = 1).
     a = tfd.Bernoulli(probs=[0., 0., 0.])
     b = tfd.Bernoulli(probs=[0.5, 1., 0.])
     kl_expected = [
         # The Pa term kills the entire first term.
         0 + 1 * np.log(1 / 0.5),
         # P[b = 0] = 0, but P[a = 0] != 0, so not absolutely continuous.
         # Some would argue that NaN would be more correct...
         np.inf,
         # P[b = 1] = 0, and P[a = 1] = 0, so absolute continuity holds.
         0 + 1 * np.log(1 / 1)
     ]
     self.assertAllClose(self.evaluate(tfd.kl_divergence(a, b)),
                         kl_expected)
Exemple #30
0
  def testKLNonBatch(self):
    batch_shape = []
    event_shape = [2]
    mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
    mvn_a = tfd.MultivariateNormalTriL(
        loc=mu_a, scale_tril=np.linalg.cholesky(sigma_a), validate_args=True)
    mvn_b = tfd.MultivariateNormalTriL(
        loc=mu_b, scale_tril=np.linalg.cholesky(sigma_b), validate_args=True)

    kl = tfd.kl_divergence(mvn_a, mvn_b)
    self.assertEqual(batch_shape, kl.shape)

    kl_v = self.evaluate(kl)
    expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
    self.assertAllClose(expected_kl, kl_v)
Exemple #31
0
    def test_docstring_example_normal(self):
        num_draws = int(1e5)
        mu_p = tf.constant(0.)
        mu_q = tf.constant(1.)
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(mu_p)
            tape.watch(mu_q)
            p = tfd.Normal(loc=mu_p, scale=1.)
            q = tfd.Normal(loc=mu_q, scale=2.)
            exact_kl_normal_normal = tfd.kl_divergence(p, q)
            approx_kl_normal_normal = tfp.monte_carlo.expectation(
                f=lambda x: p.log_prob(x) - q.log_prob(x),
                samples=p.sample(num_draws, seed=42),
                log_prob=p.log_prob,
                use_reparameterization=(
                    p.reparameterization_type == tfd.FULLY_REPARAMETERIZED))
        [exact_kl_normal_normal_, approx_kl_normal_normal_
         ] = self.evaluate([exact_kl_normal_normal, approx_kl_normal_normal])
        self.assertEqual(
            True, p.reparameterization_type == tfd.FULLY_REPARAMETERIZED)
        self.assertAllClose(exact_kl_normal_normal_,
                            approx_kl_normal_normal_,
                            rtol=0.01,
                            atol=0.)

        # Compare gradients. (Not present in `docstring`.)
        gradp = lambda fp: tape.gradient(fp, mu_p)
        gradq = lambda fq: tape.gradient(fq, mu_q)
        [
            gradp_exact_kl_normal_normal_,
            gradq_exact_kl_normal_normal_,
            gradp_approx_kl_normal_normal_,
            gradq_approx_kl_normal_normal_,
        ] = self.evaluate([
            gradp(exact_kl_normal_normal),
            gradq(exact_kl_normal_normal),
            gradp(approx_kl_normal_normal),
            gradq(approx_kl_normal_normal),
        ])
        self.assertAllClose(gradp_exact_kl_normal_normal_,
                            gradp_approx_kl_normal_normal_,
                            rtol=0.01,
                            atol=0.)
        self.assertAllClose(gradq_exact_kl_normal_normal_,
                            gradq_approx_kl_normal_normal_,
                            rtol=0.01,
                            atol=0.)
    def __init__(self,
                 input_dim,
                 output_dim,
                 mask_zero=False,
                 input_length=None,
                 client_weight=1.,
                 trainable=True,
                 embeddings_initializer=tf.keras.initializers.RandomUniform(
                     -0.01, 0.01),
                 embedding_posterior_fn=None,
                 embedding_posterior_tensor_fn=(lambda d: d.sample()),
                 embedding_prior_fn=None,
                 embedding_divergence_fn=(
                         lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 **kwargs
                 ):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')

        if embedding_posterior_fn is None:
            embedding_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if embedding_prior_fn is None:
            embedding_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(NaturalGaussianEmbedding, self).__init__(input_dim,
                                                       output_dim,
                                                       mask_zero=mask_zero,
                                                       input_length=input_length,
                                                       trainable=trainable,
                                                       embeddings_initializer=embeddings_initializer,
                                                       **kwargs)

        self.client_weight = client_weight
        self.delta_function = tf.subtract
        self.apply_delta_function = tf.add
        self.embedding_posterior_fn = embedding_posterior_fn
        self.embedding_prior_fn = embedding_prior_fn
        self.embedding_posterior_tensor_fn = embedding_posterior_tensor_fn
        self.embedding_divergence_fn = embedding_divergence_fn
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
Exemple #33
0
    def loss(self, features):
        encoder = self._make_encoder(self.latent_size, self.activation)
        decoder = self._make_decoder(self.latent_size, self.activation)
        latent_prior = self._make_prior()

        approx_posterior = encoder(features)
        approx_posterior_sample = approx_posterior.sample(self.n_samples)
        decoder_likelihood = decoder(approx_posterior_sample)
        distortion = -decoder_likelihood.log_prob(features)
        if self.analytic_kl:
            rate = tfd.kl_divergence(approx_posterior, latent_prior)
        else:
            rate = (approx_posterior.log_prob(approx_posterior_sample) -
                    latent_prior.log_prob(approx_posterior_sample))
        elbo_local = -(rate + distortion)

        elbo = tf.reduce_mean(elbo_local)
        loss = -elbo
        return loss