Exemple #1
0
    def testDirichletDirichletKL(self):
        conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
                          [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])
        conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]])

        d1 = dirichlet_lib.Dirichlet(conc1)
        d2 = dirichlet_lib.Dirichlet(conc2)
        x = d1.sample(int(1e4), seed=0)
        kl_sample = tf.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0)
        kl_actual = kullback_leibler.kl_divergence(d1, d2)

        kl_sample_val = self.evaluate(kl_sample)
        kl_actual_val = self.evaluate(kl_actual)

        self.assertEqual(conc1.shape[:-1], kl_actual.shape)

        if not special:
            return

        kl_expected = (
            special.gammaln(np.sum(conc1, -1)) -
            special.gammaln(np.sum(conc2, -1)) -
            np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) +
            np.sum((conc1 - conc2) *
                   (special.digamma(conc1) -
                    special.digamma(np.sum(conc1, -1, keepdims=True))), -1))

        self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6)
        self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1)

        # Make sure KL(d1||d1) is 0
        kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
        self.assertAllClose(kl_same, np.zeros_like(kl_expected))
Exemple #2
0
 def testSimpleShapes(self):
     alpha = np.random.rand(3)
     dist = dirichlet_lib.Dirichlet(alpha)
     self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
     self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
     self.assertEqual(tf.TensorShape([3]), dist.event_shape)
     self.assertEqual(tf.TensorShape([]), dist.batch_shape)
Exemple #3
0
 def testComplexShapes(self):
     alpha = np.random.rand(3, 2, 2)
     dist = dirichlet_lib.Dirichlet(alpha)
     self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
     self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
     self.assertEqual(tf.TensorShape([2]), dist.event_shape)
     self.assertEqual(tf.TensorShape([3, 2]), dist.batch_shape)
Exemple #4
0
    def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
        # Test concentration = 1. for each dimension.
        concentration = 3 * np.ones((10, 10)).astype(np.float32)
        concentration[range(10), range(10)] = 1.
        x = 1 / 9. * np.ones((10, 10)).astype(np.float32)
        x[range(10), range(10)] = 0.
        dist = dirichlet_lib.Dirichlet(concentration)
        log_prob = self.evaluate(dist.log_prob(x))
        self.assertAllEqual(np.ones_like(log_prob, dtype=np.bool),
                            np.isfinite(log_prob))

        # Test when concentration[k] = 1., and x is zero at various dimensions.
        dist = dirichlet_lib.Dirichlet(10 * [1.])
        log_prob = self.evaluate(dist.log_prob(x))
        self.assertAllEqual(np.ones_like(log_prob, dtype=np.bool),
                            np.isfinite(log_prob))
Exemple #5
0
 def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
     alpha = [[1., 2]]
     x = [[.5, .5], [.3, .7]]
     dist = dirichlet_lib.Dirichlet(alpha)
     pdf = dist.prob(x)
     self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
     self.assertEqual((2), pdf.shape)
Exemple #6
0
 def testPdfZeroBatchesNontrivialX(self):
     alpha = [1., 2]
     x = [.3, .7]
     dist = dirichlet_lib.Dirichlet(alpha)
     pdf = dist.prob(x)
     self.assertAllClose(7. / 5, self.evaluate(pdf))
     self.assertEqual((), pdf.shape)
Exemple #7
0
 def testPdfZeroBatches(self):
     alpha = [1., 2]
     x = [.5, .5]
     dist = dirichlet_lib.Dirichlet(alpha)
     pdf = dist.prob(x)
     self.assertAllClose(1., self.evaluate(pdf))
     self.assertEqual((), pdf.shape)
Exemple #8
0
 def testMean(self):
     alpha = [1., 2, 3]
     dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
     self.assertEqual(dirichlet.mean().shape, [3])
     if not stats:
         return
     expected_mean = stats.dirichlet.mean(alpha)
     self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
Exemple #9
0
 def testPdfUniformZeroBatches(self):
     # Corresponds to a uniform distribution
     alpha = [1., 1, 1]
     x = [[.2, .5, .3], [.3, .4, .3]]
     dist = dirichlet_lib.Dirichlet(alpha)
     pdf = dist.prob(x)
     self.assertAllClose([2., 2.], self.evaluate(pdf))
     self.assertEqual((2), pdf.shape)
Exemple #10
0
 def testEntropy(self):
   alpha = [1., 2, 3]
   dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
   self.assertEqual(dirichlet.entropy().shape, ())
   if not stats:
     return
   expected_entropy = stats.dirichlet.entropy(alpha)
   self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
