Ejemplo n.º 1
0
def test_arrayexpr_array_expr_zero_array():
    za1 = ZeroArray(k, l, m, n)
    zm1 = ZeroMatrix(m, n)

    za2 = ZeroArray(k, m, m, n)
    zm2 = ZeroMatrix(m, m)
    zm3 = ZeroMatrix(k, k)

    assert _array_tensor_product(M, N, za1) == ZeroArray(k, k, k, k, k, l, m, n)
    assert _array_tensor_product(M, N, zm1) == ZeroArray(k, k, k, k, m, n)

    assert _array_contraction(za1, (3,)) == ZeroArray(k, l, m)
    assert _array_contraction(zm1, (1,)) == ZeroArray(m)
    assert _array_contraction(za2, (1, 2)) == ZeroArray(k, n)
    assert _array_contraction(zm2, (0, 1)) == 0

    assert _array_diagonal(za2, (1, 2)) == ZeroArray(k, n, m)
    assert _array_diagonal(zm2, (0, 1)) == ZeroArray(m)

    assert _permute_dims(za1, [2, 1, 3, 0]) == ZeroArray(m, l, n, k)
    assert _permute_dims(zm1, [1, 0]) == ZeroArray(n, m)

    assert _array_add(za1) == za1
    assert _array_add(zm1) == ZeroArray(m, n)
    tp1 = _array_tensor_product(MatrixSymbol("A", k, l), MatrixSymbol("B", m, n))
    assert _array_add(tp1, za1) == tp1
    tp2 = _array_tensor_product(MatrixSymbol("C", k, l), MatrixSymbol("D", m, n))
    assert _array_add(tp1, za1, tp2) == _array_add(tp1, tp2)
    assert _array_add(M, zm3) == M
    assert _array_add(M, N, zm3) == _array_add(M, N)
Ejemplo n.º 2
0
def test_arrayexpr_canonicalize_diagonal__permute_dims():
    tp = _array_tensor_product(M, Q, N, P)
    expr = _array_diagonal(
        _permute_dims(tp, [0, 1, 2, 4, 7, 6, 3, 5]), (2, 4, 5), (6, 7),
        (0, 3))
    result = _array_diagonal(tp, (2, 6, 7), (3, 5), (0, 4))
    assert expr == result

    tp = _array_tensor_product(M, N, P, Q)
    expr = _array_diagonal(_permute_dims(tp, [0, 5, 2, 4, 1, 6, 3, 7]), (1, 2, 6), (3, 4))
    result = _array_diagonal(_array_tensor_product(M, P, N, Q), (3, 4, 5), (1, 2))
    assert expr == result
Ejemplo n.º 3
0
def test_arrayexpr_convert_index_to_array_support_function():
    expr = M[i, j]
    assert _convert_indexed_to_array(expr) == (M, (i, j))
    expr = M[i, j] * N[k, l]
    assert _convert_indexed_to_array(expr) == (ArrayTensorProduct(M, N),
                                               (i, j, k, l))
    expr = M[i, j] * N[j, k]
    assert _convert_indexed_to_array(expr) == (ArrayDiagonal(
        ArrayTensorProduct(M, N), (1, 2)), (i, k, j))
    expr = Sum(M[i, j] * N[j, k], (j, 0, k - 1))
    assert _convert_indexed_to_array(expr) == (ArrayContraction(
        ArrayTensorProduct(M, N), (1, 2)), (i, k))
    expr = M[i, j] + N[i, j]
    assert _convert_indexed_to_array(expr) == (ArrayAdd(M, N), (i, j))
    expr = M[i, j] + N[j, i]
    assert _convert_indexed_to_array(expr) == (ArrayAdd(
        M, PermuteDims(N, Permutation([1, 0]))), (i, j))
    expr = M[i, j] + M[j, i]
    assert _convert_indexed_to_array(expr) == (ArrayAdd(
        M, PermuteDims(M, Permutation([1, 0]))), (i, j))
    expr = (M * N * P)[i, j]
    assert _convert_indexed_to_array(expr) == (_array_contraction(
        ArrayTensorProduct(M, N, P), (1, 2), (3, 4)), (i, j))
    expr = expr.function  # Disregard summation in previous expression
    ret1, ret2 = _convert_indexed_to_array(expr)
    assert ret1 == ArrayDiagonal(ArrayTensorProduct(M, N, P), (1, 2), (3, 4))
    assert str(ret2) == "(i, j, _i_1, _i_2)"
    expr = KroneckerDelta(i, j) * M[i, k]
    assert _convert_indexed_to_array(expr) == (M, ({i, j}, k))
    expr = KroneckerDelta(i, j) * KroneckerDelta(j, k) * M[i, l]
    assert _convert_indexed_to_array(expr) == (M, ({i, j, k}, l))
    expr = KroneckerDelta(j, k) * (M[i, j] * N[k, l] + N[i, j] * M[k, l])
    assert _convert_indexed_to_array(expr) == (_array_diagonal(
        _array_add(
            ArrayTensorProduct(M, N),
            _permute_dims(ArrayTensorProduct(M, N),
                          Permutation(0, 2)(1, 3))),
        (1, 2)), (i, l, frozenset({j, k})))
    expr = KroneckerDelta(j, m) * KroneckerDelta(
        m, k) * (M[i, j] * N[k, l] + N[i, j] * M[k, l])
    assert _convert_indexed_to_array(expr) == (_array_diagonal(
        _array_add(
            ArrayTensorProduct(M, N),
            _permute_dims(ArrayTensorProduct(M, N),
                          Permutation(0, 2)(1, 3))),
        (1, 2)), (i, l, frozenset({j, m, k})))
    expr = KroneckerDelta(i, j) * KroneckerDelta(j, k) * KroneckerDelta(
        k, m) * M[i, 0] * KroneckerDelta(m, n)
    assert _convert_indexed_to_array(expr) == (M, ({i, j, k, m, n}, 0))
    expr = M[i, i]
    assert _convert_indexed_to_array(expr) == (ArrayDiagonal(M, (0, 1)), (i, ))
