def test_dirichlet(self): rnd.seed(self.seed, self.brng) alpha = np.array([51.72840233779265162, 39.74494232180943953]) actual = rnd.dirichlet(alpha, size=(3, 2)) desired = np.array([[[0.6332947001908874, 0.36670529980911254], [0.5376828907571894, 0.4623171092428107]], [[0.6835615930093024, 0.3164384069906976], [0.5452378139016114, 0.45476218609838875]], [[0.6498494402738553, 0.3501505597261446], [0.5622024400324822, 0.43779755996751785]]]) np.testing.assert_array_almost_equal(actual, desired, decimal=10)
def test_dirichlet_size(self): # gh-3173 p = np.array([51.72840233779265162, 39.74494232180943953]) assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2)) assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2)) assert_equal(rnd.dirichlet(p, np.uint32(1)).shape, (1, 2)) assert_equal(rnd.dirichlet(p, [2, 2]).shape, (2, 2, 2)) assert_equal(rnd.dirichlet(p, (2, 2)).shape, (2, 2, 2)) assert_equal(rnd.dirichlet(p, np.array((2, 2))).shape, (2, 2, 2)) assert_raises(TypeError, rnd.dirichlet, p, np.float(1))