예제 #1
0
    def __new__(cls, *args):

        args = map(matrixify, args)

        args = [arg for arg in args if arg!=0]

        if not all(arg.is_Matrix for arg in args):
            raise ValueError("Mix of Matrix and Scalar symbols")

        # Check that the shape of the args is consistent
        A = args[0]
        for B in args[1:]:
            if A.shape != B.shape:
                raise ShapeError("Matrices %s and %s are not aligned"%(A,B))

        expr = Add.__new__(cls, *args)
        if expr == S.Zero:
            return ZeroMatrix(*args[0].shape)
        expr = matrixify(expr)

        if expr.is_Mul:
            return MatMul(*expr.args)

        # Clear out Identities
        # Any zeros around?
        if expr.is_Add and any(M.is_ZeroMatrix for M in expr.args):
            newargs = [M for M in expr.args if not M.is_ZeroMatrix] # clear out
            if len(newargs)==0: # Did we lose everything?
                return ZeroMatrix(*args[0].shape)
            if expr.args != newargs: # Removed some 0's but not everything?
                return MatAdd(*newargs) # Repeat with simpler expr

        return expr
예제 #2
0
    def __new__(cls, *args):

        # Check that the shape of the args is consistent
        matrices = [arg for arg in args if arg.is_Matrix]

        for i in range(len(matrices) - 1):
            A, B = matrices[i:i + 2]
            if A.cols != B.rows:
                raise ShapeError("Matrices %s and %s are not aligned" % (A, B))

        if any(arg.is_zero for arg in args):
            return ZeroMatrix(matrices[0].rows, matrices[-1].cols)

        expr = matrixify(Mul.__new__(cls, *args))
        if expr.is_Add:
            return MatAdd(*expr.args)
        if expr.is_Pow:
            assert expr.exp.is_Integer
            expr = Basic.__new__(MatMul, *[expr.base for i in range(expr.exp)])
        if not expr.is_Mul:
            return expr

        if any(arg.is_Matrix and arg.is_ZeroMatrix for arg in expr.args):
            return ZeroMatrix(*expr.shape)

        # Clear out Identities
        nonmats = [M for M in expr.args if not M.is_Matrix]  # scalars
        mats = [M for M in expr.args if M.is_Matrix]  # matrices
        if any(M.is_Identity for M in mats):  # Any identities around?
            newmats = [M for M in mats if not M.is_Identity]  # clear out
            if len(newmats) == 0:  # Did we lose everything?
                newmats = [Identity(expr.rows)]  # put just one back in

            if mats != newmats:  # Removed some I's but not everything?
                return MatMul(*(nonmats + newmats))  # Repeat with simpler expr

        return expr
예제 #3
0
파일: matmul.py 프로젝트: vperic/sympy
    def __new__(cls, *args):

        # Check that the shape of the args is consistent
        matrices = [arg for arg in args if arg.is_Matrix]

        for i in range(len(matrices) - 1):
            A, B = matrices[i:i + 2]
            if A.cols != B.rows:
                raise ShapeError("Matrices %s and %s are not aligned" % (A, B))

        if any(arg.is_zero for arg in args):
            return ZeroMatrix(matrices[0].rows, matrices[-1].cols)

        expr = matrixify(Mul.__new__(cls, *args))
        if expr.is_Add:
            return MatAdd(*expr.args)
        if expr.is_Pow:
            assert expr.exp.is_Integer
            expr = Basic.__new__(MatMul, *[expr.base for i in range(expr.exp)])
        if not expr.is_Mul:
            return expr

        if any(arg.is_Matrix and arg.is_ZeroMatrix for arg in expr.args):
            return ZeroMatrix(*expr.shape)

        # Clear out Identities
        nonmats = [M for M in expr.args if not M.is_Matrix]  # scalars
        mats = [M for M in expr.args if M.is_Matrix]  # matrices
        if any(M.is_Identity for M in mats):  # Any identities around?
            newmats = [M for M in mats if not M.is_Identity]  # clear out
            if len(newmats) == 0:  # Did we lose everything?
                newmats = [Identity(expr.rows)]  # put just one back in

            if mats != newmats:  # Removed some I's but not everything?
                return MatMul(*(nonmats + newmats))  # Repeat with simpler expr

        return expr