Ejemplo n.º 4
0
def test_arrayexpr_nested_array_elementwise_add():
    cg = _array_contraction(
        _array_add(_array_tensor_product(M, N), _array_tensor_product(N, M)),
        (1, 2))
    result = _array_add(
        _array_contraction(_array_tensor_product(M, N), (1, 2)),
        _array_contraction(_array_tensor_product(N, M), (1, 2)))
    assert cg == result

    cg = _array_diagonal(
        _array_add(_array_tensor_product(M, N), _array_tensor_product(N, M)),
        (1, 2))
    result = _array_add(_array_diagonal(_array_tensor_product(M, N), (1, 2)),
                        _array_diagonal(_array_tensor_product(N, M), (1, 2)))
    assert cg == result
Ejemplo n.º 5
0
def _(expr: ArrayDiagonal):
    newexpr, removed = _remove_trivial_dims(expr.expr)
    shifts = list(
        accumulate(
            [0] +
            [1 if i in removed else 0 for i in range(get_rank(expr.expr))]))
    new_diag_indices_map = {
        i: tuple(j for j in i if j not in removed)
        for i in expr.diagonal_indices
    }
    for old_diag_tuple, new_diag_tuple in new_diag_indices_map.items():
        if len(new_diag_tuple) == 1:
            removed = [i for i in removed if i not in old_diag_tuple]
    new_diag_indices = [
        tuple(j - shifts[j] for j in i) for i in new_diag_indices_map.values()
    ]
    rank = get_rank(expr.expr)
    removed = ArrayDiagonal._push_indices_up(expr.diagonal_indices, removed,
                                             rank)
    removed = sorted({i for i in removed})
    # If there are single axes to diagonalize remaining, it means that their
    # corresponding dimension has been removed, they no longer need diagonalization:
    new_diag_indices = [i for i in new_diag_indices if len(i) > 0]
    if len(new_diag_indices) > 0:
        newexpr2 = _array_diagonal(newexpr,
                                   *new_diag_indices,
                                   allow_trivial_diags=True)
    else:
        newexpr2 = newexpr
    if isinstance(newexpr2, ArrayDiagonal):
        newexpr3, removed2 = _remove_diagonalized_identity_matrices(newexpr2)
        removed = _combine_removed(-1, removed, removed2)
        return newexpr3, removed
    else:
        return newexpr2, removed
Ejemplo n.º 6
0
def _(expr: ArrayDiagonal):
    pexpr = _array_diagonal(_array2matrix(expr.expr), *expr.diagonal_indices)
    pexpr = identify_hadamard_products(pexpr)
    if isinstance(pexpr, ArrayDiagonal):
        pexpr = _array_diag2contr_diagmatrix(pexpr)
    if expr == pexpr:
        return expr
    return _array2matrix(pexpr)
Ejemplo n.º 7
0
def _(expr: ElementwiseApplyFunction, x: Expr):
    assert get_rank(expr) == 2
    assert get_rank(x) == 2
    fdiff = expr._get_function_fdiff()
    dexpr = array_derive(expr.expr, x)
    tp = _array_tensor_product(ElementwiseApplyFunction(fdiff, expr.expr),
                               dexpr)
    td = _array_diagonal(tp, (0, 4), (1, 5))
    return td
Ejemplo n.º 8
0
def _array_diag2contr_diagmatrix(expr: ArrayDiagonal):
    if isinstance(expr.expr, ArrayTensorProduct):
        args = list(expr.expr.args)
        diag_indices = list(expr.diagonal_indices)
        mapping = _get_mapping_from_subranks(
            [_get_subrank(arg) for arg in args])
        tuple_links = [[mapping[j] for j in i] for i in diag_indices]
        contr_indices = []
        total_rank = get_rank(expr)
        replaced = [False for arg in args]
        for i, (abs_pos, rel_pos) in enumerate(zip(diag_indices, tuple_links)):
            if len(abs_pos) != 2:
                continue
            (pos1_outer, pos1_inner), (pos2_outer, pos2_inner) = rel_pos
            arg1 = args[pos1_outer]
            arg2 = args[pos2_outer]
            if get_rank(arg1) != 2 or get_rank(arg2) != 2:
                if replaced[pos1_outer]:
                    diag_indices[i] = None
                if replaced[pos2_outer]:
                    diag_indices[i] = None
                continue
            pos1_in2 = 1 - pos1_inner
            pos2_in2 = 1 - pos2_inner
            if arg1.shape[pos1_in2] == 1:
                if arg1.shape[pos1_inner] != 1:
                    darg1 = DiagMatrix(arg1)
                else:
                    darg1 = arg1
                args.append(darg1)
                contr_indices.append(
                    ((pos2_outer, pos2_inner), (len(args) - 1, pos1_inner)))
                total_rank += 1
                diag_indices[i] = None
                args[pos1_outer] = OneArray(arg1.shape[pos1_in2])
                replaced[pos1_outer] = True
            elif arg2.shape[pos2_in2] == 1:
                if arg2.shape[pos2_inner] != 1:
                    darg2 = DiagMatrix(arg2)
                else:
                    darg2 = arg2
                args.append(darg2)
                contr_indices.append(
                    ((pos1_outer, pos1_inner), (len(args) - 1, pos2_inner)))
                total_rank += 1
                diag_indices[i] = None
                args[pos2_outer] = OneArray(arg2.shape[pos2_in2])
                replaced[pos2_outer] = True
        diag_indices_new = [i for i in diag_indices if i is not None]
        cumul = list(accumulate([0] + [get_rank(arg) for arg in args]))
        contr_indices2 = [
            tuple(cumul[a] + b for a, b in i) for i in contr_indices
        ]
        tc = _array_contraction(_array_tensor_product(*args), *contr_indices2)
        td = _array_diagonal(tc, *diag_indices_new)
        return td
    return expr
