def doit(self, **hints): from sympy.assumptions import ask, Q from sympy import Transpose, Mul, MatMul from sympy import MatrixBase, eye vector = self._vector # This accounts for shape (1, 1) and identity matrices, among others: if ask(Q.diagonal(vector)): return vector if isinstance(vector, MatrixBase): ret = eye(max(vector.shape)) for i in range(ret.shape[0]): ret[i, i] = vector[i] return type(vector)(ret) if vector.is_MatMul: matrices = [arg for arg in vector.args if arg.is_Matrix] scalars = [arg for arg in vector.args if arg not in matrices] if scalars: return ( Mul.fromiter(scalars) * DiagMatrix(MatMul.fromiter(matrices).doit()).doit() ) if isinstance(vector, Transpose): vector = vector.arg return DiagMatrix(vector)
def convert_matrix_to_array(expr: MatrixExpr) -> Basic: if isinstance(expr, MatMul): args_nonmat = [] args = [] for arg in expr.args: if isinstance(arg, MatrixExpr): args.append(arg) else: args_nonmat.append(convert_matrix_to_array(arg)) contractions = [(2 * i + 1, 2 * i + 2) for i in range(len(args) - 1)] scalar = ArrayTensorProduct.fromiter( args_nonmat) if args_nonmat else S.One if scalar == 1: tprod = ArrayTensorProduct( *[convert_matrix_to_array(arg) for arg in args]) else: tprod = ArrayTensorProduct( scalar, *[convert_matrix_to_array(arg) for arg in args]) return ArrayContraction(tprod, *contractions) elif isinstance(expr, MatAdd): return ArrayAdd(*[convert_matrix_to_array(arg) for arg in expr.args]) elif isinstance(expr, Transpose): return PermuteDims(convert_matrix_to_array(expr.args[0]), [1, 0]) elif isinstance(expr, Trace): inner_expr = convert_matrix_to_array(expr.arg) return ArrayContraction(inner_expr, (0, len(inner_expr.shape) - 1)) elif isinstance(expr, Mul): return ArrayTensorProduct.fromiter( convert_matrix_to_array(i) for i in expr.args) elif isinstance(expr, Pow): base = convert_matrix_to_array(expr.base) if (expr.exp > 0) == True: return ArrayTensorProduct.fromiter(base for i in range(expr.exp)) else: return expr elif isinstance(expr, MatPow): base = convert_matrix_to_array(expr.base) if expr.exp.is_Integer != True: b = symbols("b", cls=Dummy) return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp), convert_matrix_to_array(base)) elif (expr.exp > 0) == True: return convert_matrix_to_array( MatMul.fromiter(base for i in range(expr.exp))) else: return expr elif isinstance(expr, HadamardProduct): tp = ArrayTensorProduct.fromiter(expr.args) diag = [[2 * i for i in range(len(expr.args))], [2 * i + 1 for i in range(len(expr.args))]] return ArrayDiagonal(tp, *diag) elif isinstance(expr, HadamardPower): base, exp = expr.args return convert_matrix_to_array( HadamardProduct.fromiter(base for i in range(exp))) else: return expr
def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct): # If there are matrices of trivial shape in the tensor product (i.e. shape # (1, 1)), try to check if there is a suitable non-trivial MatMul where the # expression can be inserted. # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as # "b*a*b.T" trivial_matrices = [] pos: Optional[int] = None first: Optional[MatrixExpr] = None second: Optional[MatrixExpr] = None removed: List[int] = [] counter: int = 0 args: List[Optional[Basic]] = [i for i in expr.args] for i, arg in enumerate(expr.args): if isinstance(arg, MatrixExpr): if arg.shape == (1, 1): trivial_matrices.append(arg) args[i] = None removed.extend([counter, counter + 1]) elif pos is None and isinstance(arg, MatMul): margs = arg.args for j, e in enumerate(margs): if isinstance(e, MatrixExpr) and e.shape[1] == 1: pos = i first = MatMul.fromiter(margs[:j + 1]) second = MatMul.fromiter(margs[j + 1:]) break counter += get_rank(arg) if pos is None: return expr, [] args[pos] = (first * MatMul.fromiter(i for i in trivial_matrices) * second).doit() return _array_tensor_product(*[i for i in args if i is not None]), removed
def doit(self, **hints): from sympy.assumptions import ask, Q from sympy import Transpose, Mul, MatMul vector = self._vector # This accounts for shape (1, 1) and identity matrices, among others: if ask(Q.diagonal(vector)): return vector if vector.is_MatMul: matrices = [arg for arg in vector.args if arg.is_Matrix] scalars = [arg for arg in vector.args if arg not in matrices] if scalars: return Mul.fromiter(scalars)*DiagonalizeVector(MatMul.fromiter(matrices).doit()).doit() if isinstance(vector, Transpose): vector = vector.arg return DiagonalizeVector(vector)
def _normalize(self): # Normalization of trace of matrix products. Use transposition and # cyclic properties of traces to make sure the arguments of the matrix # product are sorted and the first argument is not a trasposition. from sympy import MatMul, Transpose, default_sort_key trace_arg = self.arg if isinstance(trace_arg, MatMul): indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) if isinstance(trace_arg.args[indmin], Transpose): trace_arg = Transpose(trace_arg).doit() indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin]) return Trace(trace_arg) return self
def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 args_ind = [] for arg in expr.args: retvals = recurse_expr(arg, index_ranges) assert isinstance(retvals, list) if isinstance(retvals, list): for i in retvals: args_ind.append(i) else: args_ind.append(retvals) for arg_symbol, arg_indices in args_ind: if arg_indices is None: nonmatargs.append(arg_symbol) continue if isinstance(arg_symbol, MatrixElement): arg_symbol = arg_symbol.args[0] pos_arg.append(arg_symbol) pos_ind.append(arg_indices) link_ind.append([None]*len(arg_indices)) for i, ind in enumerate(arg_indices): if ind in dlinks: other_i = dlinks[ind] link_ind[counter][i] = other_i link_ind[other_i[0]][other_i[1]] = (counter, i) dlinks[ind] = (counter, i) counter += 1 counter2 = 0 lines = {} while counter2 < len(link_ind): for i, e in enumerate(link_ind): if None in e: line_start_index = (i, e.index(None)) break cur_ind_pos = line_start_index cur_line = [] index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]] while True: d, r = cur_ind_pos if pos_arg[d] != 1: if r % 2 == 1: cur_line.append(transpose(pos_arg[d])) else: cur_line.append(pos_arg[d]) next_ind_pos = link_ind[d][1-r] counter2 += 1 # Mark as visited, there will be no `None` anymore: link_ind[d] = (-1, -1) if next_ind_pos is None: index2 = pos_ind[d][1-r] lines[(index1, index2)] = cur_line break cur_ind_pos = next_ind_pos ret_indices = list(j for i in lines for j in i) lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()} return [(Mul.fromiter(nonmatargs), None)] + [ (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items() ] elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] d = collections.defaultdict(list) for res_addend in res: scalar = 1 for elem, indices in res_addend: if indices is None: scalar = elem continue indices = tuple(sorted(indices, key=default_sort_key)) d[indices].append(scalar*remove_matelement(elem, *indices)) scalar = 1 return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args if dimensions is not None: identity = Identity(dimensions[0]) else: identity = S.One return [(MatrixElement(identity, i1, i2), (i1, i2))] elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return [(trace(matrix_symbol), None)] return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))] elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return [(expr, None)]
def recurse_expr(expr, index_ranges={}): if expr.is_Mul: nonmatargs = [] matargs = [] pos_arg = [] pos_ind = [] dlinks = {} link_ind = [] counter = 0 for arg in expr.args: arg_symbol, arg_indices = recurse_expr(arg, index_ranges) if arg_indices is None: nonmatargs.append(arg_symbol) continue i1, i2 = arg_indices pos_arg.append(arg_symbol) pos_ind.append(i1) pos_ind.append(i2) link_ind.extend([None, None]) if i1 in dlinks: other_i1 = dlinks[i1] link_ind[2*counter] = other_i1 link_ind[other_i1] = 2*counter if i2 in dlinks: other_i2 = dlinks[i2] link_ind[2*counter + 1] = other_i2 link_ind[other_i2] = 2*counter + 1 dlinks[i1] = 2*counter dlinks[i2] = 2*counter + 1 counter += 1 cur_ind_pos = link_ind.index(None) first_index = pos_ind[cur_ind_pos] while True: d = cur_ind_pos // 2 r = cur_ind_pos % 2 if r == 1: matargs.append(transpose(pos_arg[d])) else: matargs.append(pos_arg[d]) next_ind_pos = link_ind[2*d + 1 - r] if next_ind_pos is None: last_index = pos_ind[2*d + 1 - r] break cur_ind_pos = next_ind_pos return Mul.fromiter(nonmatargs)*MatMul.fromiter(matargs), (first_index, last_index) elif expr.is_Add: res = [recurse_expr(i) for i in expr.args] res = [ ((transpose(i), (j[1], j[0])) if default_sort_key(j[0]) > default_sort_key(j[1]) else (i, j)) for (i, j) in res ] addends, last_indices = zip(*res) last_indices = list(set(last_indices)) if len(last_indices) > 1: print(last_indices) raise ValueError("incompatible summation") return MatAdd.fromiter(addends), last_indices[0] elif isinstance(expr, KroneckerDelta): i1, i2 = expr.args return S.One, (i1, i2) elif isinstance(expr, MatrixElement): matrix_symbol, i1, i2 = expr.args if i1 in index_ranges: r1, r2 = index_ranges[i1] if r1 != 0 or matrix_symbol.shape[0] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[0])) if i2 in index_ranges: r1, r2 = index_ranges[i2] if r1 != 0 or matrix_symbol.shape[1] != r2+1: raise ValueError("index range mismatch: {0} vs. (0, {1})".format( (r1, r2), matrix_symbol.shape[1])) if (i1 == i2) and (i1 in index_ranges): return trace(matrix_symbol), None return matrix_symbol, (i1, i2) elif isinstance(expr, Sum): return recurse_expr( expr.args[0], index_ranges={i[0]: i[1:] for i in expr.args[1:]} ) else: return expr, None