def test_Trace_MatAdd_doit(): # See issue sympy/sympy#9028 X = ImmutableMatrix([[1, 2, 3]] * 3) Y = MatrixSymbol('Y', 3, 3) q = MatAdd(X, 2 * X, Y, -3 * Y) assert Trace(q).arg == q assert Trace(q).doit() == 18 - 2 * Trace(Y)
def test_invariants(): A = MatrixSymbol('A', n, m) B = MatrixSymbol('B', m, l) X = MatrixSymbol('X', n, n) objs = [Identity(n), ZeroMatrix(m, n), MatMul(A, B), MatAdd(A, A), Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1), MatPow(X, 0)] for obj in objs: assert obj == obj.__class__(*obj.args)
def test_Trace_doit_deep_False(): X = Matrix([[1, 2], [3, 4]]) assert Trace(X).doit(deep=False) == 5 q = MatPow(X, 2) assert Trace(q).doit(deep=False).arg == q q = MatAdd(X, 2 * X) assert Trace(q).doit(deep=False).arg == q q = MatMul(X, 2 * X) assert Trace(q).doit(deep=False).arg == q
def test_doit_args(): A = ImmutableMatrix([[1, 2], [3, 4]]) B = ImmutableMatrix([[2, 3], [4, 5]]) assert MatAdd(A, MatPow(B, 2)).doit() == A + B**2 assert MatAdd(A, MatMul(A, B)).doit() == A + A * B assert MatAdd(A, A).doit(deep=False) == 2 * A assert (MatAdd(A, X, MatMul(A, B), Y, MatAdd(2 * A, B)).doit() == MatAdd(X, Y, 3 * A + A * B + B))
def test_matadd_simplify(): A = MatrixSymbol('A', 1, 1) assert simplify(MatAdd(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ MatAdd(A, ImmutableMatrix([[1]]))
def test_Trace_A_plus_B(): assert trace(A + B) == Trace(A) + Trace(B) assert Trace(A + B).arg == MatAdd(A, B) assert Trace(A + B).doit() == Trace(A) + Trace(B)
def test_matadd_of_matrices(): assert MatAdd(eye(2), 4 * eye(2), eye(2)).doit() == ImmutableMatrix(6 * eye(2))
def test_matadd_sympify(): assert isinstance(MatAdd(eye(1), eye(1)).args[0], Basic)
def test_matadd(): pytest.raises(ShapeError, lambda: X + eye(1)) MatAdd(X, eye(1), check=False) # not raises
def test_sort_key(): assert MatAdd(Y, X).doit().args == (X, Y)
def test_doit_nested_MatrixExpr(): X = ImmutableMatrix([[1, 2], [3, 4]]) Y = ImmutableMatrix([[2, 3], [4, 5]]) assert MatPow(MatMul(X, Y), 2).doit() == (X * Y)**2 assert MatPow(MatAdd(X, Y), 2).doit() == (X + Y)**2