Ejemplo n.º 1
0
def freeze_expression(expr):
    """
    Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add`
    into, respectively, :class:`devito.Mul` and :class:`devito.Add`.
    """
    if expr.is_Atom or expr.is_Indexed:
        return expr
    elif expr.is_Add:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Add(*rebuilt_args, evaluate=False)
    elif expr.is_Mul:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Mul(*rebuilt_args, evaluate=False)
    elif expr.is_Pow:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Pow(*rebuilt_args, evaluate=False)
    elif expr.is_Equality:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        if isinstance(expr, FrozenExpr):
            # Avoid dropping metadata associated with /expr/
            return expr.func(*rebuilt_args)
        else:
            return Eq(*rebuilt_args, evaluate=False)
    else:
        return expr.func(*[freeze_expression(e) for e in expr.args])
Ejemplo n.º 2
0
def freeze(expr):
    """
    Reconstruct ``expr`` turning all sympy.Mul and sympy.Add
    into FrozenExpr equivalents.
    """
    if expr.is_Atom or expr.is_Indexed:
        return expr
    elif expr.is_Add:
        rebuilt_args = [freeze(e) for e in expr.args]
        return Add(*rebuilt_args, evaluate=False)
    elif expr.is_Mul:
        rebuilt_args = [freeze(e) for e in expr.args]
        return Mul(*rebuilt_args, evaluate=False)
    elif expr.is_Pow:
        rebuilt_args = [freeze(e) for e in expr.args]
        return Pow(*rebuilt_args, evaluate=False)
    elif expr.is_Equality:
        rebuilt_args = [freeze(e) for e in expr.args]
        if isinstance(expr, FrozenExpr):
            # Avoid dropping metadata associated with /expr/
            return expr.func(*rebuilt_args)
        else:
            return Eq(*rebuilt_args, evaluate=False)
    else:
        return expr.func(*[freeze(e) for e in expr.args])
Ejemplo n.º 3
0
def freeze_expression(expr):
    """
    Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add`
    into, respectively, :class:`devito.Mul` and :class:`devito.Add`.
    """
    if expr.is_Atom or expr.is_Indexed:
        return expr
    elif expr.is_Add:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Add(*rebuilt_args, evaluate=False)
    elif expr.is_Mul:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Mul(*rebuilt_args, evaluate=False)
    elif expr.is_Equality:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Eq(*rebuilt_args, evaluate=False)
    else:
        return expr.func(*[freeze_expression(e) for e in expr.args])
Ejemplo n.º 4
0
def xreplace_constrained(exprs,
                         make,
                         rule=None,
                         costmodel=lambda e: True,
                         repeat=False):
    """
    Unlike ``xreplace``, which replaces all objects specified in a mapper,
    this function replaces all objects satisfying two criteria: ::

        * The "matching rule" -- a function returning True if a node within ``expr``
            satisfies a given property, and as such should be replaced;
        * A "cost model" -- a function triggering replacement only if a certain
            cost (e.g., operation count) is exceeded. This function is optional.

    Note that there is not necessarily a relationship between the set of nodes
    for which the matching rule returns True and those nodes passing the cost
    model check. It might happen for example that, given the expression ``a + b``,
    all of ``a``, ``b``, and ``a + b`` satisfy the matching rule, but only
    ``a + b`` satisfies the cost model.

    :param exprs: The target SymPy expression, or a collection of SymPy expressions.
    :param make: Either a mapper M: K -> V, indicating how to replace an expression
                 in K with a symbol in V, or a function with internal state that,
                 when called, returns unique symbols.
    :param rule: The matching rule (a lambda function). May be left unspecified if
                 ``make`` is a mapper.
    :param costmodel: The cost model (a lambda function, optional).
    :param repeat: Repeatedly apply ``xreplace`` until no more replacements are
                   possible (optional, defaults to False).
    """
    found = OrderedDict()
    rebuilt = []

    # Define /replace()/ based on the user-provided /make/
    if isinstance(make, dict):
        rule = rule if rule is not None else (lambda i: i in make)
        replace = lambda i: make[i]
    else:
        assert callable(make) and callable(rule)

        def replace(expr):
            temporary = found.get(expr)
            if not temporary:
                temporary = make()
                found[expr] = temporary
            return temporary

    def run(expr):
        if expr.is_Atom or expr.is_Indexed:
            return expr, rule(expr)
        elif expr.is_Pow:
            base, flag = run(expr.base)
            if flag and costmodel(base):
                return expr.func(replace(base), expr.exp,
                                 evaluate=False), False
            else:
                return expr.func(base, expr.exp, evaluate=False), flag
        else:
            children = [run(a) for a in expr.args]
            matching = [a for a, flag in children if flag]
            other = [a for a, _ in children if a not in matching]
            if matching:
                matched = expr.func(*matching, evaluate=False)
                if len(matching) == len(children) and rule(expr):
                    # Go look for longer expressions first
                    return matched, True
                elif rule(matched) and costmodel(matched):
                    # Replace what I can replace, then give up
                    rebuilt = expr.func(*(other + [replace(matched)]),
                                        evaluate=False)
                    return rebuilt, False
                else:
                    # Replace flagged children, then give up
                    replaced = [replace(e) for e in matching if costmodel(e)]
                    unreplaced = [e for e in matching if not costmodel(e)]
                    rebuilt = expr.func(*(other + replaced + unreplaced),
                                        evaluate=False)
                    return rebuilt, False
            return expr.func(*other, evaluate=False), False

    # Process the provided expressions
    for expr in as_tuple(exprs):
        assert expr.is_Equality
        root = expr.rhs

        while True:
            ret, flag = run(root)
            if isinstance(make, dict) and root.is_Atom and flag:
                rebuilt.append(
                    expr.func(expr.lhs, replace(root), evaluate=False))
                break
            elif repeat and ret != root:
                root = ret
            else:
                rebuilt.append(expr.func(expr.lhs, ret, evaluate=False))
                break

    # Post-process the output
    found = [Eq(v, k) for k, v in found.items()]

    return found + rebuilt, found
