Пример #1
0
def test_MatMul_postprocessor():
    z = zeros(2)
    z1 = ZeroMatrix(2, 2)
    assert Mul(0, z) == Mul(z, 0) in [z, z1]

    M = Matrix([[1, 2], [3, 4]])
    Mx = Matrix([[x, 2*x], [3*x, 4*x]])
    assert Mul(x, M) == Mul(M, x) == Mx

    A = MatrixSymbol("A", 2, 2)
    assert Mul(A, M) == MatMul(A, M)
    assert Mul(M, A) == MatMul(M, A)
    # Scalars should be absorbed into constant matrices
    a = Mul(x, M, A)
    b = Mul(M, x, A)
    c = Mul(M, A, x)
    assert a == b == c == MatMul(Mx, A)
    a = Mul(x, A, M)
    b = Mul(A, x, M)
    c = Mul(A, M, x)
    assert a == b == c == MatMul(A, Mx)
    assert Mul(M, M) == M**2
    assert Mul(A, M, M) == MatMul(A, M**2)
    assert Mul(M, M, A) == MatMul(M**2, A)
    assert Mul(M, A, M) == MatMul(M, A, M)

    assert Mul(A, x, M, M, x) == MatMul(A, Mx**2)
Пример #2
0
def test_invariants():
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    X = MatrixSymbol('X', n, n)
    objs = [Identity(n), ZeroMatrix(m, n), A, MatMul(A, B), MatAdd(A, A),
            Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1),
            MatPow(X, 0)]
    for obj in objs:
        assert obj == obj.__class__(*obj.args)
Пример #3
0
def test_generic_identity():
    I = GenericIdentity()
    A = MatrixSymbol("A", n, n)

    assert I == I
    assert I != A
    assert A != I

    assert I.is_Identity
    assert I**-1 == I

    raises(TypeError, lambda: I.shape)
    raises(TypeError, lambda: I.rows)
    raises(TypeError, lambda: I.cols)

    assert MatMul() == I
    assert MatMul(I, A) == MatMul(A)
    # Make sure it is hashable
    hash(I)
Пример #4
0
def test_MatMul():
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    C = MatrixSymbol('C', n, n)

    assert (2 * A * B).shape == (n, l)

    assert (A * 0 * B) == ZeroMatrix(n, l)

    raises(ShapeError, lambda: B * A)
    assert (2 * A).shape == A.shape

    assert MatMul(A, ZeroMatrix(m, m), B) == ZeroMatrix(n, l)

    assert MatMul(C * Identity(n) * C.I) == Identity(n)

    assert B / 2 == S.Half * B
    raises(NotImplementedError, lambda: 2 / B)

    A = MatrixSymbol('A', n, n)
    B = MatrixSymbol('B', n, n)
    assert MatMul(Identity(n), (A + B)).is_MatAdd
Пример #5
0
def test_subs():
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    C = MatrixSymbol('C', m, l)

    assert A.subs(n, m).shape == (m, m)
    assert (A * B).subs(B, C) == A * C
    assert (A * B).subs(l, n).is_square

    A = SparseMatrix([[1, 2], [3, 4]])
    B = Matrix([[1, 2], [3, 4]])
    C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2)

    assert (C * D).subs({C: A, D: B}) == MatMul(A, B)
Пример #6
0
def test_MatMul():
    n, m, l = symbols('n m l', integer=True)
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    C = MatrixSymbol('C', n, n)

    assert (2 * A * B).shape == (n, l)

    assert (A * 0 * B) == ZeroMatrix(n, l)

    raises(ShapeError, "B*A")
    assert (2 * A).shape == A.shape

    assert MatMul(A, ZeroMatrix(m, m), B) == ZeroMatrix(n, l)

    assert MatMul(C * Identity(n) * C.I) == Identity(n)

    assert B / 2 == S.Half * B
    raises(NotImplementedError, "2/B")

    A = MatrixSymbol('A', n, n)
    B = MatrixSymbol('B', n, n)
    assert MatMul(Identity(n), (A + B)).is_Add