Ejemplo n.º 9
0
def test_arrayexpr_array_diagonal():
    cg = _array_diagonal(M, (1, 0))
    assert cg == _array_diagonal(M, (0, 1))

    cg = _array_diagonal(_array_tensor_product(M, N, P), (4, 1), (2, 0))
    assert cg == _array_diagonal(_array_tensor_product(M, N, P), (1, 4), (0, 2))

    cg = _array_diagonal(_array_tensor_product(M, N), (1, 2), (3,), allow_trivial_diags=True)
    assert cg == _permute_dims(_array_diagonal(_array_tensor_product(M, N), (1, 2)), [0, 2, 1])

    Ax = ArraySymbol("Ax", shape=(1, 2, 3, 4, 3, 5, 6, 2, 7))
    cg = _array_diagonal(Ax, (1, 7), (3,), (2, 4), (6,), allow_trivial_diags=True)
    assert cg == _permute_dims(_array_diagonal(Ax, (1, 7), (2, 4)), [0, 2, 4, 5, 1, 6, 3])

    cg = _array_diagonal(M, (0,), allow_trivial_diags=True)
    assert cg == _permute_dims(M, [1, 0])

    raises(ValueError, lambda: _array_diagonal(M, (0, 0)))
Ejemplo n.º 10
0
def _(expr: ArrayElementwiseApplyFunc, x: Expr):
    fdiff = expr._get_function_fdiff()
    subexpr = expr.expr
    dsubexpr = array_derive(subexpr, x)
    tp = _array_tensor_product(dsubexpr,
                               ArrayElementwiseApplyFunc(fdiff, subexpr))
    b = get_rank(x)
    c = get_rank(expr)
    diag_indices = [(b + i, b + c + i) for i in range(c)]
    return _array_diagonal(tp, *diag_indices)
Ejemplo n.º 11
0
def test_arrayexpr_array_shape():
    expr = _array_tensor_product(M, N, P, Q)
    assert expr.shape == (k, k, k, k, k, k, k, k)
    Z = MatrixSymbol("Z", m, n)
    expr = _array_tensor_product(M, Z)
    assert expr.shape == (k, k, m, n)
    expr2 = _array_contraction(expr, (0, 1))
    assert expr2.shape == (m, n)
    expr2 = _array_diagonal(expr, (0, 1))
    assert expr2.shape == (m, n, k)
    exprp = _permute_dims(expr, [2, 1, 3, 0])
    assert exprp.shape == (m, k, n, k)
    expr3 = _array_tensor_product(N, Z)
    expr2 = _array_add(expr, expr3)
    assert expr2.shape == (k, k, m, n)

    # Contraction along axes with discordant dimensions:
    raises(ValueError, lambda: _array_contraction(expr, (1, 2)))
    # Also diagonal needs the same dimensions:
    raises(ValueError, lambda: _array_diagonal(expr, (1, 2)))
    # Diagonal requires at least to axes to compute the diagonal:
    raises(ValueError, lambda: _array_diagonal(expr, (1,)))
Ejemplo n.º 12
0
def test_arrayexpr_convert_array_to_matrix_diag2contraction_diagmatrix():
    cg = _array_diagonal(_array_tensor_product(M, a), (1, 2))
    res = _array_diag2contr_diagmatrix(cg)
    assert res.shape == cg.shape
    assert res == _array_contraction(
        _array_tensor_product(M, OneArray(1), DiagMatrix(a)), (1, 3))

    raises(ValueError, lambda: _array_diagonal(_array_tensor_product(a, M),
                                               (1, 2)))

    cg = _array_diagonal(_array_tensor_product(a.T, M), (1, 2))
    res = _array_diag2contr_diagmatrix(cg)
    assert res.shape == cg.shape
    assert res == _array_contraction(
        _array_tensor_product(OneArray(1), M, DiagMatrix(a.T)), (1, 4))

    cg = _array_diagonal(_array_tensor_product(a.T, M, N, b.T), (1, 2), (4, 7))
    res = _array_diag2contr_diagmatrix(cg)
    assert res.shape == cg.shape
    assert res == _array_contraction(
        _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a.T),
                              DiagMatrix(b.T)), (1, 7), (3, 9))

    cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 2), (4, 7))
    res = _array_diag2contr_diagmatrix(cg)
    assert res.shape == cg.shape
    assert res == _array_contraction(
        _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a),
                              DiagMatrix(b.T)), (1, 6), (3, 9))

    cg = _array_diagonal(_array_tensor_product(a, M, N, b.T), (0, 4), (3, 7))
    res = _array_diag2contr_diagmatrix(cg)
    assert res.shape == cg.shape
    assert res == _array_contraction(
        _array_tensor_product(OneArray(1), M, N, OneArray(1), DiagMatrix(a),
                              DiagMatrix(b.T)), (3, 6), (2, 9))

    I1 = Identity(1)
    x = MatrixSymbol("x", k, 1)
    A = MatrixSymbol("A", k, k)
    cg = _array_diagonal(_array_tensor_product(x, A.T, I1), (0, 2))
    assert _array_diag2contr_diagmatrix(cg).shape == cg.shape
    assert _array2matrix(cg).shape == cg.shape
Ejemplo n.º 13
0
def test_arrayexpr_canonicalize_diagonal_contraction():
    tp = _array_tensor_product(M, N, P, Q)
    expr = _array_contraction(_array_diagonal(tp, (1, 3, 4)), (0, 3))
    result = _array_diagonal(_array_contraction(_array_tensor_product(M, N, P, Q), (0, 6)), (0, 2, 3))
    assert expr == result

    expr = _array_contraction(_array_diagonal(tp, (0, 1, 2, 3, 7)), (1, 2, 3))
    result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 2, 3, 5, 6, 7))
    assert expr == result

    expr = _array_contraction(_array_diagonal(tp, (0, 2, 6, 7)), (1, 2, 3))
    result = _array_diagonal(_array_contraction(tp, (3, 4, 5)), (0, 2, 3, 4))
    assert expr == result

    td = _array_diagonal(_array_tensor_product(M, N, P, Q), (0, 3))
    expr = _array_contraction(td, (2, 1), (0, 4, 6, 5, 3))
    result = _array_contraction(_array_tensor_product(M, N, P, Q), (0, 1, 3, 5, 6, 7), (2, 4))
    assert expr == result