Ejemplo n.º 5
0
def yreplace(exprs,
             make,
             rule=None,
             costmodel=lambda e: True,
             repeat=False,
             eager=False):
    """
    Unlike SymPy's ``xreplace``, which performs structural replacement based on a mapper,
    ``yreplace`` applies replacements using two callbacks:

        * The "matching rule" -- a boolean function telling whether an expression
          honors a certain property.
        * The "cost model" -- a boolean function telling whether an expression exceeds
          a certain (e.g., operation count) cost.

    Parameters
    ----------
    exprs : expr-like or list of expr-like
        One or more expressions searched for replacements.
    make : dict or callable
        Either a mapper of substitution rules (just like in ``xreplace``), or
        or a callable returning unique symbols each time it is called.
    rule : callable, optional
        The matching rule (see above). Unnecessary if ``make`` is a dict.
    costmodel : callable, optional
        The cost model (see above).
    repeat : bool, optional
        If True, repeatedly apply ``xreplace`` until no more replacements are
        possible. Defaults to False.
    eager : bool, optional
        If True, replaces an expression ``e`` as soon as the condition
        ``rule(e) and costmodel(e)`` is True. Otherwise, the search continues
        for larger, more expensive expressions. Defaults to False.

    Notes
    -----
    In general, there is no relationship between the set of expressions for which
    the matching rule gives True and the set of expressions passing the cost test.
    For example, in the expression `a + b` all of `a`, `b` and `a+b` may satisfy
    the matching rule, whereas only `a+b` satisfy the cost test. Likewise, an
    expression may pass the cost test, but not satisfy the matching rule.
    """
    found = OrderedDict()
    rebuilt = []

    # Define `replace()` based on the user-provided `make`
    if isinstance(make, dict):
        rule = rule if rule is not None else (lambda i: i in make)
        replace = lambda i: make[i]
    else:
        assert callable(make) and callable(rule)

        def replace(expr):
            temporary = found.get(expr)
            if not temporary:
                temporary = make()
                found[expr] = temporary
            return temporary

    def run(expr):
        if expr.is_Atom or expr.is_Indexed:
            return expr, rule(expr)
        elif expr.is_Pow:
            base, flag = run(expr.base)
            if flag and costmodel(base):
                return expr.func(replace(base), expr.exp,
                                 evaluate=False), False
            elif flag and costmodel(expr):
                return replace(expr), False
            else:
                return expr.func(base, expr.exp, evaluate=False), rule(expr)
        else:
            children = [run(a) for a in expr.args]
            matching = [a for a, flag in children if flag]
            other = [a for a, _ in children if a not in matching]

            if not matching:
                return expr.func(*other, evaluate=False), False

            if eager is False:
                matched = expr.func(*matching, evaluate=False)
                if len(matching) == len(children) and rule(expr):
                    # Go look for larger expressions first
                    return matched, True
                elif rule(matched) and costmodel(matched):
                    # E.g.: a*b*c*d -> a*r0
                    rebuilt = expr.func(*(other + [replace(matched)]),
                                        evaluate=False)
                    return rebuilt, False
                else:
                    # E.g.: a*b*c*d -> a*r0*r1*r2
                    replaced = [replace(e) for e in matching if costmodel(e)]
                    unreplaced = [e for e in matching if not costmodel(e)]
                    rebuilt = expr.func(*(other + replaced + unreplaced),
                                        evaluate=False)
                    return rebuilt, False
            else:
                replaceable, unreplaced = split(matching,
                                                lambda e: costmodel(e))
                if replaceable:
                    # E.g.: a*b*c*d -> a*r0*r1*r2
                    replaced = [replace(e) for e in replaceable]
                    rebuilt = expr.func(*(other + replaced + unreplaced),
                                        evaluate=False)
                    return rebuilt, False
                matched = expr.func(*matching, evaluate=False)
                if rule(matched) and costmodel(matched):
                    if len(matching) == len(children):
                        # E.g.: a*b*c*d -> r0
                        return replace(matched), False
                    else:
                        # E.g.: a*b*c*d -> a*r0
                        rebuilt = expr.func(*(other + [replace(matched)]),
                                            evaluate=False)
                        return rebuilt, False
                elif len(matching) == len(children) and rule(expr):
                    # Go look for larger expressions
                    return matched, True
                else:
                    # E.g.: a*b*c*d; a,b,a*b replaceable but not satisfying the cost
                    # model, hence giving up as c,d,c*d aren't replaceable
                    return expr.func(*(matching + other),
                                     evaluate=False), False

    # Process the provided expressions
    for expr in as_tuple(exprs):
        assert expr.is_Equality
        root = expr.rhs

        while True:
            ret, flag = run(root)
            if flag and costmodel(ret):
                if expr.lhs.function.is_Array:
                    rebuilt.append(expr)
                else:
                    # Have to replace the whole expr.rhs
                    rebuilt.append(
                        expr.func(expr.lhs, replace(root), evaluate=False))
                break
            elif repeat and ret != root:
                root = ret
            else:
                rebuilt.append(expr.func(expr.lhs, ret, evaluate=False))
                break

    # Post-process the output
    found = [Eq(v, k) for k, v in found.items()]

    return found + rebuilt, found