Пример #1
0
def _(expr: ArrayTensorProduct, x: Expr):
    args = expr.args
    addend_list = []
    for i, arg in enumerate(expr.args):
        darg = array_derive(arg, x)
        if darg == 0:
            continue
        args_prev = args[:i]
        args_succ = args[i + 1:]
        shape_prev = reduce(operator.add, map(get_shape, args_prev), ())
        shape_succ = reduce(operator.add, map(get_shape, args_succ), ())
        addend = _array_tensor_product(*args_prev, darg, *args_succ)
        tot1 = len(get_shape(x))
        tot2 = tot1 + len(shape_prev)
        tot3 = tot2 + len(get_shape(arg))
        tot4 = tot3 + len(shape_succ)
        perm = [i for i in range(tot1, tot2)] + \
               [i for i in range(tot1)] + [i for i in range(tot2, tot3)] + \
               [i for i in range(tot3, tot4)]
        addend = _permute_dims(addend, _af_invert(perm))
        addend_list.append(addend)
    if len(addend_list) == 1:
        return addend_list[0]
    elif len(addend_list) == 0:
        return S.Zero
    else:
        return _array_add(*addend_list)
Пример #2
0
def _find_trivial_kronecker_products_broadcast(expr: ArrayTensorProduct):
    newargs: List[Basic] = []
    removed = []
    count_dims = 0
    for i, arg in enumerate(expr.args):
        count_dims += get_rank(arg)
        shape = get_shape(arg)
        current_range = [count_dims - i for i in range(len(shape), 0, -1)]
        if (shape == (1, 1) and len(newargs) > 0
                and 1 not in get_shape(newargs[-1])
                and isinstance(newargs[-1], MatrixExpr)
                and isinstance(arg, MatrixExpr)):
            # KroneckerProduct object allows the trick of broadcasting:
            newargs[-1] = KroneckerProduct(newargs[-1], arg)
            removed.extend(current_range)
        elif 1 not in shape and len(newargs) > 0 and get_shape(
                newargs[-1]) == (1, 1):
            # Broadcast:
            newargs[-1] = KroneckerProduct(newargs[-1], arg)
            prev_range = [
                i for i in range(min(current_range)) if i not in removed
            ]
            removed.extend(prev_range[-2:])
        else:
            newargs.append(arg)
    return _array_tensor_product(*newargs), removed
Пример #3
0
 def shape(self) -> Tuple[Union[sp.Basic, int], ...]:
     parent_shape = get_shape(self.parent)
     shape = [
         _compute_slice_size(idx, axis_size)
         for idx, axis_size in zip_longest(self.indices, parent_shape)
     ]
     return tuple(shape)
Пример #4
0
def _(expr: ArrayContraction, x: Expr):
    fd = array_derive(expr.expr, x)
    rank_x = len(get_shape(x))
    contraction_indices = expr.contraction_indices
    new_contraction_indices = [
        tuple(j + rank_x for j in i) for i in contraction_indices
    ]
    return _array_contraction(fd, *new_contraction_indices)