Exemple #11
0
 def testDirichletFullyReparameterized(self):
     alpha = tf.constant([1.0, 2.0, 3.0])
     with backprop.GradientTape() as tape:
         tape.watch(alpha)
         dirichlet = dirichlet_lib.Dirichlet(alpha)
         samples = dirichlet.sample(100)
     grad_alpha = tape.gradient(samples, alpha)
     self.assertIsNotNone(grad_alpha)
Exemple #12
0
    def testModeEnableAllowNanStats(self):
        alpha = np.array([1., 2, 3])
        dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
                                            allow_nan_stats=True)
        expected_mode = np.zeros_like(alpha) + np.nan

        self.assertEqual(dirichlet.mode().shape, [3])
        self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
Exemple #13
0
 def testPdfXProper(self):
     alpha = [[1., 2, 3]]
     dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
     self.evaluate(dist.prob([.1, .3, .6]))
     self.evaluate(dist.prob([.2, .3, .5]))
     # Either condition can trigger.
     with self.assertRaisesOpError("samples must be positive"):
         self.evaluate(dist.prob([-1., 1.5, 0.5]))
     with self.assertRaisesOpError("samples must be positive"):
         self.evaluate(dist.prob([0., .1, .9]))
     with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
         self.evaluate(dist.prob([.1, .2, .8]))
Exemple #14
0
 def testVariance(self):
     alpha = [1., 2, 3]
     denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
     dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
     self.assertEqual(dirichlet.covariance().shape, (3, 3))
     if not stats:
         return
     expected_covariance = np.diag(stats.dirichlet.var(alpha))
     expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
                             ] / denominator
     self.assertAllClose(self.evaluate(dirichlet.covariance()),
                         expected_covariance)
Exemple #15
0
 def testSample(self):
     alpha = [1., 2]
     dirichlet = dirichlet_lib.Dirichlet(alpha)
     n = tf.constant(100000)
     samples = dirichlet.sample(n)
     sample_values = self.evaluate(samples)
     self.assertEqual(sample_values.shape, (100000, 2))
     self.assertTrue(np.all(sample_values > 0.0))
     if not stats:
         return
     self.assertLess(
         stats.kstest(
             # Beta is a univariate distribution.
             sample_values[:, 0],
             stats.beta(a=1., b=2.).cdf)[0],
         0.01)
Exemple #16
0
    def testCovarianceFromSampling(self):
        alpha = np.array([[1., 2, 3], [2.5, 4, 0.01]], dtype=np.float32)
        dist = dirichlet_lib.Dirichlet(
            alpha)  # batch_shape=[2], event_shape=[3]
        x = dist.sample(int(250e3), seed=1)
        sample_mean = tf.reduce_mean(x, 0)
        x_centered = x - sample_mean[None, ...]
        sample_cov = tf.reduce_mean(
            tf.matmul(x_centered[..., None], x_centered[..., None, :]), 0)
        sample_var = tf.matrix_diag_part(sample_cov)
        sample_stddev = tf.sqrt(sample_var)

        [
            sample_mean_,
            sample_cov_,
            sample_var_,
            sample_stddev_,
            analytic_mean,
            analytic_cov,
            analytic_var,
            analytic_stddev,
        ] = self.evaluate([
            sample_mean,
            sample_cov,
            sample_var,
            sample_stddev,
            dist.mean(),
            dist.covariance(),
            dist.variance(),
            dist.stddev(),
        ])

        self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
        self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
        self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
        self.assertAllClose(sample_stddev_,
                            analytic_stddev,
                            atol=0.02,
                            rtol=0.)
Exemple #17
0
 def testModeInvalid(self):
     alpha = np.array([1., 2, 3])
     dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
                                         allow_nan_stats=False)
     with self.assertRaisesOpError("Condition x < y.*"):
         self.evaluate(dirichlet.mode())
Exemple #18
0
 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
     alpha = [1., 2]
     x = [[.5, .5], [.2, .8]]
     pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
     self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
     self.assertEqual((2), pdf.shape)
Exemple #19
0
 def testConcentrationProperty(self):
     alpha = [[1., 2, 3]]
     dist = dirichlet_lib.Dirichlet(alpha)
     self.assertEqual([1, 3], dist.concentration.shape)
     self.assertAllClose(alpha, self.evaluate(dist.concentration))
Exemple #20
0
 def testPdfXStretchedInBroadcastWhenLowerRank(self):
     alpha = [[1., 2], [2., 3]]
     x = [.5, .5]
     pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
     self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
     self.assertEqual((2), pdf.shape)
Exemple #21
0
 def testMode(self):
     alpha = np.array([1.1, 2, 3])
     expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
     dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
     self.assertEqual(dirichlet.mode().shape, [3])
     self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)