示例#1
0
 def testDirichletMode(self):
     with self.test_session():
         alpha = np.array([1.1, 2, 3])
         expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
         dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
         self.assertEqual(dirichlet.mode().get_shape(), (3, ))
         self.assertAllClose(dirichlet.mode().eval(), expected_mode)
示例#2
0
 def testDirichletMean(self):
     with self.test_session():
         alpha = [1., 2, 3]
         expected_mean = stats.dirichlet.mean(alpha)
         dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
         self.assertEqual(dirichlet.mean().get_shape(), (3, ))
         self.assertAllClose(dirichlet.mean().eval(), expected_mean)
示例#3
0
 def testDirichletEntropy(self):
     with self.test_session():
         alpha = [1., 2, 3]
         expected_entropy = stats.dirichlet.entropy(alpha)
         dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
         self.assertEqual(dirichlet.entropy().get_shape(), ())
         self.assertAllClose(dirichlet.entropy().eval(), expected_entropy)
 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
     with self.test_session():
         alpha = [1., 2]
         x = [[.5, .5], [.2, .8]]
         pdf = dirichlet_lib.Dirichlet(alpha).pdf(x)
         self.assertAllClose([1., 8. / 5], pdf.eval())
         self.assertEqual((2), pdf.get_shape())
示例#5
0
 def testPdfXStretchedInBroadcastWhenLowerRank(self):
     with self.test_session():
         alpha = [[1., 2], [2., 3]]
         x = [.5, .5]
         pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
         self.assertAllClose([1., 3. / 2], pdf.eval())
         self.assertEqual((2), pdf.get_shape())
示例#6
0
 def testDirichletModeInvalid(self):
     with self.test_session():
         alpha = np.array([1., 2, 3])
         dirichlet = dirichlet_lib.Dirichlet(alpha=alpha,
                                             allow_nan_stats=False)
         with self.assertRaisesOpError("Condition x < y.*"):
             dirichlet.mode().eval()
示例#7
0
 def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
     with self.test_session():
         alpha = [[1., 2]]
         x = [[.5, .5], [.3, .7]]
         dist = dirichlet_lib.Dirichlet(alpha)
         pdf = dist.prob(x)
         self.assertAllClose([1., 7. / 5], pdf.eval())
         self.assertEqual((2), pdf.get_shape())
示例#8
0
 def testPdfZeroBatches(self):
     with self.test_session():
         alpha = [1., 2]
         x = [.5, .5]
         dist = dirichlet_lib.Dirichlet(alpha)
         pdf = dist.prob(x)
         self.assertAllClose(1., pdf.eval())
         self.assertEqual((), pdf.get_shape())
示例#9
0
 def testPdfZeroBatchesNontrivialX(self):
     with self.test_session():
         alpha = [1., 2]
         x = [.3, .7]
         dist = dirichlet_lib.Dirichlet(alpha)
         pdf = dist.prob(x)
         self.assertAllClose(7. / 5, pdf.eval())
         self.assertEqual((), pdf.get_shape())
示例#10
0
 def testSimpleShapes(self):
     with self.test_session():
         alpha = np.random.rand(3)
         dist = dirichlet_lib.Dirichlet(alpha)
         self.assertEqual(3, dist.event_shape_tensor().eval())
         self.assertAllEqual([], dist.batch_shape_tensor().eval())
         self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
         self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
示例#11
0
 def testPdfUniformZeroBatches(self):
     with self.test_session():
         # Corresponds to a uniform distribution
         alpha = [1., 1, 1]
         x = [[.2, .5, .3], [.3, .4, .3]]
         dist = dirichlet_lib.Dirichlet(alpha)
         pdf = dist.prob(x)
         self.assertAllClose([2., 2.], pdf.eval())
         self.assertEqual((2), pdf.get_shape())
示例#12
0
    def testModeEnableAllowNanStats(self):
        with self.test_session():
            alpha = np.array([1., 2, 3])
            dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
                                                allow_nan_stats=True)
            expected_mode = np.zeros_like(alpha) + np.nan

            self.assertEqual(dirichlet.mode().get_shape(), [3])
            self.assertAllClose(dirichlet.mode().eval(), expected_mode)