Пример #5
0
def _(expr: PermuteDims):
    if expr.permutation.array_form == [1, 0]:
        return _a2m_transpose(_array2matrix(expr.expr))
    elif isinstance(expr.expr, ArrayTensorProduct):
        ranks = expr.expr.subranks
        inv_permutation = expr.permutation**(-1)
        newrange = [inv_permutation(i) for i in range(sum(ranks))]
        newpos = []
        counter = 0
        for rank in ranks:
            newpos.append(newrange[counter:counter + rank])
            counter += rank
        newargs = []
        newperm = []
        scalars = []
        for pos, arg in zip(newpos, expr.expr.args):
            if len(pos) == 0:
                scalars.append(_array2matrix(arg))
            elif pos == sorted(pos):
                newargs.append((_array2matrix(arg), pos[0]))
                newperm.extend(pos)
            elif len(pos) == 2:
                newargs.append((_a2m_transpose(_array2matrix(arg)), pos[0]))
                newperm.extend(reversed(pos))
            else:
                raise NotImplementedError()
        newargs = [i[0] for i in newargs]
        return _permute_dims(_a2m_tensor_product(*scalars, *newargs),
                             _af_invert(newperm))
    elif isinstance(expr.expr, ArrayContraction):
        mat_mul_lines = _array2matrix(expr.expr)
        if not isinstance(mat_mul_lines, ArrayTensorProduct):
            flat_cyclic_form = [
                j for i in expr.permutation.cyclic_form for j in i
            ]
            expr_shape = get_shape(expr)
            if all(expr_shape[i] == 1 for i in flat_cyclic_form):
                return mat_mul_lines
            return mat_mul_lines
        # TODO: this assumes that all arguments are matrices, it may not be the case:
        permutation = Permutation(2 * len(mat_mul_lines.args) -
                                  1) * expr.permutation
        permuted = [permutation(i) for i in range(2 * len(mat_mul_lines.args))]
        args_array = [None for i in mat_mul_lines.args]
        for i in range(len(mat_mul_lines.args)):
            p1 = permuted[2 * i]
            p2 = permuted[2 * i + 1]
            if p1 // 2 != p2 // 2:
                return _permute_dims(mat_mul_lines, permutation)
            pos = p1 // 2
            if p1 > p2:
                args_array[i] = _a2m_transpose(
                    mat_mul_lines.args[pos])  # type: ignore
            else:
                args_array[i] = mat_mul_lines.args[pos]  # type: ignore
        return _a2m_tensor_product(*args_array)
    else:
        return expr
Пример #6
0
def _(expr: ArrayAdd):
    rec = [_remove_trivial_dims(arg) for arg in expr.args]
    newargs, removed = zip(*rec)
    if len(set([get_shape(i) for i in newargs])) > 1:
        return expr, []
    if len(removed) == 0:
        return expr, removed
    removed1 = removed[0]
    return _a2m_add(*newargs), removed1
Пример #7
0
 def __new__(
     cls,
     parent: sp.Expr,
     indices: Tuple[Union[sp.Basic, int, slice], ...],
 ) -> "ArraySlice":
     parent_shape = get_shape(parent)
     normalized_indices = []
     for idx, axis_size in zip_longest(indices, parent_shape):
         if idx is None:
             break
         if isinstance(idx, slice):
             new_idx = sp.Tuple(*normalize(idx, axis_size))
         else:
             new_idx = _sympify(_normalize_index(idx, axis_size))
         normalized_indices.append(new_idx)
     return sp.Expr.__new__(cls, parent, sp.Tuple(*normalized_indices))
Пример #8
0
 def __new__(cls, parent: sp.Expr, indices: Iterable) -> "ArrayElement":
     sympified_indices = sp.Tuple(*map(_sympify, indices))
     parent_shape = get_shape(parent)
     if any(
         (i >= s) == True  # noqa: E712
         for i, s in zip(sympified_indices, parent_shape)
     ):
         raise ValueError("shape is out of bounds")
     if len(parent_shape):
         if len(sympified_indices) > len(parent_shape):
             raise IndexError(
                 f"Too many indices for {cls.__name__}: parent"
                 f" {type(parent).__name__} is"
                 f" {len(parent_shape)}-dimensional, but"
                 f" {len(sympified_indices)} indices were given"
             )
         normalized_indices = [
             _normalize_index(i, axis_size)
             for i, axis_size in zip(indices, parent_shape)
         ]
     else:
         normalized_indices = list(indices)
     return sp.Expr.__new__(cls, parent, sp.Tuple(*normalized_indices))
