Ejemplo n.º 1
0
 def test_draw(self):
     d = Dirichlet(np.array([1., 2., 1.]))
     self.assertEqual(d.draw().shape, (1, 3))
     sample = d.draw(1000)
     self.assertTrue(
         np.allclose(np.mean(sample, axis=0), [0.25, 0.5, 0.25], 1e-1, 1e-1)
     )
Ejemplo n.º 2
0
 def test_var(self):
     d = Dirichlet(np.array([1, 2, 3]))
     var = np.array(
         [[5, -2, -3],
          [-2, 8, -6],
          [-3, -6, 9]]
     ) / 36 / 7
     self.assertTrue(np.allclose(var, d.var))
Ejemplo n.º 3
0
 def test_bayes(self):
     mu = Dirichlet(concentration=np.ones(3))
     model = Categorical(prob=mu)
     self.assertEqual(
         repr(model),
         "Categorical(prob=\nDirichlet(concentration=[ 1.  1.  1.])\n)")
     model.bayes(np.array([[1., 0., 0.], [1., 0., 0.], [0., 1., 0.]]))
     self.assertEqual(
         repr(model),
         "Categorical(prob=\nDirichlet(concentration=[ 3.  2.  1.])\n)")
Ejemplo n.º 4
0
 def test_init(self):
     d = Dirichlet(np.ones(3))
     self.assertTrue((d.concentration == 1).all())
     self.assertEqual(d.size, 3)
     self.assertEqual(d.ndim, 1)
     self.assertEqual(d.shape, (3,))
Ejemplo n.º 5
0
 def test_pdf(self):
     d = Dirichlet(np.ones(4))
     self.assertTrue((d.pdf(np.random.uniform(size=(5, 4))) == 6).all())
Ejemplo n.º 6
0
 def test_mean(self):
     d = Dirichlet(np.ones(4))
     self.assertTrue((d.mean == 0.25).all())
Ejemplo n.º 7
0
 def test_repr(self):
     d = Dirichlet(np.ones(3))
     self.assertEqual(repr(d), "Dirichlet(concentration=[ 1.  1.  1.])")
Ejemplo n.º 8
0
 def test_init(self):
     d = Dirichlet(np.ones(3))
     self.assertTrue((d.concentration == 1).all())
     self.assertEqual(d.n_classes, 3)