Ejemplo n.º 14
0
def convert_matrix_to_array(expr: Basic) -> Basic:
    if isinstance(expr, MatMul):
        args_nonmat = []
        args = []
        for arg in expr.args:
            if isinstance(arg, MatrixExpr):
                args.append(arg)
            else:
                args_nonmat.append(convert_matrix_to_array(arg))
        contractions = [(2*i+1, 2*i+2) for i in range(len(args)-1)]
        scalar = _array_tensor_product(*args_nonmat) if args_nonmat else S.One
        if scalar == 1:
            tprod = _array_tensor_product(
                *[convert_matrix_to_array(arg) for arg in args])
        else:
            tprod = _array_tensor_product(
                scalar,
                *[convert_matrix_to_array(arg) for arg in args])
        return _array_contraction(
                tprod,
                *contractions
        )
    elif isinstance(expr, MatAdd):
        return _array_add(
                *[convert_matrix_to_array(arg) for arg in expr.args]
        )
    elif isinstance(expr, Transpose):
        return _permute_dims(
                convert_matrix_to_array(expr.args[0]), [1, 0]
        )
    elif isinstance(expr, Trace):
        inner_expr: MatrixExpr = convert_matrix_to_array(expr.arg) # type: ignore
        return _array_contraction(inner_expr, (0, len(inner_expr.shape) - 1))
    elif isinstance(expr, Mul):
        return _array_tensor_product(*[convert_matrix_to_array(i) for i in expr.args])
    elif isinstance(expr, Pow):
        base = convert_matrix_to_array(expr.base)
        if (expr.exp > 0) == True:
            return _array_tensor_product(*[base for i in range(expr.exp)])
        else:
            return expr
    elif isinstance(expr, MatPow):
        base = convert_matrix_to_array(expr.base)
        if expr.exp.is_Integer != True:
            b = symbols("b", cls=Dummy)
            return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base))
        elif (expr.exp > 0) == True:
            return convert_matrix_to_array(MatMul.fromiter(base for i in range(expr.exp)))
        else:
            return expr
    elif isinstance(expr, HadamardProduct):
        tp = _array_tensor_product(*[convert_matrix_to_array(arg) for arg in expr.args])
        diag = [[2*i for i in range(len(expr.args))], [2*i+1 for i in range(len(expr.args))]]
        return _array_diagonal(tp, *diag)
    elif isinstance(expr, HadamardPower):
        base, exp = expr.args
        if isinstance(exp, Integer) and exp > 0:
            return convert_matrix_to_array(HadamardProduct.fromiter(base for i in range(exp)))
        else:
            d = Dummy("d")
            return ArrayElementwiseApplyFunc(Lambda(d, d**exp), base)
    elif isinstance(expr, KroneckerProduct):
        kp_args = [convert_matrix_to_array(arg) for arg in expr.args]
        permutation = [2*i for i in range(len(kp_args))] + [2*i + 1 for i in range(len(kp_args))]
        return Reshape(_permute_dims(_array_tensor_product(*kp_args), permutation), expr.shape)
    else:
        return expr
Ejemplo n.º 15
0
def _(expr: ArrayDiagonal, x: Expr):
    dsubexpr = array_derive(expr.expr, x)
    rank_x = len(get_shape(x))
    diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
    return _array_diagonal(dsubexpr, *diag_indices)
Ejemplo n.º 16
0
def tensordiagonal(array, *diagonal_axes):
    """
    Diagonalization of an array-like object on the specified axes.

    This is equivalent to multiplying the expression by Kronecker deltas
    uniting the axes.

    The diagonal indices are put at the end of the axes.

    Examples
    ========

    ``tensordiagonal`` acting on a 2-dimensional array by axes 0 and 1 is
    equivalent to the diagonal of the matrix:

    >>> from sympy import Array, tensordiagonal
    >>> from sympy import Matrix, eye
    >>> tensordiagonal(eye(3), (0, 1))
    [1, 1, 1]

    >>> from sympy.abc import a,b,c,d
    >>> m1 = Matrix([[a, b], [c, d]])
    >>> tensordiagonal(m1, [0, 1])
    [a, d]

    In case of higher dimensional arrays, the diagonalized out dimensions
    are appended removed and appended as a single dimension at the end:

    >>> A = Array(range(18), (3, 2, 3))
    >>> A
    [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]], [[12, 13, 14], [15, 16, 17]]]
    >>> tensordiagonal(A, (0, 2))
    [[0, 7, 14], [3, 10, 17]]
    >>> from sympy import permutedims
    >>> tensordiagonal(A, (0, 2)) == permutedims(Array([A[0, :, 0], A[1, :, 1], A[2, :, 2]]), [1, 0])
    True

    """
    if any(len(i) <= 1 for i in diagonal_axes):
        raise ValueError("need at least two axes to diagonalize")

    from sympy.tensor.array.expressions.array_expressions import _ArrayExpr
    from sympy.tensor.array.expressions.array_expressions import _CodegenArrayAbstract
    from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal, _array_diagonal
    from sympy.matrices.expressions.matexpr import MatrixSymbol
    if isinstance(array, (_ArrayExpr, _CodegenArrayAbstract, MatrixSymbol)):
        return _array_diagonal(array, *diagonal_axes)

    ArrayDiagonal._validate(array, *diagonal_axes)

    array, remaining_indices, remaining_shape, diagonal_deltas = _util_contraction_diagonal(
        array, *diagonal_axes)

    # Compute the diagonalized array:
    #
    # 1. external for loops on all undiagonalized indices.
    #    Undiagonalized indices are determined by the combinatorial product of
    #    the absolute positions of the remaining indices.
    # 2. internal loop on all diagonal indices.
    #    It appends the values of the absolute diagonalized index and the absolute
    #    undiagonalized index for the external loop.
    diagonalized_array = []
    diagonal_shape = [len(i) for i in diagonal_deltas]
    for icontrib in itertools.product(*remaining_indices):
        index_base_position = sum(icontrib)
        isum = []
        for sum_to_index in itertools.product(*diagonal_deltas):
            idx = array._get_tuple_index(index_base_position +
                                         sum(sum_to_index))
            isum.append(array[idx])

        isum = type(array)(isum).reshape(*diagonal_shape)
        diagonalized_array.append(isum)

    return type(array)(diagonalized_array, remaining_shape + diagonal_shape)
