Пример #1
0
    def doit(self, **hints):
        from sympy.assumptions import ask, Q
        from sympy import Transpose, Mul, MatMul
        from sympy import MatrixBase, eye

        vector = self._vector
        # This accounts for shape (1, 1) and identity matrices, among others:
        if ask(Q.diagonal(vector)):
            return vector
        if isinstance(vector, MatrixBase):
            ret = eye(max(vector.shape))
            for i in range(ret.shape[0]):
                ret[i, i] = vector[i]
            return type(vector)(ret)
        if vector.is_MatMul:
            matrices = [arg for arg in vector.args if arg.is_Matrix]
            scalars = [arg for arg in vector.args if arg not in matrices]
            if scalars:
                return (
                    Mul.fromiter(scalars)
                    * DiagMatrix(MatMul.fromiter(matrices).doit()).doit()
                )
        if isinstance(vector, Transpose):
            vector = vector.arg
        return DiagMatrix(vector)
Пример #2
0
def convert_matrix_to_array(expr: MatrixExpr) -> Basic:
    if isinstance(expr, MatMul):
        args_nonmat = []
        args = []
        for arg in expr.args:
            if isinstance(arg, MatrixExpr):
                args.append(arg)
            else:
                args_nonmat.append(convert_matrix_to_array(arg))
        contractions = [(2 * i + 1, 2 * i + 2) for i in range(len(args) - 1)]
        scalar = ArrayTensorProduct.fromiter(
            args_nonmat) if args_nonmat else S.One
        if scalar == 1:
            tprod = ArrayTensorProduct(
                *[convert_matrix_to_array(arg) for arg in args])
        else:
            tprod = ArrayTensorProduct(
                scalar, *[convert_matrix_to_array(arg) for arg in args])
        return ArrayContraction(tprod, *contractions)
    elif isinstance(expr, MatAdd):
        return ArrayAdd(*[convert_matrix_to_array(arg) for arg in expr.args])
    elif isinstance(expr, Transpose):
        return PermuteDims(convert_matrix_to_array(expr.args[0]), [1, 0])
    elif isinstance(expr, Trace):
        inner_expr = convert_matrix_to_array(expr.arg)
        return ArrayContraction(inner_expr, (0, len(inner_expr.shape) - 1))
    elif isinstance(expr, Mul):
        return ArrayTensorProduct.fromiter(
            convert_matrix_to_array(i) for i in expr.args)
    elif isinstance(expr, Pow):
        base = convert_matrix_to_array(expr.base)
        if (expr.exp > 0) == True:
            return ArrayTensorProduct.fromiter(base for i in range(expr.exp))
        else:
            return expr
    elif isinstance(expr, MatPow):
        base = convert_matrix_to_array(expr.base)
        if expr.exp.is_Integer != True:
            b = symbols("b", cls=Dummy)
            return ArrayElementwiseApplyFunc(Lambda(b, b**expr.exp),
                                             convert_matrix_to_array(base))
        elif (expr.exp > 0) == True:
            return convert_matrix_to_array(
                MatMul.fromiter(base for i in range(expr.exp)))
        else:
            return expr
    elif isinstance(expr, HadamardProduct):
        tp = ArrayTensorProduct.fromiter(expr.args)
        diag = [[2 * i for i in range(len(expr.args))],
                [2 * i + 1 for i in range(len(expr.args))]]
        return ArrayDiagonal(tp, *diag)
    elif isinstance(expr, HadamardPower):
        base, exp = expr.args
        return convert_matrix_to_array(
            HadamardProduct.fromiter(base for i in range(exp)))
    else:
        return expr
Пример #3
0
def _find_trivial_matrices_rewrite(expr: ArrayTensorProduct):
    # If there are matrices of trivial shape in the tensor product (i.e. shape
    # (1, 1)), try to check if there is a suitable non-trivial MatMul where the
    # expression can be inserted.

    # For example, if "a" has shape (1, 1) and "b" has shape (k, 1), the
    # expressions "_array_tensor_product(a, b*b.T)" can be rewritten as
    # "b*a*b.T"

    trivial_matrices = []
    pos: Optional[int] = None
    first: Optional[MatrixExpr] = None
    second: Optional[MatrixExpr] = None
    removed: List[int] = []
    counter: int = 0
    args: List[Optional[Basic]] = [i for i in expr.args]
    for i, arg in enumerate(expr.args):
        if isinstance(arg, MatrixExpr):
            if arg.shape == (1, 1):
                trivial_matrices.append(arg)
                args[i] = None
                removed.extend([counter, counter + 1])
            elif pos is None and isinstance(arg, MatMul):
                margs = arg.args
                for j, e in enumerate(margs):
                    if isinstance(e, MatrixExpr) and e.shape[1] == 1:
                        pos = i
                        first = MatMul.fromiter(margs[:j + 1])
                        second = MatMul.fromiter(margs[j + 1:])
                        break
        counter += get_rank(arg)
    if pos is None:
        return expr, []
    args[pos] = (first * MatMul.fromiter(i for i in trivial_matrices) *
                 second).doit()
    return _array_tensor_product(*[i for i in args if i is not None]), removed
