def test_MatrixNormal(): M = MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]) assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals) X = MatrixSymbol('X', 1, 2) term1 = exp(-Trace(Matrix([[ S(2)/3, -S(1)/3], [-S(1)/3, S(2)/3]])*( Matrix([[-5], [-6]]) + X.T)*Matrix([[1/4]])*(Matrix([[-5, -6]]) + X))/2) assert density(M)(X).doit() == term1/(24*pi) assert density(M)([[7, 8]]).doit() == exp(-S(1)/3)/(24*pi) d, n = symbols('d n', positive=True, integer=True) SM2 = MatrixSymbol('SM2', d, d) SM1 = MatrixSymbol('SM1', n, n) LM = MatrixSymbol('LM', n, d) Y = MatrixSymbol('Y', n, d) M = MatrixNormal('M', LM, SM1, SM2) exprd = 4*(2*pi)**(-d*n/2)*exp(-Trace(SM2**(-1)*(-LM.T + Y.T)*SM1**(-1)*(-LM + Y) )/2)*Determinant(SM1)**(-d)*Determinant(SM2)**(-n) assert density(M)(Y).doit() == exprd raises(ValueError, lambda: density(M)(1)) raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]])) raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]])) raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]])) raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]])) raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0]])) raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [[1, 0], [0, 1]], [[1, 0]])) raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [1], [[1, 0]]))
def test_sample_seed(): X = MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]) libraries = ['scipy', 'numpy', 'pymc3'] for lib in libraries: try: imported_lib = import_module(lib) if imported_lib: s0, s1, s2 = [], [], [] s0 = list(sample(X, numsamples=10, library=lib, seed=0)) s1 = list(sample(X, numsamples=10, library=lib, seed=0)) s2 = list(sample(X, numsamples=10, library=lib, seed=1)) assert all((s0[i] == s1[i]).all() for i in range(10)) assert all((s1[i] != s2[i]).all() for i in range(10)) except NotImplementedError: continue
def test_sample_pymc3(): distribs_pymc3 = [ MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]), Wishart('W', 7, [[2, 1], [1, 2]]) ] size = 3 pymc3 = import_module('pymc3') if not pymc3: skip('PyMC3 is not installed. Abort tests for _sample_pymc3.') else: for X in distribs_pymc3: samps = sample(X, size=size, library='pymc3') for sam in samps: assert Matrix(sam) in X.pspace.distribution.set M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) raises(NotImplementedError, lambda: sample(M, size=3))
def test_sample_scipy(): distribs_scipy = [ MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]), Wishart('W', 5, [[1, 0], [0, 1]]) ] size = 5 scipy = import_module('scipy') if not scipy: skip('Scipy not installed. Abort tests for _sample_scipy.') else: for X in distribs_scipy: samps = sample(X, size=size) for sam in samps: assert Matrix(sam) in X.pspace.distribution.set M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) raises(NotImplementedError, lambda: sample(M, size=3))
def test_sample_pymc3(): distribs_pymc3 = [ MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]), Wishart('W', 7, [[2, 1], [1, 2]]) ] size = 3 pymc3 = import_module('pymc3') if not pymc3: skip('PyMC3 is not installed. Abort tests for _sample_pymc3.') else: with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed for X in distribs_pymc3: samps = next(sample(X, size=size, library='pymc3')) for sam in samps: assert Matrix(sam) in X.pspace.distribution.set M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) raises(NotImplementedError, lambda: next(sample(M, size=3)))
def test_sample_scipy(): distribs_scipy = [ MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]), Wishart('W', 5, [[1, 0], [0, 1]]) ] size = 5 scipy = import_module('scipy') if not scipy: skip('Scipy not installed. Abort tests for _sample_scipy.') else: with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed for X in distribs_scipy: samps = next(sample(X, size=size)) for sam in samps: assert Matrix(sam) in X.pspace.distribution.set M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]]) raises(NotImplementedError, lambda: next(sample(M, size=3)))