Ejemplo n.º 17
0
def _convert_indexed_to_array(expr):
    if isinstance(expr, Sum):
        function = expr.function
        summation_indices = expr.variables
        subexpr, subindices = _convert_indexed_to_array(function)
        subindicessets = {
            j: i
            for i in subindices if isinstance(i, frozenset) for j in i
        }
        summation_indices = sorted(set(
            [subindicessets.get(i, i) for i in summation_indices]),
                                   key=default_sort_key)
        # TODO: check that Kronecker delta is only contracted to one other element:
        kronecker_indices = set([])
        if isinstance(function, Mul):
            for arg in function.args:
                if not isinstance(arg, KroneckerDelta):
                    continue
                arg_indices = sorted(set(arg.indices), key=default_sort_key)
                if len(arg_indices) == 2:
                    kronecker_indices.update(arg_indices)
        kronecker_indices = sorted(kronecker_indices, key=default_sort_key)
        # Check dimensional consistency:
        shape = get_shape(subexpr)
        if shape:
            for ind, istart, iend in expr.limits:
                i = _get_argindex(subindices, ind)
                if istart != 0 or iend + 1 != shape[i]:
                    raise ValueError(
                        "summation index and array dimension mismatch: %s" %
                        ind)
        contraction_indices = []
        subindices = list(subindices)
        if isinstance(subexpr, ArrayDiagonal):
            diagonal_indices = list(subexpr.diagonal_indices)
            dindices = subindices[-len(diagonal_indices):]
            subindices = subindices[:-len(diagonal_indices)]
            for index in summation_indices:
                if index in dindices:
                    position = dindices.index(index)
                    contraction_indices.append(diagonal_indices[position])
                    diagonal_indices[position] = None
            diagonal_indices = [i for i in diagonal_indices if i is not None]
            for i, ind in enumerate(subindices):
                if ind in summation_indices:
                    pass
            if diagonal_indices:
                subexpr = _array_diagonal(subexpr.expr, *diagonal_indices)
            else:
                subexpr = subexpr.expr

        axes_contraction = defaultdict(list)
        for i, ind in enumerate(subindices):
            include = all(j not in kronecker_indices
                          for j in ind) if isinstance(
                              ind, frozenset) else ind not in kronecker_indices
            if ind in summation_indices and include:
                axes_contraction[ind].append(i)
                subindices[i] = None
        for k, v in axes_contraction.items():
            if any(i in kronecker_indices for i in k) if isinstance(
                    k, frozenset) else k in kronecker_indices:
                continue
            contraction_indices.append(tuple(v))
        free_indices = [i for i in subindices if i is not None]
        indices_ret = list(free_indices)
        indices_ret.sort(key=lambda x: free_indices.index(x))
        return _array_contraction(
            subexpr, *contraction_indices,
            free_indices=free_indices), tuple(indices_ret)
    if isinstance(expr, Mul):
        args, indices = zip(
            *[_convert_indexed_to_array(arg) for arg in expr.args])
        # Check if there are KroneckerDelta objects:
        kronecker_delta_repl = {}
        for arg in args:
            if not isinstance(arg, KroneckerDelta):
                continue
            # Diagonalize two indices:
            i, j = arg.indices
            kindices = set(arg.indices)
            if i in kronecker_delta_repl:
                kindices.update(kronecker_delta_repl[i])
            if j in kronecker_delta_repl:
                kindices.update(kronecker_delta_repl[j])
            kindices = frozenset(kindices)
            for index in kindices:
                kronecker_delta_repl[index] = kindices
        # Remove KroneckerDelta objects, their relations should be handled by
        # ArrayDiagonal:
        newargs = []
        newindices = []
        for arg, loc_indices in zip(args, indices):
            if isinstance(arg, KroneckerDelta):
                continue
            newargs.append(arg)
            newindices.append(loc_indices)
        flattened_indices = [
            kronecker_delta_repl.get(j, j) for i in newindices for j in i
        ]
        diagonal_indices, ret_indices = _get_diagonal_indices(
            flattened_indices)
        tp = _array_tensor_product(*newargs)
        if diagonal_indices:
            return _array_diagonal(tp, *diagonal_indices), ret_indices
        else:
            return tp, ret_indices
    if isinstance(expr, MatrixElement):
        indices = expr.args[1:]
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.args[0],
                                   *diagonal_indices), ret_indices
        else:
            return expr.args[0], ret_indices
    if isinstance(expr, ArrayElement):
        indices = expr.indices
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.name, *diagonal_indices), ret_indices
        else:
            return expr.name, ret_indices
    if isinstance(expr, Indexed):
        indices = expr.indices
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.base, *diagonal_indices), ret_indices
        else:
            return expr.args[0], ret_indices
    if isinstance(expr, IndexedBase):
        raise NotImplementedError
    if isinstance(expr, KroneckerDelta):
        return expr, expr.indices
    if isinstance(expr, Add):
        args, indices = zip(
            *[_convert_indexed_to_array(arg) for arg in expr.args])
        args = list(args)
        # Check if all indices are compatible. Otherwise expand the dimensions:
        index0 = []
        shape0 = []
        for arg, arg_indices in zip(args, indices):
            arg_indices_set = set(arg_indices)
            arg_indices_missing = arg_indices_set.difference(index0)
            index0.extend([i for i in arg_indices if i in arg_indices_missing])
            arg_shape = get_shape(arg)
            shape0.extend([
                arg_shape[i] for i, e in enumerate(arg_indices)
                if e in arg_indices_missing
            ])
        for i, (arg, arg_indices) in enumerate(zip(args, indices)):
            if len(arg_indices) < len(index0):
                missing_indices_pos = [
                    i for i, e in enumerate(index0) if e not in arg_indices
                ]
                missing_shape = [shape0[i] for i in missing_indices_pos]
                arg_indices = tuple(index0[j]
                                    for j in missing_indices_pos) + arg_indices
                args[i] = _array_tensor_product(OneArray(*missing_shape),
                                                args[i])
            permutation = Permutation([arg_indices.index(j) for j in index0])
            # Perform index permutations:
            args[i] = _permute_dims(args[i], permutation)
        return _array_add(*args), tuple(index0)
    if isinstance(expr, Pow):
        subexpr, subindices = _convert_indexed_to_array(expr.base)
        if isinstance(expr.exp, (int, Integer)):
            diags = zip(*[(2 * i, 2 * i + 1) for i in range(expr.exp)])
            arr = _array_diagonal(
                _array_tensor_product(*[subexpr for i in range(expr.exp)]),
                *diags)
            return arr, subindices
    if isinstance(expr, Function):
        subexpr, subindices = _convert_indexed_to_array(expr.args[0])
        return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices
    return expr, ()