Пример #4
0
 def doit(self, **hints):
     from sympy.assumptions import ask, Q
     from sympy import Transpose, Mul, MatMul
     vector = self._vector
     # This accounts for shape (1, 1) and identity matrices, among others:
     if ask(Q.diagonal(vector)):
         return vector
     if vector.is_MatMul:
         matrices = [arg for arg in vector.args if arg.is_Matrix]
         scalars = [arg for arg in vector.args if arg not in matrices]
         if scalars:
             return Mul.fromiter(scalars)*DiagonalizeVector(MatMul.fromiter(matrices).doit()).doit()
     if isinstance(vector, Transpose):
         vector = vector.arg
     return DiagonalizeVector(vector)
Пример #5
0
 def _normalize(self):
     # Normalization of trace of matrix products. Use transposition and
     # cyclic properties of traces to make sure the arguments of the matrix
     # product are sorted and the first argument is not a trasposition.
     from sympy import MatMul, Transpose, default_sort_key
     trace_arg = self.arg
     if isinstance(trace_arg, MatMul):
         indmin = min(range(len(trace_arg.args)),
                      key=lambda x: default_sort_key(trace_arg.args[x]))
         if isinstance(trace_arg.args[indmin], Transpose):
             trace_arg = Transpose(trace_arg).doit()
             indmin = min(range(len(trace_arg.args)),
                          key=lambda x: default_sort_key(trace_arg.args[x]))
         trace_arg = MatMul.fromiter(trace_arg.args[indmin:] +
                                     trace_arg.args[:indmin])
         return Trace(trace_arg)
     return self
