Esempio n. 1
0
 def testShapeAssertion(self):
   dist_type = tfp.distributions.MultivariateNormalDiag
   _, dist1_mean, dist1_cov = self._create_gaussian(dist_type)
   _, dist2_mean, dist2_cov = self._create_gaussian(dist_type)
   shape_error_regexp = 'Shape (.*) must have rank [0-9]+'
   with self.assertRaisesRegexp(ValueError, shape_error_regexp):
     distribution_ops.factorised_kl_gaussian(
         dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=True)
Esempio n. 2
0
    def testConsistentGradientsFullCovariance(self):
        dist_type = tfp.distributions.MultivariateNormalFullCovariance
        dist1, dist1_mean, dist1_cov = self._create_gaussian(dist_type)
        dist2, dist2_mean, dist2_cov = self._create_gaussian(dist_type)

        kl = tfp.distributions.kl_divergence(dist1, dist2)
        kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
            dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=False)

        dist1_cov = dist1.parameters['covariance_matrix']
        dist2_cov = dist2.parameters['covariance_matrix']
        dist_params = [
            dist1_mean,
            dist2_mean,
            dist1_cov,
            dist2_cov,
        ]
        actual_kl_gradients = tf.gradients(kl, dist_params)
        factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params)

        # Check that no gradients flow into the mean terms from `kl_cov` and
        # vice-versa.
        gradients = tf.gradients(kl_mean, [dist1_cov])
        self.assertListEqual(gradients, [None])
        gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean])
        self.assertListEqual(gradients, [None, None])

        with self.test_session() as sess:
            np_actual_kl, np_factorised_kl = sess.run(
                [actual_kl_gradients, factorised_kl_gradients])
            self.assertAllClose(np_actual_kl, np_factorised_kl)
Esempio n. 3
0
 def testFactorisedKLGaussian(self, dist1_type, dist2_type):
     """Tests that the factorised KL terms sum up to the true KL."""
     dist1, dist1_mean, dist1_cov = self._create_gaussian(dist1_type)
     dist2, dist2_mean, dist2_cov = self._create_gaussian(dist2_type)
     both_diagonal = _is_diagonal(dist1.scale) and _is_diagonal(dist2.scale)
     if both_diagonal:
         dist1_cov = dist1.parameters['scale_diag']
         dist2_cov = dist2.parameters['scale_diag']
     kl = tfp.distributions.kl_divergence(dist1, dist2)
     kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian(
         dist1_mean,
         dist1_cov,
         dist2_mean,
         dist2_cov,
         both_diagonal=both_diagonal)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         actual_kl, kl_mean_np, kl_cov_np = sess.run([kl, kl_mean, kl_cov])
         self.assertAllClose(actual_kl, kl_mean_np + kl_cov_np, rtol=1e-4)