Пример #7
0
def test_matexpr_subs():
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    C = MatrixSymbol('C', m, l)

    assert A.subs(n, m).shape == (m, m)
    assert (A * B).subs(B, C) == A * C
    assert (A * B).subs(l, n).is_square

    W = MatrixSymbol("W", 3, 3)
    X = MatrixSymbol("X", 2, 2)
    Y = MatrixSymbol("Y", 1, 2)
    Z = MatrixSymbol("Z", n, 2)
    # no restrictions on Symbol replacement
    assert X.subs(X, Y) == Y
    # it might be better to just change the name
    y = Str('y')
    assert X.subs(Str("X"), y).args == (y, 2, 2)
    # it's ok to introduce a wider matrix
    assert X[1, 1].subs(X, W) == W[1, 1]
    # but for a given MatrixExpression, only change
    # name if indexing on the new shape is valid.
    # Here, X is 2,2; Y is 1,2 and Y[1, 1] is out
    # of range so an error is raised
    raises(IndexError, lambda: X[1, 1].subs(X, Y))
    # here, [0, 1] is in range so the subs succeeds
    assert X[0, 1].subs(X, Y) == Y[0, 1]
    # and here the size of n will accept any index
    # in the first position
    assert W[2, 1].subs(W, Z) == Z[2, 1]
    # but not in the second position
    raises(IndexError, lambda: W[2, 2].subs(W, Z))
    # any matrix should raise if invalid
    raises(IndexError, lambda: W[2, 2].subs(W, zeros(2)))

    A = SparseMatrix([[1, 2], [3, 4]])
    B = Matrix([[1, 2], [3, 4]])
    C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2)

    assert (C * D).subs({C: A, D: B}) == MatMul(A, B)