Пример #6
0
 def recurse_expr(expr, index_ranges={}):
     if expr.is_Mul:
         nonmatargs = []
         pos_arg = []
         pos_ind = []
         dlinks = {}
         link_ind = []
         counter = 0
         args_ind = []
         for arg in expr.args:
             retvals = recurse_expr(arg, index_ranges)
             assert isinstance(retvals, list)
             if isinstance(retvals, list):
                 for i in retvals:
                     args_ind.append(i)
             else:
                 args_ind.append(retvals)
         for arg_symbol, arg_indices in args_ind:
             if arg_indices is None:
                 nonmatargs.append(arg_symbol)
                 continue
             if isinstance(arg_symbol, MatrixElement):
                 arg_symbol = arg_symbol.args[0]
             pos_arg.append(arg_symbol)
             pos_ind.append(arg_indices)
             link_ind.append([None]*len(arg_indices))
             for i, ind in enumerate(arg_indices):
                 if ind in dlinks:
                     other_i = dlinks[ind]
                     link_ind[counter][i] = other_i
                     link_ind[other_i[0]][other_i[1]] = (counter, i)
                 dlinks[ind] = (counter, i)
             counter += 1
         counter2 = 0
         lines = {}
         while counter2 < len(link_ind):
             for i, e in enumerate(link_ind):
                 if None in e:
                     line_start_index = (i, e.index(None))
                     break
             cur_ind_pos = line_start_index
             cur_line = []
             index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]]
             while True:
                 d, r = cur_ind_pos
                 if pos_arg[d] != 1:
                     if r % 2 == 1:
                         cur_line.append(transpose(pos_arg[d]))
                     else:
                         cur_line.append(pos_arg[d])
                 next_ind_pos = link_ind[d][1-r]
                 counter2 += 1
                 # Mark as visited, there will be no `None` anymore:
                 link_ind[d] = (-1, -1)
                 if next_ind_pos is None:
                     index2 = pos_ind[d][1-r]
                     lines[(index1, index2)] = cur_line
                     break
                 cur_ind_pos = next_ind_pos
         ret_indices = list(j for i in lines for j in i)
         lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()}
         return [(Mul.fromiter(nonmatargs), None)] + [
             (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items()
         ]
     elif expr.is_Add:
         res = [recurse_expr(i) for i in expr.args]
         d = collections.defaultdict(list)
         for res_addend in res:
             scalar = 1
             for elem, indices in res_addend:
                 if indices is None:
                     scalar = elem
                     continue
                 indices = tuple(sorted(indices, key=default_sort_key))
                 d[indices].append(scalar*remove_matelement(elem, *indices))
                 scalar = 1
         return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()]
     elif isinstance(expr, KroneckerDelta):
         i1, i2 = expr.args
         if dimensions is not None:
             identity = Identity(dimensions[0])
         else:
             identity = S.One
         return [(MatrixElement(identity, i1, i2), (i1, i2))]
     elif isinstance(expr, MatrixElement):
         matrix_symbol, i1, i2 = expr.args
         if i1 in index_ranges:
             r1, r2 = index_ranges[i1]
             if r1 != 0 or matrix_symbol.shape[0] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[0]))
         if i2 in index_ranges:
             r1, r2 = index_ranges[i2]
             if r1 != 0 or matrix_symbol.shape[1] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[1]))
         if (i1 == i2) and (i1 in index_ranges):
             return [(trace(matrix_symbol), None)]
         return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))]
     elif isinstance(expr, Sum):
         return recurse_expr(
             expr.args[0],
             index_ranges={i[0]: i[1:] for i in expr.args[1:]}
         )
     else:
         return [(expr, None)]
 def recurse_expr(expr, index_ranges={}):
     if expr.is_Mul:
         nonmatargs = []
         pos_arg = []
         pos_ind = []
         dlinks = {}
         link_ind = []
         counter = 0
         args_ind = []
         for arg in expr.args:
             retvals = recurse_expr(arg, index_ranges)
             assert isinstance(retvals, list)
             if isinstance(retvals, list):
                 for i in retvals:
                     args_ind.append(i)
             else:
                 args_ind.append(retvals)
         for arg_symbol, arg_indices in args_ind:
             if arg_indices is None:
                 nonmatargs.append(arg_symbol)
                 continue
             if isinstance(arg_symbol, MatrixElement):
                 arg_symbol = arg_symbol.args[0]
             pos_arg.append(arg_symbol)
             pos_ind.append(arg_indices)
             link_ind.append([None]*len(arg_indices))
             for i, ind in enumerate(arg_indices):
                 if ind in dlinks:
                     other_i = dlinks[ind]
                     link_ind[counter][i] = other_i
                     link_ind[other_i[0]][other_i[1]] = (counter, i)
                 dlinks[ind] = (counter, i)
             counter += 1
         counter2 = 0
         lines = {}
         while counter2 < len(link_ind):
             for i, e in enumerate(link_ind):
                 if None in e:
                     line_start_index = (i, e.index(None))
                     break
             cur_ind_pos = line_start_index
             cur_line = []
             index1 = pos_ind[cur_ind_pos[0]][cur_ind_pos[1]]
             while True:
                 d, r = cur_ind_pos
                 if pos_arg[d] != 1:
                     if r % 2 == 1:
                         cur_line.append(transpose(pos_arg[d]))
                     else:
                         cur_line.append(pos_arg[d])
                 next_ind_pos = link_ind[d][1-r]
                 counter2 += 1
                 # Mark as visited, there will be no `None` anymore:
                 link_ind[d] = (-1, -1)
                 if next_ind_pos is None:
                     index2 = pos_ind[d][1-r]
                     lines[(index1, index2)] = cur_line
                     break
                 cur_ind_pos = next_ind_pos
         ret_indices = list(j for i in lines for j in i)
         lines = {k: MatMul.fromiter(v) if len(v) != 1 else v[0] for k, v in lines.items()}
         return [(Mul.fromiter(nonmatargs), None)] + [
             (MatrixElement(a, i, j), (i, j)) for (i, j), a in lines.items()
         ]
     elif expr.is_Add:
         res = [recurse_expr(i) for i in expr.args]
         d = collections.defaultdict(list)
         for res_addend in res:
             scalar = 1
             for elem, indices in res_addend:
                 if indices is None:
                     scalar = elem
                     continue
                 indices = tuple(sorted(indices, key=default_sort_key))
                 d[indices].append(scalar*remove_matelement(elem, *indices))
                 scalar = 1
         return [(MatrixElement(Add.fromiter(v), *k), k) for k, v in d.items()]
     elif isinstance(expr, KroneckerDelta):
         i1, i2 = expr.args
         if dimensions is not None:
             identity = Identity(dimensions[0])
         else:
             identity = S.One
         return [(MatrixElement(identity, i1, i2), (i1, i2))]
     elif isinstance(expr, MatrixElement):
         matrix_symbol, i1, i2 = expr.args
         if i1 in index_ranges:
             r1, r2 = index_ranges[i1]
             if r1 != 0 or matrix_symbol.shape[0] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[0]))
         if i2 in index_ranges:
             r1, r2 = index_ranges[i2]
             if r1 != 0 or matrix_symbol.shape[1] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[1]))
         if (i1 == i2) and (i1 in index_ranges):
             return [(trace(matrix_symbol), None)]
         return [(MatrixElement(matrix_symbol, i1, i2), (i1, i2))]
     elif isinstance(expr, Sum):
         return recurse_expr(
             expr.args[0],
             index_ranges={i[0]: i[1:] for i in expr.args[1:]}
         )
     else:
         return [(expr, None)]
