文件: matmul.py 项目: stackwell/sympy
    def __new__(cls, *args, **kwargs):
        check = kwargs.get('check', True)

        if not args:
            return GenericIdentity()

        # This must be removed aggressively in the constructor to avoid
        # TypeErrors from GenericIdentity().shape
        args = filter(lambda i: GenericIdentity() != i, args)
        args = list(map(sympify, args))
        obj = Basic.__new__(cls, *args)
        factor, matrices = obj.as_coeff_matrices()
        if check:
        if not matrices:
            # Should it be
            # return Basic.__neq__(cls, factor, GenericIdentity()) ?
            return factor
        return obj
def test_generic_identity():
    I = GenericIdentity()
    A = MatrixSymbol("A", n, n)

    assert I == I
    assert I != A
    assert A != I

    assert I.is_Identity
    assert I**-1 == I

    raises(TypeError, lambda: I.shape)
    raises(TypeError, lambda: I.rows)
    raises(TypeError, lambda: I.cols)

    assert MatMul() == I
    assert MatMul(I, A) == MatMul(A)
    # Make sure it is hashable
class MatMul(MatrixExpr, Mul):
    A product of matrix expressions


    >>> from sympy import MatMul, MatrixSymbol
    >>> A = MatrixSymbol('A', 5, 4)
    >>> B = MatrixSymbol('B', 4, 3)
    >>> C = MatrixSymbol('C', 3, 6)
    >>> MatMul(A, B, C)
    is_MatMul = True

    identity = GenericIdentity()

    def __new__(cls, *args, **kwargs):
        check = kwargs.get('check', True)

        if not args:
            return cls.identity

        # This must be removed aggressively in the constructor to avoid
        # TypeErrors from GenericIdentity().shape
        args = filter(lambda i: cls.identity != i, args)
        args = list(map(sympify, args))
        obj = Basic.__new__(cls, *args)
        factor, matrices = obj.as_coeff_matrices()
        if check:
        if not matrices:
            # Should it be
            # return Basic.__neq__(cls, factor, GenericIdentity()) ?
            return factor
        return obj

    def shape(self):
        matrices = [arg for arg in self.args if arg.is_Matrix]
        return (matrices[0].rows, matrices[-1].cols)

    def _entry(self, i, j, expand=True, **kwargs):
        from sympy import Dummy, Sum, Mul, ImmutableMatrix, Integer

        coeff, matrices = self.as_coeff_matrices()

        if len(matrices) == 1:  # situation like 2*X, matmul is just X
            return coeff * matrices[0][i, j]

        indices = [None]*(len(matrices) + 1)
        ind_ranges = [None]*(len(matrices) - 1)
        indices[0] = i
        indices[-1] = j

        def f():
            counter = 1
            while True:
                yield Dummy("i_%i" % counter)
                counter += 1

        dummy_generator = kwargs.get("dummy_generator", f())

        for i in range(1, len(matrices)):
            indices[i] = next(dummy_generator)

        for i, arg in enumerate(matrices[:-1]):
            ind_ranges[i] = arg.shape[1] - 1
        matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
        expr_in_sum = Mul.fromiter(matrices)
        if any(v.has(ImmutableMatrix) for v in matrices):
            expand = True
        result = coeff*Sum(
                *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)

        # Don't waste time in result.doit() if the sum bounds are symbolic
        if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
            expand = False
        return result.doit() if expand else result

    def as_coeff_matrices(self):
        scalars = [x for x in self.args if not x.is_Matrix]
        matrices = [x for x in self.args if x.is_Matrix]
        coeff = Mul(*scalars)
        if coeff.is_commutative is False:
            raise NotImplementedError("noncommutative scalars in MatMul are not supported.")

        return coeff, matrices

    def as_coeff_mmul(self):
        coeff, matrices = self.as_coeff_matrices()
        return coeff, MatMul(*matrices)

    def _eval_transpose(self):
        """Transposition of matrix multiplication.


        The following rules are applied.

        Transposition for matrix multiplied with another matrix:
        `\\left(A B\\right)^{T} = B^{T} A^{T}`

        Transposition for matrix multiplied with scalar:
        `\\left(c A\\right)^{T} = c A^{T}`


        .. [1] https://en.wikipedia.org/wiki/Transpose
        coeff, matrices = self.as_coeff_matrices()
        return MatMul(
            coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()

    def _eval_adjoint(self):
        return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()

    def _eval_trace(self):
        factor, mmul = self.as_coeff_mmul()
        if factor != 1:
            from .trace import trace
            return factor * trace(mmul.doit())
            raise NotImplementedError("Can't simplify any further")

    def _eval_determinant(self):
        from sympy.matrices.expressions.determinant import Determinant
        factor, matrices = self.as_coeff_matrices()
        square_matrices = only_squares(*matrices)
        return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))

    def _eval_inverse(self):
            return MatMul(*[
                arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
                    for arg in self.args[::-1]]).doit()
        except ShapeError:
            from sympy.matrices.expressions.inverse import Inverse
            return Inverse(self)

    def doit(self, **kwargs):
        deep = kwargs.get('deep', True)
        if deep:
            args = [arg.doit(**kwargs) for arg in self.args]
            args = self.args
        # treat scalar*MatrixSymbol or scalar*MatPow separately
        expr = canonicalize(MatMul(*args))
        return expr

    # Needed for partial compatibility with Mul
    def args_cnc(self, **kwargs):
        coeff_c = [x for x in self.args if x.is_commutative]
        coeff_nc = [x for x in self.args if not x.is_commutative]
        return [coeff_c, coeff_nc]

    def _eval_derivative_matrix_lines(self, x):
        from .transpose import Transpose
        with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
        lines = []
        for ind in with_x_ind:
            left_args = self.args[:ind]
            right_args = self.args[ind+1:]

            if right_args:
                right_mat = MatMul.fromiter(right_args)
                right_mat = Identity(self.shape[1])
            if left_args:
                left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
                left_rev = Identity(self.shape[0])

            d = self.args[ind]._eval_derivative_matrix_lines(x)
            for i in d:

        return lines
def test_generic_identity():
    assert MatMul.identity == GenericIdentity()
    assert MatMul.identity != S.One
class MatMul(MatrixExpr):
    precedence = 45
    A product of matrix expressions


    >>> from sympy import MatMul, MatrixSymbol
    >>> A = MatrixSymbol('A', 5, 4)
    >>> B = MatrixSymbol('B', 4, 3)
    >>> C = MatrixSymbol('C', 3, 6)
    >>> MatMul(A, B, C)
    is_commutative = True

    identity = GenericIdentity()

    def __new__(cls, *args, **kwargs):
        #         check = kwargs.get('check', True)
        check = kwargs.get('check', False)

        if not args:
            return cls.identity

        if len(args) == 1:
            return args[0]

        # This must be removed aggressively in the constructor to avoid
        # TypeErrors from GenericIdentity().shape
        args = list(map(sympify, args))

        if any(arg.is_MatMul for arg in args):

            def generator():
                for arg in args:
                    if arg.is_MatMul:
                        yield from arg.args
                        yield arg

            args = [*generator()]

        coeffs = []
        matrices = []

        def append(mat):
            if matrices:
                last = matrices[-1]
                if last.is_MatPow:
                    if mat.is_MatPow:
                        if last.base == mat.base:
                            matrices[-1] = last.func(last.base,
                                                     last.exp + mat.exp)
                    elif last.base == mat:
                        matrices[-1] = last.func(last.base, last.exp + 1)
                elif last == mat:
                    if mat._eval_inverse() == last:
                        matrices[-1] = MatPow(last, 2)


        for arg in args:
            if not arg.is_Mul:

            coeff = []
            matrix = []
            for t in arg.args:
                if t.shape:
            if coeff:

        if not matrices:
            return Identity(args[0].shape[-1])
        matrices = [*filter(lambda X: not X.is_Identity, matrices)]

        if len(matrices) == 1:
            mat = matrices.pop()

            mat = Basic.__new__(cls, *matrices)
            factor, matrices = mat.as_coeff_matrices()
            if check:
            if not matrices:
                # Should it be
                # return Basic.__neq__(cls, factor, GenericIdentity()) ?
                mat = factor

        if coeffs:
            mat = Mul(*coeffs) * mat

        return mat

    def argmax_shape(self):
        import numpy as np
        return np.argmax([len(arg.shape) for arg in self.args])

    def shape(self):
        dimension = self.args[self.argmax_shape()].shape
        dimension = dimension[:-2]

        m = self.args[0]
        if len(m.shape) == 1:
            if len(self.args[1].shape) == 1:
                assert m.shape[0] == self.args[1].shape[0]
                assert m.shape[0] == self.args[1].shape[
                    -2], "self.args[0].shape = %s, self.args[1].shape = %s" % (
                        self.args[0].shape, self.args[1].shape)
            assert len(m.shape) >= 2
            dimension += (m.shape[-2], )

        last_shape = self.args[-1].shape
        if len(last_shape) > 1:
            dimension += (last_shape[-1], )

        return dimension

#         matrices = [arg for arg in self.args if arg.is_Matrix]
#         return (matrices[0].rows, matrices[-1].cols)

    def _entry(self, i, j=None, expand=True, **kwargs):
        if j is None:
            if len(self.args[0].shape) == 1:
                return self.args[0] @ self.func(*self.args[1:])[:, i]
            return self.args[0][i] @ self.func(*self.args[1:])
        if isinstance(i, slice):
            start, stop = i.start, i.stop
            if start is None:
                if stop is None:
                    return self.func(*self.args[:-1]) @ self.args[-1][:, j]
                start = 0
            if stop is None:
                stop = self.shape[0]

        if expand:
            from sympy import Dummy, Sum, ImmutableMatrix, Integer

            coeff, matrices = self.as_coeff_matrices()

            if len(matrices) == 1:  # situation like 2*X, matmul is just X
                return coeff * matrices[0][i, j]

            indices = [None] * (len(matrices) + 1)
            ind_ranges = [None] * (len(matrices) - 1)
            indices[0] = i
            indices[-1] = j

            def f():
                counter = 1
                while True:
                    yield Dummy("i_%i" % counter)
                    counter += 1

            dummy_generator = kwargs.get("dummy_generator", f())

            for i in range(1, len(matrices)):
                indices[i] = next(dummy_generator)

            for i, arg in enumerate(matrices[:-1]):
                ind_ranges[i] = arg.shape[1] - 1
            matrices = [
                           indices[i + 1],
                for i, arg in enumerate(matrices)
            expr_in_sum = Mul.fromiter(matrices)
            if any(v.has(ImmutableMatrix) for v in matrices):
                expand = True
            result = coeff * Sum(
                *zip(indices[1:-1], [0] * len(ind_ranges), ind_ranges))

            # Don't waste time in result.doit() if the sum bounds are symbolic
            if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
                expand = False
            return result.doit() if expand else result
            return self._entry(i)[:, j]

    def as_coeff_matrices(self):
        #         scalars = [x for x in self.args if not x.is_Matrix]
        #         matrices = [x for x in self.args if x.is_Matrix]
        scalars = [x for x in self.args if not x.shape]
        matrices = [x for x in self.args if x.shape]

        coeff = Mul(*scalars)
        #         if coeff.is_commutative == False:
        #             raise NotImplementedError("noncommutative scalars in MatMul are not supported.")

        return coeff, matrices

    def as_coeff_mmul(self):
        coeff, matrices = self.as_coeff_matrices()
        return coeff, MatMul(*matrices)

    def _eval_adjoint(self):
        return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()

    def _eval_trace(self):
        factor, mmul = self.as_coeff_mmul()
        if factor != 1:
            from .trace import trace
            return factor * trace(mmul.doit())
            raise NotImplementedError("Can't simplify any further")

    def _eval_determinant(self):
        #         from sympy.matrices.expressions.determinant import Det
        from sympy import det
        factor, matrices = self.as_coeff_matrices()
        square_matrices = only_squares(*matrices)
        return factor**self.rows * Mul(*list(map(det, square_matrices)))

    def _eval_inverse(self):
            return MatMul(*[arg.inverse() for arg in self.args[::-1]]).doit()
        except ShapeError:
            from sympy.matrices.expressions.inverse import Inverse
            return Inverse(self)

    def doit(self, **kwargs):
        deep = kwargs.get('deep', False)
        if deep:
            args = [arg.doit(**kwargs) for arg in self.args]
            args = self.args
        # treat scalar*MatrixSymbol or scalar*MatPow separately
        expr = canonicalize(MatMul(*args))
        return expr

    # Needed for partial compatibility with Mul
    def args_cnc(self, **kwargs):
        coeff_c = [x for x in self.args if x.is_commutative]
        coeff_nc = [x for x in self.args if not x.is_commutative]
        return [coeff_c, coeff_nc]

    def _eval_derivative_matrix_lines(self, x):
        from .transpose import Transpose
        with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
        lines = []
        for ind in with_x_ind:
            left_args = self.args[:ind]
            right_args = self.args[ind + 1:]

            if right_args:
                right_mat = MatMul.fromiter(right_args)
                right_mat = Identity(self.shape[1])
            if left_args:
                left_rev = MatMul.fromiter([
                    Transpose(i).doit() if i.is_Matrix else i
                    for i in reversed(left_args)
                left_rev = Identity(self.shape[0])

            d = self.args[ind]._eval_derivative_matrix_lines(x)
            for i in d:

        return lines

    def simplify(self, **_):
        this = any_zeros(self)
        if this != self:
            return this

        from sympy import exp
        if len(self.args) == 2 and all(
                isinstance(arg, exp) for arg in self.args):
            if len(self.args[0].shape) < len(self.args[1].shape):
                from sympy.concrete import summations
                return summations.Sum(
                    exp(self.args[0].arg + self.args[1].arg.T))

        this = self.simplifyProduct()
        if this is not self:
            return this

        return self

    def simplifyProduct(self):
        from sympy.concrete.products import MatProduct

        for i, prod in enumerate(self.args):
            if isinstance(prod, MatProduct):
                before = self.func(*self.args[:i])
                after = self.func(*self.args[i + 1:])

                _prod = prod.try_absorb_forward(before)
                if _prod:
                    return _prod @ after

                _prod = prod.try_absorb_backward(after)
                if _prod:
                    return before @ _prod

        return self

    def expand(self, free_symbol=None, deep=True, **_):
        if not deep:
            return MatrixExpr.expand(self)
        from sympy.concrete.expr_with_limits import LAMBDA
        from sympy.concrete.summations import Sum
        if len(self.args) > 2:
            return MatrixExpr.expand(self)

        A, B = self.args
        if A.is_MatPow:
            return self
        if A.is_Concatenate:
            if B.is_Concatenate and len(A.shape) == 1:
                if len(A.args) == len(B.args):
                    sgm = None
                    for a, b in zip(A.args, B.args):
                        if a.shape:
                            product = a @ b
                            if product.is_MatMul and len(product.args) == 2:
                                product = product.expand()
                            product = a * b
                        if sgm is None:
                            sgm = product
                            sgm += product
                    return sgm
                    return self
                args = [self.func(arg, B).simplify() for arg in A.args]
                if deep:
                    args = [
                        arg.expand(deep=True) if arg.is_MatMul else arg
                        for arg in args
                return A.func(*args)

        if A.is_Transpose:
            if B.is_Transpose:
                return (B.arg @ A.arg).expand().T
            if A.arg.is_Concatenate and B.is_Concatenate:
                A_T = A.arg
                if len(A_T.args) == len(B.args):
                    B_T = B._eval_transpose()
                    if B_T is not None:
                        # A @ B = A.T.T @ B.T.T = (B.T @ A.T).T
                        return (B_T @ A_T).expand().T

                    sgm = None
                    for a, b in zip(A_T.args, B.args):
                        if len(a.shape) == 1 and len(b.shape) == 1:
                            n = a.shape[0]
                            if b.shape[0] == n:
                                i = a.generate_free_symbol(b.free_symbols,
                                j = a.generate_free_symbol(b.free_symbols
                                                           | {i},
                                product = LAMBDA[j:n,
                                                 i:n](a[i] * b[j]).simplify()
                                return self
                            if not b.shape:
                                product = a * b
                            elif a.rows == b.shape[0]:
                                product = (a.T @ b).simplify()
                                if product.is_MatMul and len(
                                        product.args) == 2:
                                    X = product.args[1]
                                    if X.is_Transpose and X.arg.is_Concatenate:
                                        product = product.expand()
                                return self
                        if sgm is None:
                            sgm = product
                            sgm += product
                    return sgm
            return self

        if B.is_Concatenate:
            return self

        if B.is_Transpose and B.arg.is_Concatenate:
            return (B.arg @ A.T).expand().T

        if A.is_MatProduct:
            return self

        kwargs = {'free_symbol': free_symbol, 'generator': self}

        def generate_k_limit(A, B, excludes=None, **kwargs):
            if A.is_LAMBDA or not B.is_LAMBDA:
                if excludes:
                    excludes |= B.free_symbols
                    excludes = B.free_symbols

                return A.generate_int_limit(0, excludes, **kwargs)

            if excludes:
                excludes |= A.free_symbols
                excludes = A.free_symbols

            return B.generate_int_limit(0 if len(B.shape) == 1 else 1,
                                        excludes, **kwargs)

        if len(A.shape) > 1:
            i_limit = A.generate_int_limit(1, **kwargs)
            i, *_ = i_limit
            if len(B.shape) > 1:
                j_limit = B.generate_int_limit(0, {i}, **kwargs)
                j, *_ = j_limit

                k_limit = generate_k_limit(A, B, {i, j}, **kwargs)
                k, *_ = k_limit

                assert i != k and k != j and i != j
                return LAMBDA(
                    Sum(A[i, k] * B[k, j], k_limit).simplify(), j_limit,
                k_limit = generate_k_limit(A, B, {i}, **kwargs)
                k, *_ = k_limit

                assert i != k
                return LAMBDA(
                    Sum(A[i, k] * B[k], k_limit).simplify(),
            #             print('A.shape =', A.shape)
            if len(B.shape) > 1:
                if B.shape[-1].is_Integer:
                    k_limit = generate_k_limit(A, B, **kwargs)
                    k, *_ = k_limit

                    args = []
                    for j in range(B.shape[-1]):
                        args.append(Sum(A[k] * B[k, j], k_limit).simplify())
                    return Concatenate(*args)
                    #                     print('B.shape =', B.shape)
                    j_limit = B.generate_int_limit(0, **kwargs)
                    j, *_ = j_limit

                    k_limit = generate_k_limit(A, B, {j}, **kwargs)
                    k, *_ = k_limit

                    assert k != j
                    return LAMBDA(
                        Sum(A[k] * B[k, j], k_limit).simplify(),
            k_limit = generate_k_limit(A, B, **kwargs)
            k, *_ = k_limit
            return Sum(A[k] * B[k], k_limit).simplify()

    def _eval_is_integer(self):
        for elem in self.args:
            is_integer = elem.is_integer
            if is_integer:
            return is_integer
        return True

    def domain(self):
        from sympy import Interval, oo
        from sympy.sets.sets import CartesianSpace
        shape = self.shape
        interval = Interval(-oo, oo, integer=self.is_integer)
        if shape:
            return CartesianSpace(interval, *shape)
        return interval

    def _sympystr(self, p):
        from sympy.core.mul import _keep_coeff
        from sympy.printing.precedence import precedence
        c, m = self.as_coeff_mmul()
        if c.is_number and c < 0:
            expr = _keep_coeff(-c, m)
            sign = "-"
            level = precedence(expr)
            sign = ""
            level = precedence(self)

        return sign + ' @ '.join(
            p.parenthesize(arg, level) for arg in self.args)

    def _latex(self, p):
        from sympy import MatMul

        from sympy.printing.precedence import precedence_traditional
        parens = lambda x: p.parenthesize(x, precedence_traditional(self),

        args = self.args
        args = list(args)

        if isinstance(self, MatMul) and self._coeff_isneg():
            if args[0] == -1:
                args = args[1:]
                args[0] = -args[0]
            return '- ' + r' \times '.join(map(parens, args))
            return r' \times '.join(map(parens, args))

    def as_ordered_factors(self, **_):
        return [self]

    def _eval_is_extended_real(self):
        return fuzzy_and(arg.is_extended_real for arg in self.args)

    def _eval_is_extended_positive(self):
        return fuzzy_and(arg.is_extended_positive for arg in self.args)

    def _eval_is_extended_negative(self):
        return fuzzy_and(arg.is_extended_negative for arg in self.args)

    def _eval_is_finite(self):
        return fuzzy_and(arg.is_finite for arg in self.args)

    def class_key(cls):
        return 3, 0, cls.__name__

    def atomic_dtype(self):
        dtype = None
        for arg in self.args:
            _dtype = arg.atomic_dtype
            if dtype is None or dtype in _dtype:
                dtype = _dtype
        return dtype

    def domain_defined(self, x):
        from sympy import S
        if x.atomic_dtype.is_set:
            return S.UniversalSet

        domain = x.domain
        for arg in self.args:
            domain &= arg.domain_defined(x)
        return domain

    def domain_definition(self):
        eq = MatrixExpr.domain_definition(self)
        for arg in self.args:
            eq &= arg.domain_definition()
        return eq

    def _eval_transpose(self):
        """Transposition of matrix multiplication.


        The following rules are applied.

        Transposition for matrix multiplied with another matrix:
        `\\left(A B\\right)^{T} = B^{T} A^{T}`

        Transposition for matrix multiplied with scalar:
        `\\left(c A\\right)^{T} = c A^{T}`


        .. [1] https://en.wikipedia.org/wiki/Transpose

        return self.func(*(arg.T for arg in self.args[::-1]))

    def _subs(self, old, new, **hints):
        if old.is_MatMul:
            args = old.args
            for i in range(len(self.args) - len(args) + 1):
                if all(self.args[j] == args[j - i]
                       for j in range(i, i + len(args))):
                    return self.func(*self.args[:i] +
                                     (new.args if new.is_MatMul else (new, )) +
                                     self.args[i + len(args):]).simplify()
        return MatrixExpr._subs(self, old, new, **hints)

    def distribute(self):
        for i, arg in enumerate(self.args):
            if arg.is_Sum or arg.is_Integral:
                args = [*self.args]
                args[i] = arg.function
                function = self.func(*args).powsimp()
                return arg.func(function, *arg.limits)
            if arg.is_Plus:
                args = [*self.args]
                if i > 0:
                    left = arg.func(*(self.func(*args[:i]) @ a
                                      for a in arg.args))
                    right = args[i + 1:]
                    if right:
                        return left @ self.func(*right)
                        return left
                    return self
        return self