示例#13
0
 def testComplexShapes(self):
     with self.test_session():
         alpha = np.random.rand(3, 2, 2)
         dist = dirichlet_lib.Dirichlet(alpha)
         self.assertEqual(2, dist.event_shape().eval())
         self.assertAllEqual([3, 2], dist.batch_shape().eval())
         self.assertEqual(tensor_shape.TensorShape([2]),
                          dist.get_event_shape())
         self.assertEqual(tensor_shape.TensorShape([3, 2]),
                          dist.get_batch_shape())
 def testDirichletVariance(self):
     with self.test_session():
         alpha = [1., 2, 3]
         denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
         expected_variance = np.diag(stats.dirichlet.var(alpha))
         expected_variance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
                               ] / denominator
         dirichlet = dirichlet_lib.Dirichlet(alpha=alpha)
         self.assertEqual(dirichlet.variance().get_shape(), (3, 3))
         self.assertAllClose(dirichlet.variance().eval(), expected_variance)
示例#15
0
    def testDirichletModeEnableAllowNanStats(self):
        with self.test_session():
            alpha = np.array([1., 2, 3])
            dirichlet = dirichlet_lib.Dirichlet(alpha=alpha,
                                                allow_nan_stats=True)
            expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
            expected_mode[0] = np.nan

            self.assertEqual(dirichlet.mode().get_shape(), (3, ))
            self.assertAllClose(dirichlet.mode().eval(), expected_mode)
示例#16
0
 def testPdfXProper(self):
     alpha = [[1., 2, 3]]
     with self.test_session():
         dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
         dist.prob([.1, .3, .6]).eval()
         dist.prob([.2, .3, .5]).eval()
         # Either condition can trigger.
         with self.assertRaisesOpError("samples must be positive"):
             dist.prob([-1., 1.5, 0.5]).eval()
         with self.assertRaisesOpError("samples must be positive"):
             dist.prob([0., .1, .9]).eval()
         with self.assertRaisesOpError(
                 "sample last-dimension must sum to `1`"):
             dist.prob([.1, .2, .8]).eval()
示例#17
0
 def testPdfXProper(self):
     alpha = [[1., 2, 3]]
     with self.test_session():
         dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
         dist.prob([.1, .3, .6]).eval()
         dist.prob([.2, .3, .5]).eval()
         # Either condition can trigger.
         with self.assertRaisesOpError(
                 "Condition x > 0.*|Condition x < y.*"):
             dist.prob([-1., 1, 1]).eval()
         with self.assertRaisesOpError("Condition x > 0.*"):
             dist.prob([0., .1, .9]).eval()
         with self.assertRaisesOpError("Condition x ~= y.*"):
             dist.prob([.1, .2, .8]).eval()
示例#18
0
 def testDirichletSample(self):
     with self.test_session():
         alpha = [1., 2]
         dirichlet = dirichlet_lib.Dirichlet(alpha)
         n = constant_op.constant(100000)
         samples = dirichlet.sample(n)
         sample_values = samples.eval()
         self.assertEqual(sample_values.shape, (100000, 2))
         self.assertTrue(np.all(sample_values > 0.0))
         self.assertLess(
             stats.kstest(
                 # Beta is a univariate distribution.
                 sample_values[:, 0],
                 stats.beta(a=1., b=2.).cdf)[0],
             0.01)
示例#19
0
 def testCovarianceFromSampling(self):
     alpha = np.array([[1., 2, 3], [2.5, 4, 0.01]], dtype=np.float32)
     with self.test_session() as sess:
         dist = dirichlet_lib.Dirichlet(
             alpha)  # batch_shape=[2], event_shape=[3]
         x = dist.sample(int(250e3), seed=1)
         sample_mean = math_ops.reduce_mean(x, 0)
         x_centered = x - sample_mean[None, ...]
         sample_cov = math_ops.reduce_mean(
             math_ops.matmul(x_centered[..., None], x_centered[...,
                                                               None, :]), 0)
         sample_var = array_ops.matrix_diag_part(sample_cov)
         sample_stddev = math_ops.sqrt(sample_var)
         [
             sample_mean_,
             sample_cov_,
             sample_var_,
             sample_stddev_,
             analytic_mean,
             analytic_cov,
             analytic_var,
             analytic_stddev,
         ] = sess.run([
             sample_mean,
             sample_cov,
             sample_var,
             sample_stddev,
             dist.mean(),
             dist.covariance(),
             dist.variance(),
             dist.stddev(),
         ])
         self.assertAllClose(sample_mean_,
                             analytic_mean,
                             atol=0.,
                             rtol=0.04)
         self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
         self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
         self.assertAllClose(sample_stddev_,
                             analytic_stddev,
                             atol=0.,
                             rtol=0.02)
示例#20
0
 def testAlphaProperty(self):
     alpha = [[1., 2, 3]]
     with self.test_session():
         dist = dirichlet_lib.Dirichlet(alpha)
         self.assertEqual([1, 3], dist.alpha.get_shape())
         self.assertAllClose(alpha, dist.alpha.eval())