Пример #8
0
def test_matmul_simplify():
    A = MatrixSymbol('A', 1, 1)
    assert simplify(MatMul(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \
        MatMul(A, ImmutableMatrix([[1]]))
def test_arrayexpr_convert_array_to_diagonalized_vector():

    # Check matrix recognition over trivial dimensions:

    cg = ArrayTensorProduct(a, b)
    assert convert_array_to_matrix(cg) == a * b.T

    cg = ArrayTensorProduct(I1, a, b)
    assert convert_array_to_matrix(cg) == a * b.T

    # Recognize trace inside a tensor product:

    cg = ArrayContraction(ArrayTensorProduct(A, B, C), (0, 3), (1, 2))
    assert convert_array_to_matrix(cg) == Trace(A * B) * C

    # Transform diagonal operator to contraction:

    cg = ArrayDiagonal(ArrayTensorProduct(A, a), (1, 2))
    assert _array_diag2contr_diagmatrix(cg) == ArrayContraction(ArrayTensorProduct(A, OneArray(1), DiagMatrix(a)), (1, 3))
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a)

    cg = ArrayDiagonal(ArrayTensorProduct(a, b), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == PermuteDims(
        ArrayContraction(ArrayTensorProduct(DiagMatrix(a), OneArray(1), b), (0, 3)), [1, 2, 0]
    )
    assert convert_array_to_matrix(cg) == b.T * DiagMatrix(a)

    cg = ArrayDiagonal(ArrayTensorProduct(A, a), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == ArrayContraction(ArrayTensorProduct(A, OneArray(1), DiagMatrix(a)), (0, 3))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a)

    cg = ArrayDiagonal(ArrayTensorProduct(I, x, I1), (0, 2), (3, 5))
    assert _array_diag2contr_diagmatrix(cg) == ArrayContraction(ArrayTensorProduct(I, OneArray(1), I1, DiagMatrix(x)), (0, 5))
    assert convert_array_to_matrix(cg) == DiagMatrix(x)

    cg = ArrayDiagonal(ArrayTensorProduct(I, x, A, B), (1, 2), (5, 6))
    assert _array_diag2contr_diagmatrix(cg) == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6))
    # TODO: this is returning a wrong result:
    # convert_array_to_matrix(cg)

    cg = ArrayDiagonal(ArrayTensorProduct(I1, a, b), (1, 3, 5))
    assert convert_array_to_matrix(cg) == a*b.T

    cg = ArrayDiagonal(ArrayTensorProduct(I1, a, b), (1, 3))
    assert _array_diag2contr_diagmatrix(cg) == ArrayContraction(ArrayTensorProduct(OneArray(1), a, b, I1), (2, 6))
    assert convert_array_to_matrix(cg) == a*b.T

    cg = ArrayDiagonal(ArrayTensorProduct(x, I1), (1, 2))
    assert isinstance(cg, ArrayDiagonal)
    assert cg.diagonal_indices == ((1, 2),)
    assert convert_array_to_matrix(cg) == x

    cg = ArrayDiagonal(ArrayTensorProduct(x, I), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == ArrayContraction(ArrayTensorProduct(OneArray(1), I, DiagMatrix(x)), (1, 3))
    assert convert_array_to_matrix(cg).doit() == DiagMatrix(x)

    raises(ValueError, lambda: ArrayDiagonal(x, (1,)))

    # Ignore identity matrices with contractions:

    cg = ArrayContraction(ArrayTensorProduct(I, A, I, I), (0, 2), (1, 3), (5, 7))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg) == Trace(A) * I

    cg = ArrayContraction(ArrayTensorProduct(Trace(A) * I, I, I), (1, 5), (3, 4))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg).doit() == Trace(A) * I

    # Add DiagMatrix when required:

    cg = ArrayContraction(ArrayTensorProduct(A, a), (1, 2))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg) == A * a

    cg = ArrayContraction(ArrayTensorProduct(A, a, B), (1, 2, 4))
    assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5))
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B

    cg = ArrayContraction(ArrayTensorProduct(A, a, B), (0, 2, 4))
    assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B

    cg = ArrayContraction(ArrayTensorProduct(A, a, b, a.T, B), (0, 2, 4, 7, 9))
    assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1),
                                                DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B),
                                               (0, 2), (3, 5), (6, 9), (8, 12))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T

    cg = ArrayContraction(ArrayTensorProduct(I1, I1, I1), (1, 2, 4))
    assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(I1, I1, OneArray(1), I1), (1, 2), (3, 5))
    assert convert_array_to_matrix(cg) == 1

    cg = ArrayContraction(ArrayTensorProduct(I, I, I, I, A), (1, 2, 8), (5, 6, 9))
    assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A

    cg = ArrayContraction(ArrayTensorProduct(A, a, C, a, B), (1, 2, 4), (5, 6, 8))
    expected = ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10))
    assert cg.split_multiple_contractions() == expected
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B

    cg = ArrayContraction(ArrayTensorProduct(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9))
    expected = ArrayContraction(ArrayTensorProduct(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)),
                                (1, 3), (2, 10), (6, 8), (7, 11))
    assert cg.split_multiple_contractions().dummy_eq(expected)
    assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T))
