def test_MatrixNormal():
    M = MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]])
    assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals)
    X = MatrixSymbol('X', 1, 2)
    term1 = exp(-Trace(Matrix([[ S(2)/3, -S(1)/3], [-S(1)/3, S(2)/3]])*(
            Matrix([[-5], [-6]]) + X.T)*Matrix([[1/4]])*(Matrix([[-5, -6]]) + X))/2)
    assert density(M)(X).doit() == term1/(24*pi)
    assert density(M)([[7, 8]]).doit() == exp(-S(1)/3)/(24*pi)
    d, n = symbols('d n', positive=True, integer=True)
    SM2 = MatrixSymbol('SM2', d, d)
    SM1 = MatrixSymbol('SM1', n, n)
    LM = MatrixSymbol('LM', n, d)
    Y = MatrixSymbol('Y', n, d)
    M = MatrixNormal('M', LM, SM1, SM2)
    exprd = 4*(2*pi)**(-d*n/2)*exp(-Trace(SM2**(-1)*(-LM.T + Y.T)*SM1**(-1)*(-LM + Y)
        )/2)*Determinant(SM1)**(-d)*Determinant(SM2)**(-n)
    assert density(M)(Y).doit() == exprd
    raises(ValueError, lambda: density(M)(1))
    raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]]))
    raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]]))
    raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]]))
    raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]]))
    raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0]]))
    raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [[1, 0], [0, 1]], [[1, 0]]))
    raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [1], [[1, 0]]))
def test_sample_seed():
    X = MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]])

    libraries = ['scipy', 'numpy', 'pymc3']
    for lib in libraries:
        try:
            imported_lib = import_module(lib)
            if imported_lib:
                s0, s1, s2 = [], [], []
                s0 = list(sample(X, numsamples=10, library=lib, seed=0))
                s1 = list(sample(X, numsamples=10, library=lib, seed=0))
                s2 = list(sample(X, numsamples=10, library=lib, seed=1))
                assert all((s0[i] == s1[i]).all() for i in range(10))
                assert all((s1[i] != s2[i]).all() for i in range(10))
        except NotImplementedError:
            continue
def test_sample_pymc3():
    distribs_pymc3 = [
        MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]],
                     [[2, 1], [1, 2]]),
        Wishart('W', 7, [[2, 1], [1, 2]])
    ]
    size = 3
    pymc3 = import_module('pymc3')
    if not pymc3:
        skip('PyMC3 is not installed. Abort tests for _sample_pymc3.')
    else:
        for X in distribs_pymc3:
            samps = sample(X, size=size, library='pymc3')
            for sam in samps:
                assert Matrix(sam) in X.pspace.distribution.set
        M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
        raises(NotImplementedError, lambda: sample(M, size=3))
def test_sample_scipy():
    distribs_scipy = [
        MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]),
        Wishart('W', 5, [[1, 0], [0, 1]])
    ]

    size = 5
    scipy = import_module('scipy')
    if not scipy:
        skip('Scipy not installed. Abort tests for _sample_scipy.')
    else:
        for X in distribs_scipy:
            samps = sample(X, size=size)
            for sam in samps:
                assert Matrix(sam) in X.pspace.distribution.set
        M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
        raises(NotImplementedError, lambda: sample(M, size=3))
def test_sample_pymc3():
    distribs_pymc3 = [
        MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]),
        Wishart('W', 7, [[2, 1], [1, 2]])
    ]
    size = 3
    pymc3 = import_module('pymc3')
    if not pymc3:
        skip('PyMC3 is not installed. Abort tests for _sample_pymc3.')
    else:
        with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
            for X in distribs_pymc3:
                samps = next(sample(X, size=size, library='pymc3'))
                for sam in samps:
                    assert Matrix(sam) in X.pspace.distribution.set
            M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
            raises(NotImplementedError, lambda: next(sample(M, size=3)))
def test_sample_scipy():
    distribs_scipy = [
        MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]),
        Wishart('W', 5, [[1, 0], [0, 1]])
    ]

    size = 5
    scipy = import_module('scipy')
    if not scipy:
        skip('Scipy not installed. Abort tests for _sample_scipy.')
    else:
        with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
            for X in distribs_scipy:
                samps = next(sample(X, size=size))
                for sam in samps:
                    assert Matrix(sam) in X.pspace.distribution.set
            M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
            raises(NotImplementedError, lambda: next(sample(M, size=3)))