示例#1
0
def test_postprocess_for_cse():
    assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x+y
    assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
    # Note the reverse order of application.
    assert cse_main.postprocess_for_cse(x, [(None, opt1), (None, opt2)]) == x*z+y
示例#2
0
def test_postprocess_for_cse():
    assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x+y
    assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
    assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
    # Note the reverse order of application.
    assert cse_main.postprocess_for_cse(x, [(None, opt1), (None, opt2)]) == x*z+y
示例#3
0
    def get(self, exprs=None, symbols=None):

        if symbols is None:
            symbols = sympy.utilities.iterables.numbered_symbols()
        else:
            # In case we get passed an iterable with an __iter__ method
            # instead of an actual iterator.
            symbols = iter(symbols)

        if isinstance(exprs, sympy.Basic):  # if only one expression is passed
            exprs = [exprs]

        # Find all of the repeated subexpressions.

        ivar_se = {iv: se for se, iv in self._subexp_iv.items()}

        used_ivs = set()
        repeated = set()

        def _find_repeated_subexprs(subexpr):
            if subexpr.is_Atom:
                symbs = [subexpr]
            else:
                symbs = subexpr.args
            for symb in symbs:
                if symb in ivar_se:
                    if symb not in used_ivs:
                        _find_repeated_subexprs(ivar_se[symb])
                        used_ivs.add(symb)
                    else:
                        repeated.add(symb)

        for expr in exprs:
            if expr.is_Matrix:
                expr.applyfunc(_find_repeated_subexprs)
            else:
                _find_repeated_subexprs(expr)

        # Substitute symbols for all of the repeated subexpressions.
        # remove temporary replacements that weren't used more than once

        tmpivs_ivs = dict()
        ordered_iv_se = collections.OrderedDict()

        def _get_subexprs(subexpr):
            if subexpr.is_Atom:
                symb = subexpr
                if symb in ivar_se:
                    if symb in tmpivs_ivs:
                        return tmpivs_ivs[symb]
                    else:
                        subexpr = ivar_se[symb]
                        args = list(map(_get_subexprs, subexpr.args))
                        subexpr = type(subexpr)(*args)
                        if symb in repeated:
                            ivar = next(symbols)
                            ordered_iv_se[ivar] = subexpr
                            tmpivs_ivs[symb] = ivar
                            return ivar
                        else:
                            return subexpr
                else:
                    return symb
            else:
                args = list(map(_get_subexprs, subexpr.args))
                subexpr = type(subexpr)(*args)
                return subexpr

        out_exprs = []
        for expr in exprs:
            if expr.is_Matrix:
                out_exprs.append(expr.applyfunc(_get_subexprs))
            else:
                out_exprs.append(_get_subexprs(expr))

        # Postprocess the expressions to return the expressions to canonical
        # form.
        ordered_iv_se_notopt = ordered_iv_se
        ordered_iv_se = collections.OrderedDict()
        for i, (ivar, subexpr) in enumerate(ordered_iv_se_notopt.items()):
            subexpr = postprocess_for_cse(subexpr, self._optimizations)
            ordered_iv_se[ivar] = subexpr
        out_exprs = [postprocess_for_cse(e, self._optimizations)
                     for e in out_exprs]

        if isinstance(exprs, sympy.Matrix):
            out_exprs = sympy.Matrix(exprs.rows, exprs.cols, out_exprs)
        if self._postprocess is None:
            return list(ordered_iv_se.items()), out_exprs
        return self._postprocess(list(ordered_iv_se.items()), out_exprs)