Пример #10
0
def test_recognize_diagonalized_vectors():

    a = MatrixSymbol("a", k, 1)
    b = MatrixSymbol("b", k, 1)
    A = MatrixSymbol("A", k, k)
    B = MatrixSymbol("B", k, k)
    C = MatrixSymbol("C", k, k)
    X = MatrixSymbol("X", k, k)
    x = MatrixSymbol("x", k, 1)
    I1 = Identity(1)
    I = Identity(k)

    # Check matrix recognition over trivial dimensions:

    cg = CodegenArrayTensorProduct(a, b)
    assert recognize_matrix_expression(cg) == a * b.T

    cg = CodegenArrayTensorProduct(I1, a, b)
    assert recognize_matrix_expression(cg) == a * I1 * b.T

    # Recognize trace inside a tensor product:

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, B, C), (0, 3),
                                 (1, 2))
    assert recognize_matrix_expression(cg) == Trace(A * B) * C

    # Transform diagonal operator to contraction:

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(A, a), (1, 2))
    assert cg.transform_to_product() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a)), (1, 2))
    assert recognize_matrix_expression(cg) == A * DiagMatrix(a)

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(a, b), (0, 2))
    assert cg.transform_to_product() == CodegenArrayContraction(
        CodegenArrayTensorProduct(DiagMatrix(a), b), (0, 2))
    assert recognize_matrix_expression(cg).doit() == DiagMatrix(a) * b

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(A, a), (0, 2))
    assert cg.transform_to_product() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a)), (0, 2))
    assert recognize_matrix_expression(cg) == A.T * DiagMatrix(a)

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(I, x, I1), (0, 2),
                              (3, 5))
    assert cg.transform_to_product() == CodegenArrayContraction(
        CodegenArrayTensorProduct(I, DiagMatrix(x), I1), (0, 2))

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(I, x, A, B), (1, 2),
                              (5, 6))
    assert cg.transform_to_product() == CodegenArrayDiagonal(
        CodegenArrayContraction(
            CodegenArrayTensorProduct(I, DiagMatrix(x), A, B), (1, 2)), (3, 4))

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(x, I1), (1, 2))
    assert isinstance(cg, CodegenArrayDiagonal)
    assert cg.diagonal_indices == ((1, 2), )
    assert recognize_matrix_expression(cg) == x

    cg = CodegenArrayDiagonal(CodegenArrayTensorProduct(x, I), (0, 2))
    assert cg.transform_to_product() == CodegenArrayContraction(
        CodegenArrayTensorProduct(DiagMatrix(x), I), (0, 2))
    assert recognize_matrix_expression(cg).doit() == DiagMatrix(x)

    cg = CodegenArrayDiagonal(x, (1, ))
    assert cg == x

    # Ignore identity matrices with contractions:

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(I, A, I, I), (0, 2),
                                 (1, 3), (5, 7))
    assert cg.split_multiple_contractions() == cg
    assert recognize_matrix_expression(cg) == Trace(A) * I

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(Trace(A) * I, I, I),
                                 (1, 5), (3, 4))
    assert cg.split_multiple_contractions() == cg
    assert recognize_matrix_expression(cg).doit() == Trace(A) * I

    # Add DiagMatrix when required:

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a), (1, 2))
    assert cg.split_multiple_contractions() == cg
    assert recognize_matrix_expression(cg) == A * a

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, B), (1, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a), B), (1, 2), (3, 4))
    assert recognize_matrix_expression(cg) == A * DiagMatrix(a) * B

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, B), (0, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a), B), (0, 2), (3, 4))
    assert recognize_matrix_expression(cg) == A.T * DiagMatrix(a) * B

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, b, a.T, B),
                                 (0, 2, 4, 7, 9))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a), DiagMatrix(b),
                                  DiagMatrix(a), B), (0, 2), (3, 4), (5, 7),
        (6, 9))
    assert recognize_matrix_expression(
        cg).doit() == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(I1, I1, I1),
                                 (1, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(I1, I1, I1), (1, 2), (3, 4))
    assert recognize_matrix_expression(cg).doit() == Identity(1)

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(I, I, I, I, A),
                                 (1, 2, 8), (5, 6, 9))
    assert recognize_matrix_expression(
        cg.split_multiple_contractions()).doit() == A

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, C, a, B),
                                 (1, 2, 4), (5, 6, 8))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagMatrix(a), C, DiagMatrix(a), B),
        (1, 2), (3, 4), (5, 6), (7, 8))
    assert recognize_matrix_expression(
        cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B

    cg = CodegenArrayContraction(
        CodegenArrayTensorProduct(a, I1, b, I1, (a.T * b).applyfunc(cos)),
        (1, 2, 8), (5, 6, 9))
    assert cg.split_multiple_contractions().dummy_eq(
        CodegenArrayContraction(
            CodegenArrayTensorProduct(a, I1, b, I1, (a.T * b).applyfunc(cos)),
            (1, 2), (3, 8), (5, 6), (7, 9)))
    assert recognize_matrix_expression(cg).dummy_eq(
        MatMul(a, I1, (a.T * b).applyfunc(cos), Transpose(I1), b.T))

    cg = CodegenArrayContraction(
        CodegenArrayTensorProduct(A.T, a, b, b.T, (A * X * b).applyfunc(cos)),
        (1, 2, 8), (5, 6, 9))
    assert cg.split_multiple_contractions().dummy_eq(
        CodegenArrayContraction(
            CodegenArrayTensorProduct(A.T, DiagMatrix(a), b, b.T,
                                      (A * X * b).applyfunc(cos)), (1, 2),
            (3, 8), (5, 6, 9)))
    # assert recognize_matrix_expression(cg)

    # Check no overlap of lines:

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, C, a, B),
                                 (1, 2, 4), (5, 6, 8), (3, 7))
    assert cg.split_multiple_contractions() == cg

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(a, b, A), (0, 2, 4),
                                 (1, 3))
    assert cg.split_multiple_contractions() == cg
