def test_mean_x(self): alpha = (1.0, 2.0, 3.0, 4.0) xx = (2.0, 2.0, 2.0, 2.0) m = Dirichlet(alpha).mean_x(xx) self.assertEqual(m, 2.0) xx2 = (2.0, 2.0, 2.0, 2.0, 2.0) self.assertRaises(ValueError, Dirichlet(alpha).mean_x, xx2) alpha = (1.0, 1.0, 1.0, 1.0) xx = (2.0, 3.0, 4.0, 3.0) m = Dirichlet(alpha).mean_x(xx) self.assertEqual(m, 3.0)
def test_variance_x(self): alpha = (1.0, 1.0, 1.0, 1.0) xx = (2.0, 2.0, 2.0, 2.0) v = Dirichlet(alpha).variance_x(xx) self.assertAlmostEqual(v, 0.0) alpha = (1.0, 2.0, 3.0, 4.0) xx = (2.0, 0.0, 1.0, 10.0) v = Dirichlet(alpha).variance_x(xx) # print(v) # TODO: Don't actually know if this is correct xx2 = (2.0, 2.0, 2.0, 2.0, 2.0) self.assertRaises(ValueError, Dirichlet(alpha).variance_x, xx2)
def test_relative_entropy(self): alpha = (2.0, 10.0, 1.0, 1.0) d = Dirichlet(alpha) pvec = (0.1, 0.2, 0.3, 0.4) rent = d.mean_relative_entropy(pvec) vrent = d.variance_relative_entropy(pvec) low, high = d.interval_relative_entropy(pvec, 0.95) # print() # print('> ', rent, vrent, low, high) # This test can fail randomly, but the precision from a few # thousand samples is low. Increasing samples, 1000->2000 samples = 2000 sent = zeros((samples, ), float64) for s in range(samples): post = d.sample() e = -entropy(post) for k in range(4): e += -post[k] * log(pvec[k]) sent[s] = e sent.sort() self.assertTrue(abs(sent.mean() - rent) < 4.0 * sqrt(vrent)) self.assertAlmostEqual(sent.std(), sqrt(vrent), 1) self.assertTrue(abs(low - sent[int(samples * 0.025)]) < 0.2) self.assertTrue(abs(high - sent[int(samples * 0.975)]) < 0.2)
def test_covariance(self): alpha = ones((4, ), float64) d = Dirichlet(alpha) cv = d.covariance() self.assertEqual(cv.shape, (4, 4)) self.assertAlmostEqual(cv[0, 0], 1.0 * (1.0 - 1.0 / 4.0) / (4.0 * 5.0)) self.assertAlmostEqual(cv[0, 1], -1 / (4.0 * 4.0 * 5.0))
def test_init(self): Dirichlet(( 1, 1, 1, 1, ))
def do_test(alpha, samples=1000): ent = zeros((samples, ), float64) # alpha = ones( ( K,), Float64 ) * A/K # pt = zeros( (len(alpha) ,), Float64) d = Dirichlet(alpha) for s in range(samples): p = d.sample() # print(p) # pt +=p ent[s] = entropy(p) # print(pt/samples) m = mean(ent) v = var(ent) dm = d.mean_entropy() dv = d.variance_entropy() # print(alpha, ':', m, v, dm, dv) error = 4.0 * sqrt(v / samples) self.assertTrue(abs(m - dm) < error) self.assertTrue(abs(v - dv) < error) # dodgy error estimate
def test_mean(self): alpha = ones((10, ), float64) * 23.0 d = Dirichlet(alpha) m = d.mean() self.assertAlmostEqual(m[2], 1.0 / 10) self.assertAlmostEqual(sum(m), 1.0)