Ejemplo n.º 18
0
def test_arrayexpr_array_flatten():

    # Flatten nested ArrayTensorProduct objects:
    expr1 = _array_tensor_product(M, N)
    expr2 = _array_tensor_product(P, Q)
    expr = _array_tensor_product(expr1, expr2)
    assert expr == _array_tensor_product(M, N, P, Q)
    assert expr.args == (M, N, P, Q)

    # Flatten mixed ArrayTensorProduct and ArrayContraction objects:
    cg1 = _array_contraction(expr1, (1, 2))
    cg2 = _array_contraction(expr2, (0, 3))

    expr = _array_tensor_product(cg1, cg2)
    assert expr == _array_contraction(_array_tensor_product(M, N, P, Q), (1, 2), (4, 7))

    expr = _array_tensor_product(M, cg1)
    assert expr == _array_contraction(_array_tensor_product(M, M, N), (3, 4))

    # Flatten nested ArrayContraction objects:
    cgnested = _array_contraction(cg1, (0, 1))
    assert cgnested == _array_contraction(_array_tensor_product(M, N), (0, 3), (1, 2))

    cgnested = _array_contraction(_array_tensor_product(cg1, cg2), (0, 3))
    assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 6), (1, 2), (4, 7))

    cg3 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4))
    cgnested = _array_contraction(cg3, (0, 1))
    assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 5), (1, 3), (2, 4))

    cgnested = _array_contraction(cg3, (0, 3), (1, 2))
    assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 7), (1, 3), (2, 4), (5, 6))

    cg4 = _array_contraction(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7))
    cgnested = _array_contraction(cg4, (0, 1))
    assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7))

    cgnested = _array_contraction(cg4, (0, 1), (2, 3))
    assert cgnested == _array_contraction(_array_tensor_product(M, N, P, Q), (0, 2), (1, 5), (3, 7), (4, 6))

    cg = _array_diagonal(cg4)
    assert cg == cg4
    assert isinstance(cg, type(cg4))

    # Flatten nested ArrayDiagonal objects:
    cg1 = _array_diagonal(expr1, (1, 2))
    cg2 = _array_diagonal(expr2, (0, 3))
    cg3 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4))
    cg4 = _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7))

    cgnested = _array_diagonal(cg1, (0, 1))
    assert cgnested == _array_diagonal(_array_tensor_product(M, N), (1, 2), (0, 3))

    cgnested = _array_diagonal(cg3, (1, 2))
    assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 3), (2, 4), (5, 6))

    cgnested = _array_diagonal(cg4, (1, 2))
    assert cgnested == _array_diagonal(_array_tensor_product(M, N, P, Q), (1, 5), (3, 7), (2, 4))

    cg = _array_add(M, N)
    cg2 = _array_add(cg, P)
    assert isinstance(cg2, ArrayAdd)
    assert cg2.args == (M, N, P)
    assert cg2.shape == (k, k)

    expr = _array_tensor_product(_array_diagonal(X, (0, 1)), _array_diagonal(A, (0, 1)))
    assert expr == _array_diagonal(_array_tensor_product(X, A), (0, 1), (2, 3))

    expr1 = _array_diagonal(_array_tensor_product(X, A), (1, 2))
    expr2 = _array_tensor_product(expr1, a)
    assert expr2 == _permute_dims(_array_diagonal(_array_tensor_product(X, A, a), (1, 2)), [0, 1, 4, 2, 3])

    expr1 = _array_contraction(_array_tensor_product(X, A), (1, 2))
    expr2 = _array_tensor_product(expr1, a)
    assert isinstance(expr2, ArrayContraction)
    assert isinstance(expr2.expr, ArrayTensorProduct)

    cg = _array_tensor_product(_array_diagonal(_array_tensor_product(A, X, Y), (0, 3), (1, 5)), a, b)
    assert cg == _permute_dims(_array_diagonal(_array_tensor_product(A, X, Y, a, b), (0, 3), (1, 5)), [0, 1, 6, 7, 2, 3, 4, 5])