示例#4
0
文件: cse.py 项目: rpep/fmmgen
def cse(exprs,
        symbols=None,
        optimizations=None,
        postprocess=None,
        order='canonical',
        ignore=(),
        light_ignore=()):
    if isinstance(exprs, (Basic, MatrixBase)):
        exprs = [exprs]

    copy = exprs
    temp = []
    for e in exprs:
        if isinstance(e, (Matrix, ImmutableMatrix)):
            temp.append(Tuple(*e._mat))
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            temp.append(Tuple(*e._smat.items()))
        else:
            temp.append(e)
    exprs = temp
    del temp

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    if symbols is None:
        symbols = numbered_symbols(cls=Symbol)
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
                                           order, ignore, light_ignore)

    # Postprocess the expressions to return the expressions to canonical form.
    exprs = copy
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs
    ]

    # Get the matrices back
    for i, e in enumerate(exprs):
        if isinstance(e, (Matrix, ImmutableMatrix)):
            reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
            if isinstance(e, ImmutableMatrix):
                reduced_exprs[i] = reduced_exprs[i].as_immutable()
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            m = SparseMatrix(e.rows, e.cols, {})
            for k, v in reduced_exprs[i]:
                m[k] = v
            if isinstance(e, ImmutableSparseMatrix):
                m = m.as_immutable()
            reduced_exprs[i] = m

    if postprocess is None:
        return replacements, reduced_exprs

    return postprocess(replacements, reduced_exprs)
示例#5
0
    def get(self, exprs=None, symbols=None):

        if symbols is None:
            symbols = sympy.utilities.iterables.numbered_symbols()
        else:
            # In case we get passed an iterable with an __iter__ method
            # instead of an actual iterator.
            symbols = iter(symbols)

        if isinstance(exprs, sympy.Basic):  # if only one expression is passed
            exprs = [exprs]

        # Find all of the repeated subexpressions.

        ivar_se = {iv: se for se, iv in self._subexp_iv.items()}

        used_ivs = set()
        repeated = set()

        def _find_repeated_subexprs(subexpr):
            if subexpr.is_Atom:
                symbs = [subexpr]
            else:
                symbs = subexpr.args
            for symb in symbs:
                if symb in ivar_se:
                    if symb not in used_ivs:
                        _find_repeated_subexprs(ivar_se[symb])
                        used_ivs.add(symb)
                    else:
                        repeated.add(symb)

        for expr in exprs:
            if expr.is_Matrix:
                expr.applyfunc(_find_repeated_subexprs)
            else:
                _find_repeated_subexprs(expr)

        # Substitute symbols for all of the repeated subexpressions.
        # remove temporary replacements that weren't used more than once

        tmpivs_ivs = dict()
        ordered_iv_se = collections.OrderedDict()

        def _get_subexprs(subexpr):
            if subexpr.is_Atom:
                symb = subexpr
                if symb in ivar_se:
                    if symb in tmpivs_ivs:
                        return tmpivs_ivs[symb]
                    else:
                        subexpr = ivar_se[symb]
                        args = list(map(_get_subexprs, subexpr.args))
                        subexpr = type(subexpr)(*args)
                        if symb in repeated:
                            ivar = next(symbols)
                            ordered_iv_se[ivar] = subexpr
                            tmpivs_ivs[symb] = ivar
                            return ivar
                        else:
                            return subexpr
                else:
                    return symb
            else:
                args = list(map(_get_subexprs, subexpr.args))
                subexpr = type(subexpr)(*args)
                return subexpr

        out_exprs = []
        for expr in exprs:
            if expr.is_Matrix:
                out_exprs.append(expr.applyfunc(_get_subexprs))
            else:
                out_exprs.append(_get_subexprs(expr))

        # Postprocess the expressions to return the expressions to canonical
        # form.
        ordered_iv_se_notopt = ordered_iv_se
        ordered_iv_se = collections.OrderedDict()
        for i, (ivar, subexpr) in enumerate(ordered_iv_se_notopt.items()):
            subexpr = postprocess_for_cse(subexpr, self._optimizations)
            ordered_iv_se[ivar] = subexpr
        out_exprs = [
            postprocess_for_cse(e, self._optimizations) for e in out_exprs
        ]

        if isinstance(exprs, sympy.Matrix):
            out_exprs = sympy.Matrix(exprs.rows, exprs.cols, out_exprs)
        if self._postprocess is None:
            return list(ordered_iv_se.items()), out_exprs
        return self._postprocess(list(ordered_iv_se.items()), out_exprs)