Пример #9
0
 def do_convert(self, expr, indices):
     if isinstance(expr, ArrayTensorProduct):
         cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
         indices_grp = [indices[cumul[i]:cumul[i+1]] for i in range(len(expr.args))]
         return Mul.fromiter(self.do_convert(arg, ind) for arg, ind in zip(expr.args, indices_grp))
     if isinstance(expr, ArrayContraction):
         new_indices = [None for i in range(get_rank(expr.expr))]
         limits = []
         bottom_shape = get_shape(expr.expr)
         for contraction_index_grp in expr.contraction_indices:
             d = Dummy(f"d{self.count_dummies}")
             self.count_dummies += 1
             dim = bottom_shape[contraction_index_grp[0]]
             limits.append((d, 0, dim-1))
             for i in contraction_index_grp:
                 new_indices[i] = d
         j = 0
         for i in range(len(new_indices)):
             if new_indices[i] is None:
                 new_indices[i] = indices[j]
                 j += 1
         newexpr = self.do_convert(expr.expr, new_indices)
         return Sum(newexpr, *limits)
     if isinstance(expr, ArrayDiagonal):
         new_indices = [None for i in range(get_rank(expr.expr))]
         ind_pos = expr._push_indices_down(expr.diagonal_indices, list(range(len(indices))), get_rank(expr))
         for i, index in zip(ind_pos, indices):
             if isinstance(i, collections.abc.Iterable):
                 for j in i:
                     new_indices[j] = index
             else:
                 new_indices[i] = index
         newexpr = self.do_convert(expr.expr, new_indices)
         return newexpr
     if isinstance(expr, PermuteDims):
         permuted_indices = _apply_permutation_to_list(expr.permutation, indices)
         return self.do_convert(expr.expr, permuted_indices)
     if isinstance(expr, ArrayAdd):
         return Add.fromiter(self.do_convert(arg, indices) for arg in expr.args)
     if isinstance(expr, _ArrayExpr):
         return expr.__getitem__(tuple(indices))
     if isinstance(expr, ArrayElementwiseApplyFunc):
         return expr.function(self.do_convert(expr.expr, indices))
     if isinstance(expr, Reshape):
         shape_up = expr.shape
         shape_down = get_shape(expr.expr)
         cumul = list(accumulate([1] + list(reversed(shape_up)), operator.mul))
         one_index = Add.fromiter(i*s for i, s in zip(reversed(indices), cumul))
         dest_indices = [None for _ in shape_down]
         c = 1
         for i, e in enumerate(reversed(shape_down)):
             if c == 1:
                 if i == len(shape_down) - 1:
                     dest_indices[i] = one_index
                 else:
                     dest_indices[i] = one_index % e
             elif i == len(shape_down) - 1:
                 dest_indices[i] = one_index // c
             else:
                 dest_indices[i] = one_index // c % e
             c *= e
         dest_indices.reverse()
         return self.do_convert(expr.expr, dest_indices)
     return _get_array_element_or_slice(expr, indices)