def test_convert_array_to_hadamard_products():

    expr = HadamardProduct(M, N)
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == expr

    expr = HadamardProduct(M, N)*P
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == expr

    expr = Q*HadamardProduct(M, N)*P
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == expr

    expr = Q*HadamardProduct(M, N.T)*P
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == expr

    expr = HadamardProduct(M, N)*HadamardProduct(Q, P)
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert expr == ret

    expr = P.T*HadamardProduct(M, N)*HadamardProduct(Q, P)
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert expr == ret

    # ArrayDiagonal should be converted
    cg = _array_diagonal(_array_tensor_product(M, N, Q), (1, 3), (0, 2, 4))
    ret = convert_array_to_matrix(cg)
    expected = PermuteDims(_array_diagonal(_array_tensor_product(HadamardProduct(M.T, N.T), Q), (1, 2)), [1, 0, 2])
    assert expected == ret

    # Special case that should return the same expression:
    cg = _array_diagonal(_array_tensor_product(HadamardProduct(M, N), Q), (0, 2))
    ret = convert_array_to_matrix(cg)
    assert ret == cg

    # Hadamard products with traces:

    expr = Trace(HadamardProduct(M, N))
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == Trace(HadamardProduct(M.T, N.T))

    expr = Trace(A*HadamardProduct(M, N))
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == Trace(HadamardProduct(M, N)*A)

    expr = Trace(HadamardProduct(A, M)*N)
    cg = convert_matrix_to_array(expr)
    ret = convert_array_to_matrix(cg)
    assert ret == Trace(HadamardProduct(M.T, N)*A)

    # These should not be converted into Hadamard products:

    cg = _array_diagonal(_array_tensor_product(M, N), (0, 1, 2, 3))
    ret = convert_array_to_matrix(cg)
    assert ret == cg

    cg = _array_diagonal(_array_tensor_product(A), (0, 1))
    ret = convert_array_to_matrix(cg)
    assert ret == cg

    cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 2, 4), (1, 3, 5))
    assert convert_array_to_matrix(cg) == HadamardProduct(M, N, P)

    cg = _array_diagonal(_array_tensor_product(M, N, P), (0, 3, 4), (1, 2, 5))
    assert convert_array_to_matrix(cg) == HadamardProduct(M, P, N.T)

    cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5))
    assert convert_array_to_matrix(cg) == DiagMatrix(x)
def test_arrayexpr_convert_array_to_matrix_remove_trivial_dims():

    # Tensor Product:
    assert _remove_trivial_dims(_array_tensor_product(a, b)) == (a * b.T, [1, 3])
    assert _remove_trivial_dims(_array_tensor_product(a.T, b)) == (a * b.T, [0, 3])
    assert _remove_trivial_dims(_array_tensor_product(a, b.T)) == (a * b.T, [1, 2])
    assert _remove_trivial_dims(_array_tensor_product(a.T, b.T)) == (a * b.T, [0, 2])

    assert _remove_trivial_dims(_array_tensor_product(I, a.T, b.T)) == (_array_tensor_product(I, a * b.T), [2, 4])
    assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T)) == (_array_tensor_product(a.T, I, b.T), [])

    assert _remove_trivial_dims(_array_tensor_product(a, I)) == (_array_tensor_product(a, I), [])
    assert _remove_trivial_dims(_array_tensor_product(I, a)) == (_array_tensor_product(I, a), [])

    assert _remove_trivial_dims(_array_tensor_product(a.T, b.T, c, d)) == (
        _array_tensor_product(a * b.T, c * d.T), [0, 2, 5, 7])
    assert _remove_trivial_dims(_array_tensor_product(a.T, I, b.T, c, d, I)) == (
        _array_tensor_product(a.T, I, b*c.T, d, I), [4, 7])

    # Addition:

    cg = ArrayAdd(_array_tensor_product(a, b), _array_tensor_product(c, d))
    assert _remove_trivial_dims(cg) == (a * b.T + c * d.T, [1, 3])

    # Permute Dims:

    cg = PermuteDims(_array_tensor_product(a, b), Permutation(3)(1, 2))
    assert _remove_trivial_dims(cg) == (a * b.T, [2, 3])

    cg = PermuteDims(_array_tensor_product(a, I, b), Permutation(5)(1, 2, 3, 4))
    assert _remove_trivial_dims(cg) == (cg, [])

    cg = PermuteDims(_array_tensor_product(I, b, a), Permutation(5)(1, 2, 4, 5, 3))
    assert _remove_trivial_dims(cg) == (PermuteDims(_array_tensor_product(I, b * a.T), [0, 2, 3, 1]), [4, 5])

    # Diagonal:

    cg = _array_diagonal(_array_tensor_product(M, a), (1, 2))
    assert _remove_trivial_dims(cg) == (cg, [])

    # Contraction:

    cg = _array_contraction(_array_tensor_product(M, a), (1, 2))
    assert _remove_trivial_dims(cg) == (cg, [])

    # A few more cases to test the removal and shift of nested removed axes
    # with array contractions and array diagonals:
    tp = _array_tensor_product(
        OneMatrix(1, 1),
        M,
        x,
        OneMatrix(1, 1),
        Identity(1),
    )

    expr = _array_contraction(tp, (1, 8))
    rexpr, removed = _remove_trivial_dims(expr)
    assert removed == [0, 5, 6, 7]

    expr = _array_contraction(tp, (1, 8), (3, 4))
    rexpr, removed = _remove_trivial_dims(expr)
    assert removed == [0, 3, 4, 5]

    expr = _array_diagonal(tp, (1, 8))
    rexpr, removed = _remove_trivial_dims(expr)
    assert removed == [0, 5, 6, 7, 8]

    expr = _array_diagonal(tp, (1, 8), (3, 4))
    rexpr, removed = _remove_trivial_dims(expr)
    assert removed == [0, 3, 4, 5, 6]

    expr = _array_diagonal(_array_contraction(_array_tensor_product(A, x, I, I1), (1, 2, 5)), (1, 4))
    rexpr, removed = _remove_trivial_dims(expr)
    assert removed == [2, 3]

    cg = _array_diagonal(_array_tensor_product(PermuteDims(_array_tensor_product(x, I1), Permutation(1, 2, 3)), (x.T*x).applyfunc(sqrt)), (2, 4), (3, 5))
    rexpr, removed = _remove_trivial_dims(cg)
    assert removed == [1, 2]

    # Contractions with identity matrices need to be followed by a permutation
    # in order
    cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8))
    ret, removed = _remove_trivial_dims(cg)
    assert ret == PermuteDims(_array_tensor_product(A, B, C, M), [0, 2, 3, 4, 5, 6, 7, 1])
    assert removed == []

    cg = _array_contraction(_array_tensor_product(A, B, C, M, I), (1, 8), (3, 4))
    ret, removed = _remove_trivial_dims(cg)
    assert ret == PermuteDims(_array_contraction(_array_tensor_product(A, B, C, M), (3, 4)), [0, 2, 3, 4, 5, 1])
    assert removed == []

    # Trivial matrices are sometimes inserted into MatMul expressions:

    cg = _array_tensor_product(b*b.T, a.T*a)
    ret, removed = _remove_trivial_dims(cg)
    assert ret == b*a.T*a*b.T
    assert removed == [2, 3]

    Xs = ArraySymbol("X", (3, 2, k))
    cg = _array_tensor_product(M, Xs, b.T*c, a*a.T, b*b.T, c.T*d)
    ret, removed = _remove_trivial_dims(cg)
    assert ret == _array_tensor_product(M, Xs, a*b.T*c*c.T*d*a.T, b*b.T)
    assert removed == [5, 6, 11, 12]

    cg = _array_diagonal(_array_tensor_product(I, I1, x), (1, 4), (3, 5))
    assert _remove_trivial_dims(cg) == (PermuteDims(_array_diagonal(_array_tensor_product(I, x), (1, 2)), Permutation(1, 2)), [1])

    expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2))
    assert _remove_trivial_dims(expr) == (PermuteDims(_array_tensor_product(DiagMatrix(x), y), [1, 2, 3, 0]), [0])

    expr = _array_diagonal(_array_tensor_product(x, I, y), (0, 2), (3, 4))
    assert _remove_trivial_dims(expr) == (expr, [])
