def test_matrix_expression_to_indices(): i, j = symbols("i, j") i1, i2, i3 = symbols("i_1:4") def replace_dummies(expr): repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} return expr.xreplace(repl) expr = W*X*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = Z.T*X.T*W.T assert replace_dummies(expr._entry(i, j)) == \ Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr expr = W*X*Z + W*Y*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = W*(X + Y)*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = A*B**2*A
def test_matrix_expression_to_indices(): i, j = symbols("i, j") i1, i2, i3 = symbols("i_1:4") def replace_dummies(expr): repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} return expr.xreplace(repl) expr = W * X * Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = Z.T * X.T * W.T assert replace_dummies(expr._entry(i, j)) == \ Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr expr = W * X * Z + W * Y * Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = W * (X + Y) * Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = A * B**2 * A
def test_matrix_expression_to_indices(): i, j = symbols("i, j") i1, i2, i3 = symbols("i_1:4") def replace_dummies(expr): repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} return expr.xreplace(repl) expr = W * X * Z assert replace_dummies(expr._entry(i, j)) == Sum( W[i, i1] * X[i1, i2] * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = Z.T * X.T * W.T assert replace_dummies(expr._entry(i, j)) == Sum( W[j, i2] * X[i2, i1] * Z[i1, i], (i1, 0, m - 1), (i2, 0, l - 1)) assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr expr = W * X * Z + W * Y * Z assert replace_dummies(expr._entry( i, j)) == Sum(W[i, i1] * X[i1, i2] * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) + Sum(W[i, i1] * Y[i1, i2] * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = 2 * W * X * Z + 3 * W * Y * Z assert replace_dummies(expr._entry(i, j)) == 2 * Sum( W[i, i1] * X[i1, i2] * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) + 3 * Sum(W[i, i1] * Y[i1, i2] * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = W * (X + Y) * Z assert replace_dummies(expr._entry(i, j)) == Sum( W[i, i1] * (X[i1, i2] + Y[i1, i2]) * Z[i2, j], (i1, 0, l - 1), (i2, 0, m - 1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = A * B**2 * A # assert replace_dummies(expr._entry(i, j)) == \ # Sum(A[i, i1]*B[i1, i2]*B[i2, i3]*A[i3, j], (i1, 0, 1), (i2, 0, 1), (i3, 0, 1)) # Check that different dummies are used in sub-multiplications: expr = (X1 * X2 + X2 * X1) * X3 assert replace_dummies(expr._entry(i, j)) == Sum( (Sum(X1[i, i2] * X2[i2, i1], (i2, 0, m - 1)) + Sum(X1[i3, i1] * X2[i, i3], (i3, 0, m - 1))) * X3[i1, j], (i1, 0, m - 1), )
def __add__(self, other): if isinstance(other, BlockMatrix): if len(self.args) == len(other.args): if all(x.shape == y.shape for x, y in zip(self.args, other.args)): return self.func( *[x + y for x, y in zip(self.args, other.args)]) return MatrixExpr.__add__(self, other)
def __new__(cls, arg, **kwargs): arg = _sympify(arg) if kwargs.get('evaluate', True): transpose = arg._eval_transpose() if transpose is not None: return transpose return MatrixExpr.__new__(cls, arg, **kwargs)
def test_matrix_expression_to_indices(): i, j = symbols("i, j") i1, i2, i3 = symbols("i_1:4") def replace_dummies(expr): repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} return expr.xreplace(repl) expr = W*X*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = Z.T*X.T*W.T assert replace_dummies(expr._entry(i, j)) == \ Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr expr = W*X*Z + W*Y*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = 2*W*X*Z + 3*W*Y*Z assert replace_dummies(expr._entry(i, j)) == \ 2*Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ 3*Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = W*(X + Y)*Z assert replace_dummies(expr._entry(i, j)) == \ Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr expr = A*B**2*A #assert replace_dummies(expr._entry(i, j)) == \ # Sum(A[i, i1]*B[i1, i2]*B[i2, i3]*A[i3, j], (i1, 0, 1), (i2, 0, 1), (i3, 0, 1)) # Check that different dummies are used in sub-multiplications: expr = (X1*X2 + X2*X1)*X3 assert replace_dummies(expr._entry(i, j)) == \ Sum((Sum(X1[i, i2] * X2[i2, i1], (i2, 0, m - 1)) + Sum(X1[i3, i1] * X2[i, i3], (i3, 0, m - 1))) * X3[ i1, j], (i1, 0, m - 1))
def _subs(self, old, new, **hints): if old.is_MatMul: args = old.args for i in range(len(self.args) - len(args) + 1): if all(self.args[j] == args[j - i] for j in range(i, i + len(args))): return self.func(*self.args[:i] + (new.args if new.is_MatMul else (new, )) + self.args[i + len(args):]).simplify() return MatrixExpr._subs(self, old, new, **hints)
def simplify(self, deep=False, **kwargs): if deep: return MatrixExpr.simplify(self, deep=deep, **kwargs) if self.axis == 0: if self.shape[0] == len(self.args): from sympy import Indexed start = None for i, arg in enumerate(self.args): if not isinstance(arg, Indexed): return self diff = arg.indices[-1] - i if start is None: start = diff else: if start != diff: return self return arg.base[start:len(self.args)] b = None start, stop = None, None for arg in self.args: if arg.is_Slice or arg.is_Indexed: if b is None: b = arg.base elif b != arg.base or len(arg.indices) > 1: b = None break if start is None: if arg.is_Slice: start, stop = arg.index else: start = arg.index stop = start + 1 else: if arg.is_Slice: _start, _stop = arg.index else: _start = arg.index _stop = _start + 1 if _start != stop: b = None break stop = _stop if b is not None: return b[start:stop] return self
def test_matrix_expression_from_index_summation(): from sympy.abc import a, b, c, d A = MatrixSymbol("A", k, k) B = MatrixSymbol("B", k, k) C = MatrixSymbol("C", k, k) w1 = MatrixSymbol("w1", k, 1) i0, i1, i2, i3, i4 = symbols("i0:5", cls=Dummy) expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(W.T[b, a] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(A[b, a] * B[b, c] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B * C expr = Sum(A[b, a] * B[c, b] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(C[c, d] * A[b, a] * B[c, b], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(A[a, b] + B[a, b], (a, 0, k - 1), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A + B expr = Sum((A[a, b] + B[a, b]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B) * C expr = Sum((A[a, b] + B[b, a]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B.T) * C expr = Sum(A[a, b] * A[b, c] * A[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A**3 expr = Sum(A[a, b] * A[b, c] * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A**2 * B # Parse the trace of a matrix: expr = Sum(A[a, a], (a, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, None) == trace(A) expr = Sum(A[a, a] * B[b, c] * C[c, d], (a, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, b) == trace(A) * B * C # Check wrong sum ranges (should raise an exception): ## Case 1: 0 to m instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) ## Case 2: 1 to m-1 instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 1, m - 1)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) # Parse nested sums: expr = Sum(A[a, b] * Sum(B[b, c] * C[c, d], (c, 0, k - 1)), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B * C # Test Kronecker delta: expr = Sum(A[a, b] * KroneckerDelta(b, c) * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B expr = Sum( KroneckerDelta(i1, m) * KroneckerDelta(i2, n) * A[i, i1] * A[j, i2], (i1, 0, k - 1), (i2, 0, k - 1), ) assert MatrixExpr.from_index_summation(expr, m) == A.T * A[j, n] # Test numbered indices: expr = Sum(A[i1, i2] * w1[i2, 0], (i2, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, i1) == A * w1 expr = Sum(A[i1, i2] * B[i2, 0], (i2, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A * B, i1, 0)
def test_matrix_expression_from_index_summation(): from sympy.abc import a, b, c, d A = MatrixSymbol("A", k, k) B = MatrixSymbol("B", k, k) C = MatrixSymbol("C", k, k) expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(W.T[b, a] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(A[b, a] * B[b, c] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B * C expr = Sum(A[b, a] * B[c, b] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(C[c, d] * A[b, a] * B[c, b], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(A[a, b] + B[a, b], (a, 0, k - 1), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A + B expr = Sum((A[a, b] + B[a, b]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B) * C expr = Sum((A[a, b] + B[b, a]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B.T) * C expr = Sum(A[a, b] * A[b, c] * A[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, A) expr = Sum(A[a, b] * A[b, c] * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, B) # Parse the trace of a matrix: expr = Sum(A[a, a], (a, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, None) == trace(A) expr = Sum(A[a, a] * B[b, c] * C[c, d], (a, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, b) == trace(A) * B * C # Check wrong sum ranges (should raise an exception): ## Case 1: 0 to m instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) ## Case 2: 1 to m-1 instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 1, m - 1)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) # Parse nested sums: expr = Sum(A[a, b] * Sum(B[b, c] * C[c, d], (c, 0, k - 1)), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B * C # Test Kronecker delta: expr = Sum(A[a, b] * KroneckerDelta(b, c) * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B
def test_matrix_expression_from_index_summation(): from sympy.abc import a,b,c,d A = MatrixSymbol("A", k, k) B = MatrixSymbol("B", k, k) C = MatrixSymbol("C", k, k) w1 = MatrixSymbol("w1", k, 1) i0, i1, i2, i3, i4 = symbols("i0:5", cls=Dummy) expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) assert MatrixExpr.from_index_summation(expr, a) == W*X*Z expr = Sum(W.T[b,a]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) assert MatrixExpr.from_index_summation(expr, a) == W*X*Z expr = Sum(A[b, a]*B[b, c]*C[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B*C expr = Sum(A[b, a]*B[c, b]*C[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C expr = Sum(C[c, d]*A[b, a]*B[c, b], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C expr = Sum(A[a, b] + B[a, b], (a, 0, k-1), (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A + B expr = Sum((A[a, b] + B[a, b])*C[b, c], (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == (A+B)*C expr = Sum((A[a, b] + B[b, a])*C[b, c], (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == (A+B.T)*C expr = Sum(A[a, b]*A[b, c]*A[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, A) expr = Sum(A[a, b]*A[b, c]*B[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, B) # Parse the trace of a matrix: expr = Sum(A[a, a], (a, 0, k-1)) assert MatrixExpr.from_index_summation(expr, None) == trace(A) expr = Sum(A[a, a]*B[b, c]*C[c, d], (a, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, b) == trace(A)*B*C # Check wrong sum ranges (should raise an exception): ## Case 1: 0 to m instead of 0 to m-1 expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) ## Case 2: 1 to m-1 instead of 0 to m-1 expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 1, m-1)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) # Parse nested sums: expr = Sum(A[a, b]*Sum(B[b, c]*C[c, d], (c, 0, k-1)), (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A*B*C # Test Kronecker delta: expr = Sum(A[a, b]*KroneckerDelta(b, c)*B[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A*B expr = Sum(KroneckerDelta(i1, m)*KroneckerDelta(i2, n)*A[i, i1]*A[j, i2], (i1, 0, k-1), (i2, 0, k-1)) assert MatrixExpr.from_index_summation(expr, m) == A.T*A[j, n] # Test numbered indices: expr = Sum(A[i1, i2]*w1[i2, 0], (i2, 0, k-1)) assert MatrixExpr.from_index_summation(expr, i1) == A*w1 expr = Sum(A[i1, i2]*B[i2, 0], (i2, 0, k-1)) assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*B, i1, 0)
def domain_defined(self, x): domain = MatrixExpr.domain_defined(self, x) for arg in self.args: domain &= arg.domain_defined(x) return domain
def domain_definition(self): eq = MatrixExpr.domain_definition(self) for arg in self.args: eq &= arg.domain_definition() return eq
def expand(self, free_symbol=None, deep=True, **_): if not deep: return MatrixExpr.expand(self) from sympy.concrete.expr_with_limits import LAMBDA from sympy.concrete.summations import Sum if len(self.args) > 2: return MatrixExpr.expand(self) A, B = self.args if A.is_MatPow: return self if A.is_Concatenate: if B.is_Concatenate and len(A.shape) == 1: if len(A.args) == len(B.args): sgm = None for a, b in zip(A.args, B.args): if a.shape: product = a @ b if product.is_MatMul and len(product.args) == 2: product = product.expand() else: product = a * b if sgm is None: sgm = product else: sgm += product return sgm else: return self else: args = [self.func(arg, B).simplify() for arg in A.args] if deep: args = [ arg.expand(deep=True) if arg.is_MatMul else arg for arg in args ] return A.func(*args) if A.is_Transpose: if B.is_Transpose: return (B.arg @ A.arg).expand().T if A.arg.is_Concatenate and B.is_Concatenate: A_T = A.arg if len(A_T.args) == len(B.args): B_T = B._eval_transpose() if B_T is not None: # A @ B = A.T.T @ B.T.T = (B.T @ A.T).T return (B_T @ A_T).expand().T sgm = None for a, b in zip(A_T.args, B.args): if len(a.shape) == 1 and len(b.shape) == 1: n = a.shape[0] if b.shape[0] == n: i = a.generate_free_symbol(b.free_symbols, integer=True) j = a.generate_free_symbol(b.free_symbols | {i}, integer=True) product = LAMBDA[j:n, i:n](a[i] * b[j]).simplify() else: return self else: if not b.shape: product = a * b elif a.rows == b.shape[0]: product = (a.T @ b).simplify() if product.is_MatMul and len( product.args) == 2: X = product.args[1] if X.is_Transpose and X.arg.is_Concatenate: product = product.expand() else: return self if sgm is None: sgm = product else: sgm += product return sgm return self if B.is_Concatenate: return self if B.is_Transpose and B.arg.is_Concatenate: return (B.arg @ A.T).expand().T if A.is_MatProduct: return self kwargs = {'free_symbol': free_symbol, 'generator': self} def generate_k_limit(A, B, excludes=None, **kwargs): if A.is_LAMBDA or not B.is_LAMBDA: if excludes: excludes |= B.free_symbols else: excludes = B.free_symbols return A.generate_int_limit(0, excludes, **kwargs) if excludes: excludes |= A.free_symbols else: excludes = A.free_symbols return B.generate_int_limit(0 if len(B.shape) == 1 else 1, excludes, **kwargs) if len(A.shape) > 1: i_limit = A.generate_int_limit(1, **kwargs) i, *_ = i_limit if len(B.shape) > 1: j_limit = B.generate_int_limit(0, {i}, **kwargs) j, *_ = j_limit k_limit = generate_k_limit(A, B, {i, j}, **kwargs) k, *_ = k_limit assert i != k and k != j and i != j return LAMBDA( Sum(A[i, k] * B[k, j], k_limit).simplify(), j_limit, i_limit).simplify() else: k_limit = generate_k_limit(A, B, {i}, **kwargs) k, *_ = k_limit assert i != k return LAMBDA( Sum(A[i, k] * B[k], k_limit).simplify(), i_limit).simplify() else: # print('A.shape =', A.shape) if len(B.shape) > 1: if B.shape[-1].is_Integer: k_limit = generate_k_limit(A, B, **kwargs) k, *_ = k_limit args = [] for j in range(B.shape[-1]): args.append(Sum(A[k] * B[k, j], k_limit).simplify()) return Concatenate(*args) else: # print('B.shape =', B.shape) j_limit = B.generate_int_limit(0, **kwargs) j, *_ = j_limit k_limit = generate_k_limit(A, B, {j}, **kwargs) k, *_ = k_limit assert k != j return LAMBDA( Sum(A[k] * B[k, j], k_limit).simplify(), j_limit).simplify() k_limit = generate_k_limit(A, B, **kwargs) k, *_ = k_limit return Sum(A[k] * B[k], k_limit).simplify()
def expand(self, var=None, deep=True, **_): if not deep: return MatrixExpr.expand(self) from sympy.concrete.expr_with_limits import Lamda from sympy.concrete.summations import Sum if len(self.args) > 2: matmul = self.func(*self.args[:-1]).expand( var=var, deep=deep) @ self.args[-1] if matmul.is_MatMul: matmul = matmul.expand(var=var, deep=deep) return matmul A, B = self.args if A.is_MatPow: return self if A.is_BlockMatrix: if B.is_BlockMatrix and len(A.shape) == 1: if len(A.args) == len(B.args): sgm = None for a, b in zip(A.args, B.args): if a.shape: product = a @ b if product.is_MatMul and len(product.args) == 2: product = product.expand() else: product = a * b if sgm is None: sgm = product else: sgm += product return sgm else: return self else: args = [self.func(arg, B).simplify() for arg in A.args] if deep: args = [ arg.expand(deep=True) if arg.is_MatMul else arg for arg in args ] return A.func(*args) if A.is_Transpose: if B.is_Transpose: return (B.arg @ A.arg).expand().T if A.arg.is_BlockMatrix and B.is_BlockMatrix: A_T = A.arg if len(A_T.args) == len(B.args): B_T = B._eval_transpose() if B_T is not None: # A @ B = A.T.T @ B.T.T = (B.T @ A.T).T return (B_T @ A_T).expand().T sgm = None for a, b in zip(A_T.args, B.args): if len(a.shape) == 1 and len(b.shape) == 1: n = a.shape[0] if b.shape[0] == n: i = a.generate_var(b.free_symbols, integer=True) j = a.generate_var(b.free_symbols | {i}, integer=True) product = Lamda[j:n, i:n](a[i] * b[j]).simplify() else: return self else: if not b.shape: product = a * b elif a.rows == b.shape[0]: product = (a.T @ b).simplify() if product.is_MatMul and len( product.args) == 2: X = product.args[1] if X.is_Transpose and X.arg.is_BlockMatrix: product = product.expand() else: return self if sgm is None: sgm = product else: sgm += product return sgm return self if B.is_BlockMatrix: return self if B.is_Transpose and B.arg.is_BlockMatrix: return (B.arg @ A.T).expand().T if A.is_MatProduct: return self kwargs = {'var': var, 'generator': self} def generate_k_limit(A, B, excludes=None, **kwargs): if A.is_Lamda or not B.is_Lamda: if excludes: excludes |= B.free_symbols else: excludes = B.free_symbols return A.generate_int_limit(0, excludes, **kwargs) if excludes: excludes |= A.free_symbols else: excludes = A.free_symbols return B.generate_int_limit(0 if len(B.shape) == 1 else 1, excludes, **kwargs) if len(A.shape) == 1 and len(B.shape) > 1 and B.shape[-1].is_Integer: k_limit = generate_k_limit(A, B, **kwargs) k, *_ = k_limit args = [] if A.shape[0].is_Integer: for j in range(B.shape[-1]): args.append(Sum(A[k] * B[k, j], k_limit).doit()) from sympy import Matrix return Matrix(tuple(args)) else: for j in range(B.shape[-1]): args.append(Sum(A[k] * B[k, j], k_limit).simplify()) return BlockMatrix(*args) return self
def test_matrix_expression_from_index_summation(): from sympy.abc import a,b,c,d A = MatrixSymbol("A", k, k) B = MatrixSymbol("B", k, k) C = MatrixSymbol("C", k, k) expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) assert MatrixExpr.from_index_summation(expr, a) == W*X*Z expr = Sum(W.T[b,a]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) assert MatrixExpr.from_index_summation(expr, a) == W*X*Z expr = Sum(A[b, a]*B[b, c]*C[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B*C expr = Sum(A[b, a]*B[c, b]*C[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C expr = Sum(C[c, d]*A[b, a]*B[c, b], (b, 0, k-1), (c, 0, k-1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C expr = Sum(A[a, b] + B[a, b], (a, 0, k-1), (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A + B expr = Sum((A[a, b] + B[a, b])*C[b, c], (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == (A+B)*C expr = Sum((A[a, b] + B[b, a])*C[b, c], (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == (A+B.T)*C expr = Sum(A[a, b]*A[b, c]*A[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, A) expr = Sum(A[a, b]*A[b, c]*B[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, B) # Parse the trace of a matrix: expr = Sum(A[a, a], (a, 0, k-1)) assert MatrixExpr.from_index_summation(expr, None) == trace(A) expr = Sum(A[a, a]*B[b, c]*C[c, d], (a, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, b) == trace(A)*B*C # Check wrong sum ranges (should raise an exception): ## Case 1: 0 to m instead of 0 to m-1 expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) ## Case 2: 1 to m-1 instead of 0 to m-1 expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 1, m-1)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) # Parse nested sums: expr = Sum(A[a, b]*Sum(B[b, c]*C[c, d], (c, 0, k-1)), (b, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A*B*C # Test Kronecker delta: expr = Sum(A[a, b]*KroneckerDelta(b, c)*B[c, d], (b, 0, k-1), (c, 0, k-1)) assert MatrixExpr.from_index_summation(expr, a) == A*B
def __rmul__(self, other): if not other.shape: return self.func(*(other * arg for arg in self.args)) return MatrixExpr.__rmul__(self, other)
def _eval_domain_defined(self, x, **_): domain = MatrixExpr._eval_domain_defined(self, x) for arg in self.args: domain &= arg.domain_defined(x) return domain