Пример #11
0
def test_MatMul_kind():
    M = Matrix([[1, 2], [3, 4]])
    assert MatMul(2, M).kind is MatrixKind(NumberKind)
    assert MatMul(comm_x, M).kind is MatrixKind(NumberKind)
Пример #12
0
def test_matrixify():
    n, m, l = symbols('n m l')
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, l)
    assert matrixify(n + m) == n + m
    assert matrixify(Mul(A, B)) == MatMul(A, B)
Пример #13
0
def test_recognize_diagonalized_vectors():

    a = MatrixSymbol("a", k, 1)
    b = MatrixSymbol("b", k, 1)
    A = MatrixSymbol("A", k, k)
    B = MatrixSymbol("B", k, k)
    C = MatrixSymbol("C", k, k)

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a), (1, 2))
    assert cg.split_multiple_contractions() == cg
    assert recognize_matrix_expression(cg) == A * a

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, B), (1, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagonalizeVector(a), B), (1, 2), (3, 4))
    assert recognize_matrix_expression(cg) == A * DiagonalizeVector(a) * B

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, B), (0, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagonalizeVector(a), B), (0, 2), (3, 4))
    assert recognize_matrix_expression(cg) == A.T * DiagonalizeVector(a) * B

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, b, a.T, B),
                                 (0, 2, 4, 7, 9))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A,
                                  DiagonalizeVector(a), DiagonalizeVector(b),
                                  DiagonalizeVector(a), B), (0, 2), (3, 4),
        (5, 7), (6, 9))
    assert recognize_matrix_expression(cg).doit() == A.T * DiagonalizeVector(
        a) * DiagonalizeVector(b) * DiagonalizeVector(a) * B.T

    I1 = Identity(1)
    cg = CodegenArrayContraction(CodegenArrayTensorProduct(I1, I1, I1),
                                 (1, 2, 4))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(I1, I1, I1), (1, 2), (3, 4))
    assert recognize_matrix_expression(cg).doit() == Identity(1)

    I = Identity(k)
    cg = CodegenArrayContraction(CodegenArrayTensorProduct(I, I, I, I, A),
                                 (1, 2, 8), (5, 6, 9))
    assert recognize_matrix_expression(
        cg.split_multiple_contractions()).doit() == A

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, C, a, B),
                                 (1, 2, 4), (5, 6, 8))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(A, DiagonalizeVector(a), C,
                                  DiagonalizeVector(a), B), (1, 2), (3, 4),
        (5, 6), (7, 8))
    assert recognize_matrix_expression(
        cg) == A * DiagonalizeVector(a) * C * DiagonalizeVector(a) * B

    cg = CodegenArrayContraction(
        CodegenArrayTensorProduct(a, I1, b, I1, (a.T * b).applyfunc(cos)),
        (1, 2, 8), (5, 6, 9))
    assert cg.split_multiple_contractions() == CodegenArrayContraction(
        CodegenArrayTensorProduct(a, I1, b, I1, (a.T * b).applyfunc(cos)),
        (1, 2), (3, 8), (5, 6), (7, 9))
    assert recognize_matrix_expression(cg) == MatMul(a, I1,
                                                     (a.T * b).applyfunc(cos),
                                                     Transpose(I1), b.T)

    # Check no overlap of lines:

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(A, a, C, a, B),
                                 (1, 2, 4), (5, 6, 8), (3, 7))
    raises(ValueError, lambda: cg.split_multiple_contractions())

    cg = CodegenArrayContraction(CodegenArrayTensorProduct(a, b, A), (0, 2, 4),
                                 (1, 3))
    raises(ValueError, lambda: cg.split_multiple_contractions())