Пример #10
0
def _convert_indexed_to_array(expr):
    if isinstance(expr, Sum):
        function = expr.function
        summation_indices = expr.variables
        subexpr, subindices = _convert_indexed_to_array(function)
        subindicessets = {
            j: i
            for i in subindices if isinstance(i, frozenset) for j in i
        }
        summation_indices = sorted(set(
            [subindicessets.get(i, i) for i in summation_indices]),
                                   key=default_sort_key)
        # TODO: check that Kronecker delta is only contracted to one other element:
        kronecker_indices = set([])
        if isinstance(function, Mul):
            for arg in function.args:
                if not isinstance(arg, KroneckerDelta):
                    continue
                arg_indices = sorted(set(arg.indices), key=default_sort_key)
                if len(arg_indices) == 2:
                    kronecker_indices.update(arg_indices)
        kronecker_indices = sorted(kronecker_indices, key=default_sort_key)
        # Check dimensional consistency:
        shape = get_shape(subexpr)
        if shape:
            for ind, istart, iend in expr.limits:
                i = _get_argindex(subindices, ind)
                if istart != 0 or iend + 1 != shape[i]:
                    raise ValueError(
                        "summation index and array dimension mismatch: %s" %
                        ind)
        contraction_indices = []
        subindices = list(subindices)
        if isinstance(subexpr, ArrayDiagonal):
            diagonal_indices = list(subexpr.diagonal_indices)
            dindices = subindices[-len(diagonal_indices):]
            subindices = subindices[:-len(diagonal_indices)]
            for index in summation_indices:
                if index in dindices:
                    position = dindices.index(index)
                    contraction_indices.append(diagonal_indices[position])
                    diagonal_indices[position] = None
            diagonal_indices = [i for i in diagonal_indices if i is not None]
            for i, ind in enumerate(subindices):
                if ind in summation_indices:
                    pass
            if diagonal_indices:
                subexpr = _array_diagonal(subexpr.expr, *diagonal_indices)
            else:
                subexpr = subexpr.expr

        axes_contraction = defaultdict(list)
        for i, ind in enumerate(subindices):
            include = all(j not in kronecker_indices
                          for j in ind) if isinstance(
                              ind, frozenset) else ind not in kronecker_indices
            if ind in summation_indices and include:
                axes_contraction[ind].append(i)
                subindices[i] = None
        for k, v in axes_contraction.items():
            if any(i in kronecker_indices for i in k) if isinstance(
                    k, frozenset) else k in kronecker_indices:
                continue
            contraction_indices.append(tuple(v))
        free_indices = [i for i in subindices if i is not None]
        indices_ret = list(free_indices)
        indices_ret.sort(key=lambda x: free_indices.index(x))
        return _array_contraction(
            subexpr, *contraction_indices,
            free_indices=free_indices), tuple(indices_ret)
    if isinstance(expr, Mul):
        args, indices = zip(
            *[_convert_indexed_to_array(arg) for arg in expr.args])
        # Check if there are KroneckerDelta objects:
        kronecker_delta_repl = {}
        for arg in args:
            if not isinstance(arg, KroneckerDelta):
                continue
            # Diagonalize two indices:
            i, j = arg.indices
            kindices = set(arg.indices)
            if i in kronecker_delta_repl:
                kindices.update(kronecker_delta_repl[i])
            if j in kronecker_delta_repl:
                kindices.update(kronecker_delta_repl[j])
            kindices = frozenset(kindices)
            for index in kindices:
                kronecker_delta_repl[index] = kindices
        # Remove KroneckerDelta objects, their relations should be handled by
        # ArrayDiagonal:
        newargs = []
        newindices = []
        for arg, loc_indices in zip(args, indices):
            if isinstance(arg, KroneckerDelta):
                continue
            newargs.append(arg)
            newindices.append(loc_indices)
        flattened_indices = [
            kronecker_delta_repl.get(j, j) for i in newindices for j in i
        ]
        diagonal_indices, ret_indices = _get_diagonal_indices(
            flattened_indices)
        tp = _array_tensor_product(*newargs)
        if diagonal_indices:
            return _array_diagonal(tp, *diagonal_indices), ret_indices
        else:
            return tp, ret_indices
    if isinstance(expr, MatrixElement):
        indices = expr.args[1:]
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.args[0],
                                   *diagonal_indices), ret_indices
        else:
            return expr.args[0], ret_indices
    if isinstance(expr, ArrayElement):
        indices = expr.indices
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.name, *diagonal_indices), ret_indices
        else:
            return expr.name, ret_indices
    if isinstance(expr, Indexed):
        indices = expr.indices
        diagonal_indices, ret_indices = _get_diagonal_indices(indices)
        if diagonal_indices:
            return _array_diagonal(expr.base, *diagonal_indices), ret_indices
        else:
            return expr.args[0], ret_indices
    if isinstance(expr, IndexedBase):
        raise NotImplementedError
    if isinstance(expr, KroneckerDelta):
        return expr, expr.indices
    if isinstance(expr, Add):
        args, indices = zip(
            *[_convert_indexed_to_array(arg) for arg in expr.args])
        args = list(args)
        # Check if all indices are compatible. Otherwise expand the dimensions:
        index0 = []
        shape0 = []
        for arg, arg_indices in zip(args, indices):
            arg_indices_set = set(arg_indices)
            arg_indices_missing = arg_indices_set.difference(index0)
            index0.extend([i for i in arg_indices if i in arg_indices_missing])
            arg_shape = get_shape(arg)
            shape0.extend([
                arg_shape[i] for i, e in enumerate(arg_indices)
                if e in arg_indices_missing
            ])
        for i, (arg, arg_indices) in enumerate(zip(args, indices)):
            if len(arg_indices) < len(index0):
                missing_indices_pos = [
                    i for i, e in enumerate(index0) if e not in arg_indices
                ]
                missing_shape = [shape0[i] for i in missing_indices_pos]
                arg_indices = tuple(index0[j]
                                    for j in missing_indices_pos) + arg_indices
                args[i] = _array_tensor_product(OneArray(*missing_shape),
                                                args[i])
            permutation = Permutation([arg_indices.index(j) for j in index0])
            # Perform index permutations:
            args[i] = _permute_dims(args[i], permutation)
        return _array_add(*args), tuple(index0)
    if isinstance(expr, Pow):
        subexpr, subindices = _convert_indexed_to_array(expr.base)
        if isinstance(expr.exp, (int, Integer)):
            diags = zip(*[(2 * i, 2 * i + 1) for i in range(expr.exp)])
            arr = _array_diagonal(
                _array_tensor_product(*[subexpr for i in range(expr.exp)]),
                *diags)
            return arr, subindices
    if isinstance(expr, Function):
        subexpr, subindices = _convert_indexed_to_array(expr.args[0])
        return ArrayElementwiseApplyFunc(type(expr), subexpr), subindices
    return expr, ()
