Example #1
0
def test_preprocess_for_cse():
    assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
    assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
    assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
    assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
    assert cse_main.preprocess_for_cse(
        x, [(opt1, None), (opt2, None)]) == (x + y)*z
Example #2
0
    def collect(self, exprs):

        if isinstance(exprs, sympy.Basic):  # if only one expression is passed
            exprs = [exprs]
            is_single_expr = True
        else:
            is_single_expr = False

        # Preprocess the expressions to give us better optimization
        # opportunities.
        prep_exprs = [preprocess_for_cse(e, self._optimizations)
                      for e in exprs]

        out_exprs = list(map(self._parse, prep_exprs))

        if is_single_expr:
            return out_exprs[0]
        elif isinstance(exprs, sympy.Matrix):
            return sympy.Matrix(exprs.rows, exprs.cols, out_exprs)
        else:
            return out_exprs
Example #3
0
    def collect(self, exprs):

        if isinstance(exprs, sympy.Basic):  # if only one expression is passed
            exprs = [exprs]
            is_single_expr = True
        else:
            is_single_expr = False

        # Preprocess the expressions to give us better optimization
        # opportunities.
        prep_exprs = [
            preprocess_for_cse(e, self._optimizations) for e in exprs
        ]

        out_exprs = list(map(self._parse, prep_exprs))

        if is_single_expr:
            return out_exprs[0]
        elif isinstance(exprs, sympy.Matrix):
            return sympy.Matrix(exprs.rows, exprs.cols, out_exprs)
        else:
            return out_exprs
Example #4
0
File: cse.py Project: rpep/fmmgen
def cse(exprs,
        symbols=None,
        optimizations=None,
        postprocess=None,
        order='canonical',
        ignore=(),
        light_ignore=()):
    if isinstance(exprs, (Basic, MatrixBase)):
        exprs = [exprs]

    copy = exprs
    temp = []
    for e in exprs:
        if isinstance(e, (Matrix, ImmutableMatrix)):
            temp.append(Tuple(*e._mat))
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            temp.append(Tuple(*e._smat.items()))
        else:
            temp.append(e)
    exprs = temp
    del temp

    if optimizations is None:
        optimizations = list()
    elif optimizations == 'basic':
        optimizations = basic_optimizations

    # Preprocess the expressions to give us better optimization opportunities.
    reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]

    if symbols is None:
        symbols = numbered_symbols(cls=Symbol)
    else:
        # In case we get passed an iterable with an __iter__ method instead of
        # an actual iterator.
        symbols = iter(symbols)

    # Find other optimization opportunities.
    opt_subs = opt_cse(reduced_exprs, order)

    # Main CSE algorithm.
    replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
                                           order, ignore, light_ignore)

    # Postprocess the expressions to return the expressions to canonical form.
    exprs = copy
    for i, (sym, subtree) in enumerate(replacements):
        subtree = postprocess_for_cse(subtree, optimizations)
        replacements[i] = (sym, subtree)
    reduced_exprs = [
        postprocess_for_cse(e, optimizations) for e in reduced_exprs
    ]

    # Get the matrices back
    for i, e in enumerate(exprs):
        if isinstance(e, (Matrix, ImmutableMatrix)):
            reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
            if isinstance(e, ImmutableMatrix):
                reduced_exprs[i] = reduced_exprs[i].as_immutable()
        elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
            m = SparseMatrix(e.rows, e.cols, {})
            for k, v in reduced_exprs[i]:
                m[k] = v
            if isinstance(e, ImmutableSparseMatrix):
                m = m.as_immutable()
            reduced_exprs[i] = m

    if postprocess is None:
        return replacements, reduced_exprs

    return postprocess(replacements, reduced_exprs)