def test_arrayexpr_convert_array_to_diagonalized_vector():

    # Check matrix recognition over trivial dimensions:

    cg = _array_tensor_product(a, b)
    assert convert_array_to_matrix(cg) == a * b.T

    cg = _array_tensor_product(I1, a, b)
    assert convert_array_to_matrix(cg) == a * b.T

    # Recognize trace inside a tensor product:

    cg = _array_contraction(_array_tensor_product(A, B, C), (0, 3), (1, 2))
    assert convert_array_to_matrix(cg) == Trace(A * B) * C

    # Transform diagonal operator to contraction:

    cg = _array_diagonal(_array_tensor_product(A, a), (1, 2))
    assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (1, 3))
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a)

    cg = _array_diagonal(_array_tensor_product(a, b), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == _permute_dims(
        _array_contraction(_array_tensor_product(DiagMatrix(a), OneArray(1), b), (0, 3)), [1, 2, 0]
    )
    assert convert_array_to_matrix(cg) == b.T * DiagMatrix(a)

    cg = _array_diagonal(_array_tensor_product(A, a), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(A, OneArray(1), DiagMatrix(a)), (0, 3))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a)

    cg = _array_diagonal(_array_tensor_product(I, x, I1), (0, 2), (3, 5))
    assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(I, OneArray(1), I1, DiagMatrix(x)), (0, 5))
    assert convert_array_to_matrix(cg) == DiagMatrix(x)

    cg = _array_diagonal(_array_tensor_product(I, x, A, B), (1, 2), (5, 6))
    assert _array_diag2contr_diagmatrix(cg) == _array_diagonal(_array_contraction(_array_tensor_product(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6))
    # TODO: this is returning a wrong result:
    # convert_array_to_matrix(cg)

    cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3, 5))
    assert convert_array_to_matrix(cg) == a*b.T

    cg = _array_diagonal(_array_tensor_product(I1, a, b), (1, 3))
    assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), a, b, I1), (2, 6))
    assert convert_array_to_matrix(cg) == a*b.T

    cg = _array_diagonal(_array_tensor_product(x, I1), (1, 2))
    assert isinstance(cg, ArrayDiagonal)
    assert cg.diagonal_indices == ((1, 2),)
    assert convert_array_to_matrix(cg) == x

    cg = _array_diagonal(_array_tensor_product(x, I), (0, 2))
    assert _array_diag2contr_diagmatrix(cg) == _array_contraction(_array_tensor_product(OneArray(1), I, DiagMatrix(x)), (1, 3))
    assert convert_array_to_matrix(cg).doit() == DiagMatrix(x)

    raises(ValueError, lambda: _array_diagonal(x, (1,)))

    # Ignore identity matrices with contractions:

    cg = _array_contraction(_array_tensor_product(I, A, I, I), (0, 2), (1, 3), (5, 7))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg) == Trace(A) * I

    cg = _array_contraction(_array_tensor_product(Trace(A) * I, I, I), (1, 5), (3, 4))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg).doit() == Trace(A) * I

    # Add DiagMatrix when required:

    cg = _array_contraction(_array_tensor_product(A, a), (1, 2))
    assert cg.split_multiple_contractions() == cg
    assert convert_array_to_matrix(cg) == A * a

    cg = _array_contraction(_array_tensor_product(A, a, B), (1, 2, 4))
    assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5))
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B

    cg = _array_contraction(_array_tensor_product(A, a, B), (0, 2, 4))
    assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B

    cg = _array_contraction(_array_tensor_product(A, a, b, a.T, B), (0, 2, 4, 7, 9))
    assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1),
                                                DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B),
                                               (0, 2), (3, 5), (6, 9), (8, 12))
    assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T

    cg = _array_contraction(_array_tensor_product(I1, I1, I1), (1, 2, 4))
    assert cg.split_multiple_contractions() == _array_contraction(_array_tensor_product(I1, I1, OneArray(1), I1), (1, 2), (3, 5))
    assert convert_array_to_matrix(cg) == 1

    cg = _array_contraction(_array_tensor_product(I, I, I, I, A), (1, 2, 8), (5, 6, 9))
    assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A

    cg = _array_contraction(_array_tensor_product(A, a, C, a, B), (1, 2, 4), (5, 6, 8))
    expected = _array_contraction(_array_tensor_product(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10))
    assert cg.split_multiple_contractions() == expected
    assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B

    cg = _array_contraction(_array_tensor_product(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9))
    expected = _array_contraction(_array_tensor_product(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)),
                                (1, 3), (2, 10), (6, 8), (7, 11))
    assert cg.split_multiple_contractions().dummy_eq(expected)
    assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T))