Пример #11
0
def _(expr: ArrayTensorProduct):
    # Recognize expressions like [x, y] with shape (k, 1, k, 1) as `x*y.T`.
    # The matrix expression has to be equivalent to the tensor product of the
    # matrices, with trivial dimensions (i.e. dim=1) dropped.
    # That is, add contractions over trivial dimensions:

    removed = []
    newargs = []
    cumul = list(accumulate([0] + [get_rank(arg) for arg in expr.args]))
    pending = None
    prev_i = None
    for i, arg in enumerate(expr.args):
        current_range = list(range(cumul[i], cumul[i + 1]))
        if isinstance(arg, OneArray):
            removed.extend(current_range)
            continue
        if not isinstance(arg, (MatrixExpr, MatrixCommon)):
            rarg, rem = _remove_trivial_dims(arg)
            removed.extend(rem)
            newargs.append(rarg)
            continue
        elif getattr(arg, "is_Identity", False):
            if arg.shape == (1, 1):
                # Ignore identity matrices of shape (1, 1) - they are equivalent to scalar 1.
                removed.extend(current_range)
                continue
            k = arg.shape[0]
            if pending == k:
                # OK, there is already
                removed.extend(current_range)
                continue
            elif pending is None:
                newargs.append(arg)
                pending = k
                prev_i = i
            else:
                pending = k
                prev_i = i
                newargs.append(arg)
        elif arg.shape == (1, 1):
            arg, _ = _remove_trivial_dims(arg)
            # Matrix is equivalent to scalar:
            if len(newargs) == 0:
                newargs.append(arg)
            elif 1 in get_shape(newargs[-1]):
                if newargs[-1].shape[1] == 1:
                    newargs[-1] = newargs[-1] * arg
                else:
                    newargs[-1] = arg * newargs[-1]
                removed.extend(current_range)
            else:
                newargs.append(arg)
        elif 1 in arg.shape:
            k = [i for i in arg.shape if i != 1][0]
            if pending is None:
                pending = k
                prev_i = i
                newargs.append(arg)
            elif pending == k:
                prev = newargs[-1]
                if prev.is_Identity:
                    removed.extend([cumul[prev_i], cumul[prev_i] + 1])
                    newargs[-1] = arg
                    prev_i = i
                    continue
                if prev.shape[0] == 1:
                    d1 = cumul[prev_i]
                    prev = _a2m_transpose(prev)
                else:
                    d1 = cumul[prev_i] + 1
                if arg.shape[1] == 1:
                    d2 = cumul[i] + 1
                    arg = _a2m_transpose(arg)
                else:
                    d2 = cumul[i]
                newargs[-1] = prev * arg
                pending = None
                removed.extend([d1, d2])
            else:
                newargs.append(arg)
                pending = k
                prev_i = i
        else:
            newargs.append(arg)
            pending = None
    return _a2m_tensor_product(*newargs), sorted(removed)
Пример #12
0
def _(expr: ArrayDiagonal, x: Expr):
    dsubexpr = array_derive(expr.expr, x)
    rank_x = len(get_shape(x))
    diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
    return _array_diagonal(dsubexpr, *diag_indices)
Пример #13
0
def _(expr: Reshape, x: Expr):
    de = array_derive(expr.expr, x)
    return Reshape(de, get_shape(x) + expr.shape)