def comm(A, B): """ return commutator(A, B)[A,B]=A.B-B.A. 'A', 'B': sympy matrices of the same shape """ assert A.shape == B.shape, (f"A,B must have the same shape, but" + f"A.shape={A.shape} != B.shape={B.shape}") return (MatMul(A, B) - MatMul(B, A)) #.as_mutable
def hadamard_or_mul(arg1, arg2): if arg1.shape == arg2.shape: return hadamard_product(arg1, arg2) elif arg1.shape[1] == arg2.shape[0]: return MatMul(arg1, arg2).doit() elif arg1.shape[0] == arg2.shape[0]: return MatMul(arg2.T, arg1).doit() raise NotImplementedError
def test_lm_sym_expanded(): m = Matrix([[0, x], [3.4 * y, 3 * x - 4.5 * y + z]]) c = Matrix([[1.2, 0], [0, 1.2]]) cx = MatMul(Matrix([[0.0, 1.0], [0.0, 3.0]]), x) cy = MatMul(Matrix([[0.0, 0.0], [3.4, -4.5]]), y) cz = MatMul(Matrix([[0.0, 0.0], [0.0, 1.0]]), z) cc = Matrix([[1.2, 0.0], [0.0, 1.2]]) assert MatAdd(cx, cy, cz, cc) == lm_sym_expanded(m + c, [x, y, z]) assert MatAdd(cx, cy, cz) == lm_sym_expanded(m, [x, y, z]) assert cc == lm_sym_expanded(c, [x, y, z])
def test_matrix_expression_from_index_summation(): from sympy.abc import a, b, c, d A = MatrixSymbol("A", k, k) B = MatrixSymbol("B", k, k) C = MatrixSymbol("C", k, k) expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(W.T[b, a] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m - 1)) assert MatrixExpr.from_index_summation(expr, a) == W * X * Z expr = Sum(A[b, a] * B[b, c] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B * C expr = Sum(A[b, a] * B[c, b] * C[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(C[c, d] * A[b, a] * B[c, b], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixSymbol.from_index_summation(expr, a) == A.T * B.T * C expr = Sum(A[a, b] + B[a, b], (a, 0, k - 1), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A + B expr = Sum((A[a, b] + B[a, b]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B) * C expr = Sum((A[a, b] + B[b, a]) * C[b, c], (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == (A + B.T) * C expr = Sum(A[a, b] * A[b, c] * A[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, A) expr = Sum(A[a, b] * A[b, c] * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == MatMul(A, A, B) # Parse the trace of a matrix: expr = Sum(A[a, a], (a, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, None) == trace(A) expr = Sum(A[a, a] * B[b, c] * C[c, d], (a, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, b) == trace(A) * B * C # Check wrong sum ranges (should raise an exception): ## Case 1: 0 to m instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 0, m)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) ## Case 2: 1 to m-1 instead of 0 to m-1 expr = Sum(W[a, b] * X[b, c] * Z[c, d], (b, 0, l - 1), (c, 1, m - 1)) raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) # Parse nested sums: expr = Sum(A[a, b] * Sum(B[b, c] * C[c, d], (c, 0, k - 1)), (b, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B * C # Test Kronecker delta: expr = Sum(A[a, b] * KroneckerDelta(b, c) * B[c, d], (b, 0, k - 1), (c, 0, k - 1)) assert MatrixExpr.from_index_summation(expr, a) == A * B
def refine_MatMul(expr, assumptions): """ >>> from sympy import MatrixSymbol, Q, assuming, refine >>> X = MatrixSymbol('X', 2, 2) >>> expr = X * X.T >>> print(expr) X*X.T >>> with assuming(Q.orthogonal(X)): ... print(refine(expr)) I """ newargs = [] exprargs = [] for args in expr.args: if args.is_Matrix: exprargs.append(args) else: newargs.append(args) last = exprargs[0] for arg in exprargs[1:]: if arg == last.T and ask(Q.orthogonal(arg), assumptions): last = Identity(arg.shape[0]) elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): last = Identity(arg.shape[0]) else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs)
def wavelet_filter_to_matrix_form(lifting_filter_parameters): """ Convert a :py:class:`vc2_data_tables.LiftingFilterParameters` filter specification into :math:`z`-domain matrix form. """ matrix = None for stage in lifting_filter_parameters.stages: stage_type, z_transform = lifting_stage_to_z_transform(stage) if stage_type is StageType.update: this_matrix = Matrix([ [1, z_transform], [0, 1], ]) elif stage_type is StageType.predict: this_matrix = Matrix([ [1, 0], [z_transform, 1], ]) if matrix is None: matrix = this_matrix else: matrix = MatMul(this_matrix, matrix) assert matrix is not None return matrix
def test_linalg_placeholder_multiple_mul(): assert_equal( "\\begin{pmatrix}3&-1\\end{pmatrix}\\cdot\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[3, -1]]), Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), { 'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2]) })
def doit(self, **hints): from sympy.assumptions import ask, Q from sympy import Transpose, Mul, MatMul from sympy import MatrixBase, eye vector = self._vector # This accounts for shape (1, 1) and identity matrices, among others: if ask(Q.diagonal(vector)): return vector if isinstance(vector, MatrixBase): ret = eye(max(vector.shape)) for i in range(ret.shape[0]): ret[i, i] = vector[i] return type(vector)(ret) if vector.is_MatMul: matrices = [arg for arg in vector.args if arg.is_Matrix] scalars = [arg for arg in vector.args if arg not in matrices] if scalars: return ( Mul.fromiter(scalars) * DiagMatrix(MatMul.fromiter(matrices).doit()).doit() ) if isinstance(vector, Transpose): vector = vector.arg return DiagMatrix(vector)
def _a2m_mul(*args): if not any(isinstance(i, _CodegenArrayAbstract) for i in args): from sympy.matrices.expressions.matmul import MatMul return MatMul(*args).doit() else: return _array_contraction( _array_tensor_product(*args), *[(2 * i - 1, 2 * i) for i in range(1, len(args))])
def _a2m_mul(*args): if all(not isinstance(i, _CodegenArrayAbstract) for i in args): from sympy import MatMul return MatMul(*args).doit() else: return ArrayContraction( ArrayTensorProduct(*args), *[(2 * i - 1, 2 * i) for i in range(1, len(args))])
def convert_matrix_to_array(expr: MatrixExpr) -> Basic: if isinstance(expr, MatMul): args_nonmat = [] args = [] for arg in expr.args: if isinstance(arg, MatrixExpr): args.append(arg) else: args_nonmat.append(convert_matrix_to_array(arg)) contractions = [(2 * i + 1, 2 * i + 2) for i in range(len(args) - 1)] scalar = ArrayTensorProduct.fromiter( args_nonmat) if args_nonmat else S.One if scalar == 1: tprod = ArrayTensorProduct( *[convert_matrix_to_array(arg) for arg in args]) else: tprod = ArrayTensorProduct( scalar, *[convert_matrix_to_array(arg) for arg in args]) return ArrayContraction(tprod, *contractions) elif isinstance(expr, MatAdd): return ArrayAdd(*[convert_matrix_to_array(arg) for arg in expr.args]) elif isinstance(expr, Transpose): return PermuteDims(convert_matrix_to_array(expr.args[0]), [1, 0]) elif isinstance(expr, Trace): inner_expr = convert_matrix_to_array(expr.arg) return ArrayContraction(inner_expr, (0, len(inner_expr.shape) - 1)) elif isinstance(expr, Mul): return ArrayTensorProduct.fromiter( convert_matrix_to_array(i) for i in expr.args) elif isinstance(expr, Pow): base = convert_matrix_to_array(expr.base) if (expr.exp > 0) == True: return ArrayTensorProduct.fromiter(base for i in range(expr.exp)) else: return expr elif isinstance(expr, MatPow): base = convert_matrix_to_array(expr.base) if expr.exp.is_Integer != True: b = symbols("b", cls=Dummy) return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base)) elif (expr.exp > 0) == True: return convert_matrix_to_array( MatMul.fromiter(base for i in range(expr.exp))) else: return expr elif isinstance(expr, HadamardProduct): tp = ArrayTensorProduct.fromiter(expr.args) diag = [[2 * i for i in range(len(expr.args))], [2 * i + 1 for i in range(len(expr.args))]] return ArrayDiagonal(tp, *diag) elif isinstance(expr, HadamardPower): base, exp = expr.args return convert_matrix_to_array( HadamardProduct.fromiter(base for i in range(exp))) else: return expr
def doit(self, **kwargs): deep = kwargs.get('deep', False) if deep: args = [arg.doit(**kwargs) for arg in self.args] else: args = self.args # treat scalar*MatrixSymbol or scalar*MatPow separately expr = canonicalize(MatMul(*args)) return expr
def _cyclic_permute(expr): if expr.is_Trace and expr.arg.is_MatMul: prods = expr.arg.args newprods = [prods[-1], *prods[:-1]] return Trace(MatMul(*newprods)) else: print(expr) raise RuntimeError( "Only know how to cyclic permute products inside traces!")
def new_MM(self): init_printing() m, n, k, self.tag = self.__choose_problem_type__() self.A = self.random_integer_matrix( m, k ) self.x = self.random_integer_matrix( k, n ) self.answer = self.A * self.x return Math( "$$" + latex( MatMul( self.A, self.x ), mat_str = "matrix" ) + "=" + "?" + "$$" )
def new_problem(self): init_printing() m = self.random_integer( 3, 4 ) n = self.random_integer( 2, 3 ) self.A = self.random_integer_matrix( m, n ) self.x = self.random_integer_matrix( n, 1 ) self.answer = self.A * self.x return Math( "$$" + latex( MatMul( self.A, self.x ), mat_str = "matrix" ) + "=" + "?" + "$$" )
def test_wavelet_filter_to_matrix_form(): assert wavelet_filter_to_matrix_form( tables.LIFTING_FILTERS[tables.WaveletFilters.haar_no_shift] ) == MatMul( Matrix([ [1, 0], [1, 1], ]), Matrix([ [1, -Rational(1, 2)], [0, 1], ]) )
def lm_sym_expanded(linear_matrix, variables): """Return matrix in the form of sum of coefficent matrices times varibles. """ if S(linear_matrix).free_symbols & set(variables): coeffs, const = lm_sym_to_coeffs(linear_matrix, variables) terms = [] for i, v in enumerate(variables): terms.append(MatMul(ImmutableMatrix(coeffs[i]), v)) if const.any(): terms.append(ImmutableMatrix(const)) return MatAdd(*terms) else: return linear_matrix
def sc_ode_to_matrix(sc_ode, op_func_map, t): """ Convert a set of semiclassical equations of motion to matrix form. """ ops = operator_sort_by_order(sc_ode.keys()) A = Matrix([op_func_map[op] for op in ops]) subs = [(op_func_map[op], Symbol(op_func_map[op].name)) for op in ops] eqns = [sc_ode[op].rhs.subs(subs) for op in ops] M, C = linear_eq_to_matrix(eqns, list(zip(*subs))[1]) A_eq = Eq(-Derivative(A, t), Add(-C, MatMul(M, A), evaluate=False), evaluate=False) return A_eq, A, M, -C
def test_convert_between_synthesis_and_analysis(): # The Daubechies 9 7 wavelet is used here because it uses one of every type # of lifting operation. synth_params = tables.LIFTING_FILTERS[tables.WaveletFilters.daubechies_9_7] analy_params = convert_between_synthesis_and_analysis(synth_params) synth_matrix = wavelet_filter_to_matrix_form(synth_params) analy_matrix = wavelet_filter_to_matrix_form(analy_params) # If the generated analysis filter matches the synthesis filter, the # combined matrices should reduce to an identity. assert simplify(MatMul(synth_matrix, analy_matrix).doit()) == Matrix([ [1, 0], [0, 1], ])
def doit(self, **hints): from sympy.assumptions import ask, Q from sympy import Transpose, Mul, MatMul vector = self._vector # This accounts for shape (1, 1) and identity matrices, among others: if ask(Q.diagonal(vector)): return vector if vector.is_MatMul: matrices = [arg for arg in vector.args if arg.is_Matrix] scalars = [arg for arg in vector.args if arg not in matrices] if scalars: return Mul.fromiter(scalars)*DiagonalizeVector(MatMul.fromiter(matrices).doit()).doit() if isinstance(vector, Transpose): vector = vector.arg return DiagonalizeVector(vector)
def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct): # If there are matrices of trivial shape in the tensor product (i.e. shape # (1, 1)), try to check if there is a suitable non-trivial MatMul where the # expression can be inserted. # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as # "b*a*b.T" trivial_matrices = [] pos: Optional[int] = None first: Optional[MatrixExpr] = None second: Optional[MatrixExpr] = None removed: List[int] = [] counter: int = 0 args: List[Optional[Basic]] = [i for i in expr.args] for i, arg in enumerate(expr.args): if isinstance(arg, MatrixExpr): if arg.shape == (1, 1): trivial_matrices.append(arg) args[i] = None removed.extend([counter, counter + 1]) elif pos is None and isinstance(arg, MatMul): margs = arg.args for j, e in enumerate(margs): if isinstance(e, MatrixExpr) and e.shape[1] == 1: pos = i first = MatMul.fromiter(margs[:j + 1]) second = MatMul.fromiter(margs[j + 1:]) break counter += get_rank(arg) if pos is None: return expr, [] args[pos] = (first * MatMul.fromiter(i for i in trivial_matrices) * second).doit() return _array_tensor_product(*[i for i in args if i is not None]), removed
def only_squares(*matrices): """factor matrices only if they are square""" if matrices[0].shape[-2] != matrices[-1].shape[-1]: raise RuntimeError("Invalid matrices being multiplied") out = [] start = 0 for i, M in enumerate(matrices): if M.shape[-1] == matrices[start].shape[-2]: args = matrices[start:i + 1] if len(args) == 1: mat = args[0] else: mat = MatMul(*args).doit() out.append(mat) start = i + 1 return out
def semi_classical_eqm_matrix_form(sc_eqm): """ Convert a set of semiclassical equations of motion to matrix form. """ ops = operator_sort_by_order(sc_eqm.op_func_map.keys()) As = [sc_eqm.op_func_map[op] for op in ops] A = Matrix(As) b = Matrix([[sc_eqm.sc_ode[op].rhs.subs({A: 0 for A in As})] for op in ops]) M = Matrix([[((sc_eqm.sc_ode[op1].rhs - b[m]).subs( {A: 0 for A in (set(As) - set([sc_eqm.op_func_map[op2]]))}) / sc_eqm.op_func_map[op2]).expand() for m, op1 in enumerate(ops)] for n, op2 in enumerate(ops)]).T return Equality(-Derivative(A, sc_eqm.t), b + MatMul(M, A)), A, M, b
def repl(x): pre, post = [], [] sawAdd = False for arg in x.args: if arg.is_MatAdd: sawAdd = True add = arg continue if not sawAdd: pre.append(arg) else: post.append(arg) # ugly hack here because I can't figure out how to not end up # with nested parens that break other things addends = [[*addend.args] if addend.is_MatMul else [addend] for addend in add.args] return MatAdd(*[MatMul(*[*pre, *addend, *post]) for addend in addends])
def _normalize(self): # Normalization of trace of matrix products. Use transposition and # cyclic properties of traces to make sure the arguments of the matrix # product are sorted and the first argument is not a trasposition. from sympy import MatMul, Transpose, default_sort_key trace_arg = self.arg if isinstance(trace_arg, MatMul): indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) if isinstance(trace_arg.args[indmin], Transpose): trace_arg = Transpose(trace_arg).doit() indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin]) return Trace(trace_arg) return self
def merge_explicit(matmul): """ Merge explicit MatrixBase arguments >>> from sympy import MatrixSymbol, eye, Matrix, MatMul, pprint >>> from sympy.matrices.expressions.matmul import merge_explicit >>> A = MatrixSymbol('A', 2, 2) >>> B = Matrix([[1, 1], [1, 1]]) >>> C = Matrix([[1, 2], [3, 4]]) >>> X = MatMul(A, B, C) >>> pprint(X) [1 1] [1 2] A*[ ]*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [4 6] A*[ ] [4 6] >>> X = MatMul(B, A, C) >>> pprint(X) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] """ if not any(isinstance(arg, MatrixBase) for arg in matmul.args): return matmul newargs = [] last = matmul.args[0] for arg in matmul.args[1:]: if isinstance(arg, (MatrixBase, Number)) and isinstance( last, (MatrixBase, Number)): last = last * arg else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs)
def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 args_ind = [] for arg in expr.args: retvals = recurse_expr(arg, index_ranges) assert isinstance(retvals, list) if isinstance(retvals, list): for i in retvals: args_ind.append(i) else: args_ind.append(retvals) for arg_symbol, arg_indices in args_ind: if arg_indices is None: nonmatargs.append(arg_symbol) continue if isinstance(arg_symbol, MatrixElement): arg_symbol = arg_symbol.args[0] pos_arg.append(arg_symbol) pos_ind.append(arg_indices) link_ind.append([None]*len(arg_indices)) for i, ind in enumerate(arg_indices): if ind in dlinks: other_i = dlinks[ind] link_ind[counter][i] = other_i link_ind[other_i[0]][other_i[1]] = (counter, i) dlinks[ind] = (counter, i) counter += 1 counter2 = 0 lines = {} while counter2 < len(link_ind): for i, e in enumerate(link_ind): if None in e: line_start_index = (i, e.index(None)) break cur_ind_pos = line_start_index cur_line = [] index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]] while True: d, r = cur_ind_pos if pos_arg[d] != 1: if r % 2 == 1: cur_line.append(transpose(pos_arg[d])) else: cur_line.append(pos_arg[d]) next_ind_pos = link_ind[d][1-r] counter2 += 1 # Mark as visited, there will be no `None` anymore: link_ind[d] = (-1, -1) if next_ind_pos is None: index2 = pos_ind[d][1-r] lines[(index1, index2)] = cur_line break cur_ind_pos = next_ind_pos ret_indices = list(j for i in lines for j in i) lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()} return [(Mul.fromiter(nonmatargs), None)] + [ (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items() ] elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] d = collections.defaultdict(list) for res_addend in res: scalar = 1 for elem, indices in res_addend: if indices is None: scalar = elem continue indices = tuple(sorted(indices, key=default_sort_key)) d[indices].append(scalar*remove_matelement(elem, *indices)) scalar = 1 return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args if dimensions is not None: identity = Identity(dimensions[0]) else: identity = S.One return [(MatrixElement(identity, i1, i2), (i1, i2))] elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return [(trace(matrix_symbol), None)] return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))] elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return [(expr, None)]
def test_matrix_derivatives_of_traces(): expr = Trace(A) * A assert expr.diff(A) == Derivative(Trace(A) * A, A) assert expr[i, j].diff( A[m, n]).doit() == (KDelta(i, m) * KDelta(j, n) * Trace(A) + KDelta(m, n) * A[i, j]) ## First order: # Cookbook example 99: expr = Trace(X) assert expr.diff(X) == Identity(k) assert expr.rewrite(Sum).diff(X[m, n]).doit() == KDelta(m, n) # Cookbook example 100: expr = Trace(X * A) assert expr.diff(X) == A.T assert expr.rewrite(Sum).diff(X[m, n]).doit() == A[n, m] # Cookbook example 101: expr = Trace(A * X * B) assert expr.diff(X) == A.T * B.T assert expr.rewrite(Sum).diff(X[m, n]).doit().dummy_eq((A.T * B.T)[m, n]) # Cookbook example 102: expr = Trace(A * X.T * B) assert expr.diff(X) == B * A # Cookbook example 103: expr = Trace(X.T * A) assert expr.diff(X) == A # Cookbook example 104: expr = Trace(A * X.T) assert expr.diff(X) == A # Cookbook example 105: # TODO: TensorProduct is not supported #expr = Trace(TensorProduct(A, X)) #assert expr.diff(X) == Trace(A)*Identity(k) ## Second order: # Cookbook example 106: expr = Trace(X**2) assert expr.diff(X) == 2 * X.T # Cookbook example 107: expr = Trace(X**2 * B) assert expr.diff(X) == (X * B + B * X).T expr = Trace(MatMul(X, X, B)) assert expr.diff(X) == (X * B + B * X).T # Cookbook example 108: expr = Trace(X.T * B * X) assert expr.diff(X) == B * X + B.T * X # Cookbook example 109: expr = Trace(B * X * X.T) assert expr.diff(X) == B * X + B.T * X # Cookbook example 110: expr = Trace(X * X.T * B) assert expr.diff(X) == B * X + B.T * X # Cookbook example 111: expr = Trace(X * B * X.T) assert expr.diff(X) == X * B.T + X * B # Cookbook example 112: expr = Trace(B * X.T * X) assert expr.diff(X) == X * B.T + X * B # Cookbook example 113: expr = Trace(X.T * X * B) assert expr.diff(X) == X * B.T + X * B # Cookbook example 114: expr = Trace(A * X * B * X) assert expr.diff(X) == A.T * X.T * B.T + B.T * X.T * A.T # Cookbook example 115: expr = Trace(X.T * X) assert expr.diff(X) == 2 * X expr = Trace(X * X.T) assert expr.diff(X) == 2 * X # Cookbook example 116: expr = Trace(B.T * X.T * C * X * B) assert expr.diff(X) == C.T * X * B * B.T + C * X * B * B.T # Cookbook example 117: expr = Trace(X.T * B * X * C) assert expr.diff(X) == B * X * C + B.T * X * C.T # Cookbook example 118: expr = Trace(A * X * B * X.T * C) assert expr.diff(X) == A.T * C.T * X * B.T + C * A * X * B # Cookbook example 119: expr = Trace((A * X * B + C) * (A * X * B + C).T) assert expr.diff(X) == 2 * A.T * (A * X * B + C) * B.T # Cookbook example 120: # TODO: no support for TensorProduct. # expr = Trace(TensorProduct(X, X)) # expr = Trace(X)*Trace(X) # expr.diff(X) == 2*Trace(X)*Identity(k) # Higher Order # Cookbook example 121: expr = Trace(X**k) #assert expr.diff(X) == k*(X**(k-1)).T # Cookbook example 122: expr = Trace(A * X**k) #assert expr.diff(X) == # Needs indices # Cookbook example 123: expr = Trace(B.T * X.T * C * X * X.T * C * X * B) assert expr.diff( X ) == C * X * X.T * C * X * B * B.T + C.T * X * B * B.T * X.T * C.T * X + C * X * B * B.T * X.T * C * X + C.T * X * X.T * C.T * X * B * B.T # Other # Cookbook example 124: expr = Trace(A * X**(-1) * B) assert expr.diff(X) == -Inverse(X).T * A.T * B.T * Inverse(X).T # Cookbook example 125: expr = Trace(Inverse(X.T * C * X) * A) # Warning: result in the cookbook is equivalent if B and C are symmetric: assert expr.diff(X) == -X.inv().T * A.T * X.inv() * C.inv().T * X.inv( ).T - X.inv().T * A * X.inv() * C.inv() * X.inv().T # Cookbook example 126: expr = Trace((X.T * C * X).inv() * (X.T * B * X)) assert expr.diff(X) == -2 * C * X * (X.T * C * X).inv() * X.T * B * X * ( X.T * C * X).inv() + 2 * B * X * (X.T * C * X).inv() # Cookbook example 127: expr = Trace((A + X.T * C * X).inv() * (X.T * B * X)) # Warning: result in the cookbook is equivalent if B and C are symmetric: assert expr.diff(X) == B * X * Inverse(A + X.T * C * X) - C * X * Inverse( A + X.T * C * X) * X.T * B * X * Inverse(A + X.T * C * X) - C.T * X * Inverse( A.T + (C * X).T * X) * X.T * B.T * X * Inverse( A.T + (C * X).T * X) + B.T * X * Inverse(A.T + (C * X).T * X)
def __neg__(self): return MatMul(S.NegativeOne, self).doit()
def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 args_ind = [] for arg in expr.args: retvals = recurse_expr(arg, index_ranges) assert isinstance(retvals, list) if isinstance(retvals, list): for i in retvals: args_ind.append(i) else: args_ind.append(retvals) for arg_symbol, arg_indices in args_ind: if arg_indices is None: nonmatargs.append(arg_symbol) continue if isinstance(arg_symbol, MatrixElement): arg_symbol = arg_symbol.args[0] pos_arg.append(arg_symbol) pos_ind.append(arg_indices) link_ind.append([None]*len(arg_indices)) for i, ind in enumerate(arg_indices): if ind in dlinks: other_i = dlinks[ind] link_ind[counter][i] = other_i link_ind[other_i[0]][other_i[1]] = (counter, i) dlinks[ind] = (counter, i) counter += 1 counter2 = 0 lines = {} while counter2 < len(link_ind): for i, e in enumerate(link_ind): if None in e: line_start_index = (i, e.index(None)) break cur_ind_pos = line_start_index cur_line = [] index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]] while True: d, r = cur_ind_pos if pos_arg[d] != 1: if r % 2 == 1: cur_line.append(transpose(pos_arg[d])) else: cur_line.append(pos_arg[d]) next_ind_pos = link_ind[d][1-r] counter2 += 1 # Mark as visited, there will be no `None` anymore: link_ind[d] = (-1, -1) if next_ind_pos is None: index2 = pos_ind[d][1-r] lines[(index1, index2)] = cur_line break cur_ind_pos = next_ind_pos ret_indices = list(j for i in lines for j in i) lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()} return [(Mul.fromiter(nonmatargs), None)] + [ (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items() ] elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] d = collections.defaultdict(list) for res_addend in res: scalar = 1 for elem, indices in res_addend: if indices is None: scalar = elem continue indices = tuple(sorted(indices, key=default_sort_key)) d[indices].append(scalar*remove_matelement(elem, *indices)) scalar = 1 return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args if dimensions is not None: identity = Identity(dimensions[0]) else: identity = S.One return [(MatrixElement(identity, i1, i2), (i1, i2))] elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return [(trace(matrix_symbol), None)] return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))] elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return [(expr, None)]
def as_coeff_mmul(self): return 1, MatMul(self)
def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] matargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 for arg in expr.args: arg_symbol, arg_indices = recurse_expr(arg, index_ranges) if arg_indices is None: nonmatargs.append(arg_symbol) continue i1, i2 = arg_indices pos_arg.append(arg_symbol) pos_ind.append(i1) pos_ind.append(i2) link_ind.extend([None, None]) if i1 in dlinks: other_i1 = dlinks[i1] link_ind[2*counter] = other_i1 link_ind[other_i1] = 2*counter if i2 in dlinks: other_i2 = dlinks[i2] link_ind[2*counter + 1] = other_i2 link_ind[other_i2] = 2*counter + 1 dlinks[i1] = 2*counter dlinks[i2] = 2*counter + 1 counter += 1 cur_ind_pos = link_ind.index(None) first_index = pos_ind[cur_ind_pos] while True: d = cur_ind_pos // 2 r = cur_ind_pos % 2 if r == 1: matargs.append(transpose(pos_arg[d])) else: matargs.append(pos_arg[d]) next_ind_pos = link_ind[2*d + 1 - r] if next_ind_pos is None: last_index = pos_ind[2*d + 1 - r] break cur_ind_pos = next_ind_pos return Mul.fromiter(nonmatargs)*MatMul.fromiter(matargs), (first_index, last_index) elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] res = [ ((transpose(i), (j[1], j[0])) if default_sort_key(j[0]) > default_sort_key(j[1]) else (i, j)) for (i, j) in res ] addends, last_indices = zip(*res) last_indices = list(set(last_indices)) if len(last_indices) > 1: print(last_indices) raise ValueError("incompatible summation") return MatAdd.fromiter(addends), last_indices[0] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args return S.One, (i1, i2) elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return trace(matrix_symbol), None return matrix_symbol, (i1, i2) elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return expr, None