Esempio n. 1
0
def sympy_key(expr):
    """Get the key for ordering SymPy expressions.

    This function assumes that the given expression is already sympified.
    """

    return count_ops(expr), default_sort_key(expr)
Esempio n. 2
0
 def _normalize(self):
     # Normalization of trace of matrix products. Use transposition and
     # cyclic properties of traces to make sure the arguments of the matrix
     # product are sorted and the first argument is not a trasposition.
     from sympy import MatMul, Transpose, default_sort_key
     trace_arg = self.arg
     if isinstance(trace_arg, MatMul):
         indmin = min(range(len(trace_arg.args)),
                      key=lambda x: default_sort_key(trace_arg.args[x]))
         if isinstance(trace_arg.args[indmin], Transpose):
             trace_arg = Transpose(trace_arg).doit()
             indmin = min(range(len(trace_arg.args)),
                          key=lambda x: default_sort_key(trace_arg.args[x]))
         trace_arg = MatMul.fromiter(trace_arg.args[indmin:] +
                                     trace_arg.args[:indmin])
         return Trace(trace_arg)
     return self
Esempio n. 3
0
 def _sort_fully_contracted_args(cls, expr, contraction_indices):
     if expr.shape is None:
         return expr, contraction_indices
     cumul = list(accumulate([0] + expr.subranks))
     index_blocks = [list(range(cumul[i], cumul[i+1])) for i in range(len(expr.args))]
     contraction_indices_flat = {j for i in contraction_indices for j in i}
     fully_contracted = [all(j in contraction_indices_flat for j in range(cumul[i], cumul[i+1])) for i, arg in enumerate(expr.args)]
     new_pos = sorted(range(len(expr.args)), key=lambda x: (0, default_sort_key(expr.args[x])) if fully_contracted[x] else (1,))
     new_args = [expr.args[i] for i in new_pos]
     new_index_blocks_flat = [j for i in new_pos for j in index_blocks[i]]
     index_permutation_array_form = _af_invert(new_index_blocks_flat)
     new_contraction_indices = [tuple(index_permutation_array_form[j] for j in i) for i in contraction_indices]
     new_contraction_indices = _sort_contraction_indices(new_contraction_indices)
     return ArrayTensorProduct(*new_args), new_contraction_indices
Esempio n. 4
0
    def sort_args_by_name(self):
        """
        Sort arguments in the tensor product so that their order is lexicographical.

        Examples
        ========

        >>> from sympy.tensor.array.expressions.conv_matrix_to_array import convert_matrix_to_array
        >>> from sympy import MatrixSymbol
        >>> from sympy.abc import N
        >>> A = MatrixSymbol("A", N, N)
        >>> B = MatrixSymbol("B", N, N)
        >>> C = MatrixSymbol("C", N, N)
        >>> D = MatrixSymbol("D", N, N)

        >>> cg = convert_matrix_to_array(C*D*A*B)
        >>> cg
        ArrayContraction(ArrayTensorProduct(A, D, C, B), (0, 3), (1, 6), (2, 5))
        >>> cg.sort_args_by_name()
        ArrayContraction(ArrayTensorProduct(A, D, B, C), (0, 3), (1, 4), (2, 7))
        """
        expr = self.expr
        if not isinstance(expr, ArrayTensorProduct):
            return self
        args = expr.args
        sorted_data = sorted(enumerate(args), key=lambda x: default_sort_key(x[1]))
        pos_sorted, args_sorted = zip(*sorted_data)
        reordering_map = {i: pos_sorted.index(i) for i, arg in enumerate(args)}
        contraction_tuples = self._get_contraction_tuples()
        contraction_tuples = [[(reordering_map[j], k) for j, k in i] for i in contraction_tuples]
        c_tp = ArrayTensorProduct(*args_sorted)
        new_contr_indices = self._contraction_tuples_to_contraction_indices(
                c_tp,
                contraction_tuples
        )
        return ArrayContraction(c_tp, *new_contr_indices)
Esempio n. 5
0
 def eqn_simplicity(eqn):
     return len(eqn.lhs.free_symbols), default_sort_key(eqn)
Esempio n. 6
0
 def get_arg_key(x):
     a = trace_arg.args[x]
     if isinstance(a, Transpose):
         a = a.arg
     return default_sort_key(a)
Esempio n. 7
0
def _entity_sort_key(entity):
    return sympy.default_sort_key(entity.value)