예제 #1
0
 def test_dirichlet(self):
     rnd.seed(self.seed, self.brng)
     alpha = np.array([51.72840233779265162, 39.74494232180943953])
     actual = rnd.dirichlet(alpha, size=(3, 2))
     desired = np.array([[[0.6332947001908874, 0.36670529980911254],
                          [0.5376828907571894, 0.4623171092428107]],
                         [[0.6835615930093024, 0.3164384069906976],
                          [0.5452378139016114, 0.45476218609838875]],
                         [[0.6498494402738553, 0.3501505597261446],
                          [0.5622024400324822, 0.43779755996751785]]])
     np.testing.assert_array_almost_equal(actual, desired, decimal=10)
예제 #2
0
    def test_dirichlet_size(self):
        # gh-3173
        p = np.array([51.72840233779265162, 39.74494232180943953])
        assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2))
        assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2))
        assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2))
        assert_equal(rnd.dirichlet(p, [2, 2]).shape, (2, 2, 2))
        assert_equal(rnd.dirichlet(p, (2, 2)).shape, (2, 2, 2))
        assert_equal(rnd.dirichlet(p, np.array((2, 2))).shape, (2, 2, 2))

        assert_raises(TypeError, rnd.dirichlet, p, np.float(1))