Пример #8
0
 def recurse_expr(expr, index_ranges={}):
     if expr.is_Mul:
         nonmatargs = []
         matargs = []
         pos_arg = []
         pos_ind = []
         dlinks = {}
         link_ind = []
         counter = 0
         for arg in expr.args:
             arg_symbol, arg_indices = recurse_expr(arg, index_ranges)
             if arg_indices is None:
                 nonmatargs.append(arg_symbol)
                 continue
             i1, i2 = arg_indices
             pos_arg.append(arg_symbol)
             pos_ind.append(i1)
             pos_ind.append(i2)
             link_ind.extend([None, None])
             if i1 in dlinks:
                 other_i1 = dlinks[i1]
                 link_ind[2*counter] = other_i1
                 link_ind[other_i1] = 2*counter
             if i2 in dlinks:
                 other_i2 = dlinks[i2]
                 link_ind[2*counter + 1] = other_i2
                 link_ind[other_i2] = 2*counter + 1
             dlinks[i1] = 2*counter
             dlinks[i2] = 2*counter + 1
             counter += 1
         cur_ind_pos = link_ind.index(None)
         first_index = pos_ind[cur_ind_pos]
         while True:
             d = cur_ind_pos // 2
             r = cur_ind_pos % 2
             if r == 1:
                 matargs.append(transpose(pos_arg[d]))
             else:
                 matargs.append(pos_arg[d])
             next_ind_pos = link_ind[2*d + 1 - r]
             if next_ind_pos is None:
                 last_index = pos_ind[2*d + 1 - r]
                 break
             cur_ind_pos = next_ind_pos
         return Mul.fromiter(nonmatargs)*MatMul.fromiter(matargs), (first_index, last_index)
     elif expr.is_Add:
         res = [recurse_expr(i) for i in expr.args]
         res = [
             ((transpose(i), (j[1], j[0]))
                 if default_sort_key(j[0]) > default_sort_key(j[1])
             else (i, j))
             for (i, j) in res
         ]
         addends, last_indices = zip(*res)
         last_indices = list(set(last_indices))
         if len(last_indices) > 1:
             print(last_indices)
             raise ValueError("incompatible summation")
         return MatAdd.fromiter(addends), last_indices[0]
     elif isinstance(expr, KroneckerDelta):
         i1, i2 = expr.args
         return S.One, (i1, i2)
     elif isinstance(expr, MatrixElement):
         matrix_symbol, i1, i2 = expr.args
         if i1 in index_ranges:
             r1, r2 = index_ranges[i1]
             if r1 != 0 or matrix_symbol.shape[0] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[0]))
         if i2 in index_ranges:
             r1, r2 = index_ranges[i2]
             if r1 != 0 or matrix_symbol.shape[1] != r2+1:
                 raise ValueError("index range mismatch: {0} vs. (0, {1})".format(
                     (r1, r2), matrix_symbol.shape[1]))
         if (i1 == i2) and (i1 in index_ranges):
             return trace(matrix_symbol), None
         return matrix_symbol, (i1, i2)
     elif isinstance(expr, Sum):
         return recurse_expr(
             expr.args[0],
             index_ranges={i[0]: i[1:] for i in expr.args[1:]}
         )
     else:
         return expr, None