def run(expr):
        # Return semantic (rebuilt expression, factorization candidates)

        if expr.is_Number or expr.is_Symbol:
            return expr, [expr]
        elif q_indexed(expr) or expr.is_Atom:
            return expr, []
        elif expr.is_Add:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])

            w_numbers = [
                i for i in rebuilt if any(j.is_Number for j in i.args)
            ]
            wo_numbers = [i for i in rebuilt if i not in w_numbers]

            w_numbers = collect_const(expr.func(*w_numbers))
            wo_numbers = expr.func(*wo_numbers)

            if aggressive is True and wo_numbers:
                for i in flatten(candidates):
                    wo_numbers = collect(wo_numbers, i)

            rebuilt = expr.func(w_numbers, wo_numbers)
            return rebuilt, []
        elif expr.is_Mul:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            rebuilt = collect_const(expr.func(*rebuilt))
            return rebuilt, flatten(candidates)
        else:
            rebuilt, candidates = zip(*[run(arg) for arg in expr.args])
            return expr.func(*rebuilt), flatten(candidates)
 def run(expr):
     if expr.is_Atom or q_indexed(expr):
         return expr, rule(expr)
     elif expr.is_Pow:
         base, flag = run(expr.base)
         return expr.func(base, expr.exp, evaluate=False), flag
     else:
         children = [run(a) for a in expr.args]
         matching = [a for a, flag in children if flag]
         other = [a for a, _ in children if a not in matching]
         if matching:
             matched = expr.func(*matching, evaluate=False)
             if len(matching) == len(children) and rule(expr):
                 # Go look for longer expressions first
                 return matched, True
             elif rule(matched) and costmodel(matched):
                 # Replace what I can replace, then give up
                 rebuilt = expr.func(*(other + [replace(matched)]),
                                     evaluate=False)
                 return rebuilt, False
             else:
                 # Replace flagged children, then give up
                 replaced = [replace(e) for e in matching if costmodel(e)]
                 unreplaced = [e for e in matching if not costmodel(e)]
                 rebuilt = expr.func(*(other + replaced + unreplaced),
                                     evaluate=False)
                 return rebuilt, False
         return expr.func(*other, evaluate=False), False
Exemple #3
0
    def extract(cls, expr):
        """
        Compute the stencil of ``expr``.
        """
        assert expr.is_Equality

        # Collect all indexed objects appearing in /expr/
        terminals = retrieve_terminals(expr, mode='all')
        indexeds = [i for i in terminals if q_indexed(i)]
        indexeds += flatten([retrieve_indexed(i) for i in e.indices]
                            for e in indexeds)

        # Enforce deterministic dimension ordering...
        dims = OrderedDict()
        for e in terminals:
            if isinstance(e, Dimension):
                dims[(e, )] = e
            elif q_indexed(e):
                d = []
                for a in e.indices:
                    found = [
                        i for i in a.free_symbols if isinstance(i, Dimension)
                    ]
                    d.extend([i for i in found if i not in d])
                dims[tuple(d)] = e
        # ... giving higher priority to TimeData objects; time always go first
        dims = sorted(list(dims),
                      key=lambda i: not (isinstance(dims[i], Dimension) or
                                         dims[i].base.function.is_TimeData))
        stencil = Stencil([(i, set()) for i in partial_order(dims)])

        # Determine the points accessed along each dimension
        for e in indexeds:
            for a in e.indices:
                if isinstance(a, Dimension):
                    stencil[a].update([0])
                d = None
                off = [0]
                for i in a.args:
                    if isinstance(i, Dimension):
                        d = i
                    elif i.is_integer:
                        off += [i]
                if d is not None:
                    stencil[d].update(off)

        return stencil
def freeze_expression(expr):
    """
    Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add`
    into, respectively, :class:`devito.Mul` and :class:`devito.Add`.
    """
    if expr.is_Atom or q_indexed(expr):
        return expr
    elif expr.is_Add:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Add(*rebuilt_args, evaluate=False)
    elif expr.is_Mul:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Mul(*rebuilt_args, evaluate=False)
    elif expr.is_Equality:
        rebuilt_args = [freeze_expression(e) for e in expr.args]
        return Eq(*rebuilt_args, evaluate=False)
    else:
        return expr.func(*[freeze_expression(e) for e in expr.args])
Exemple #5
0
def temporaries_graph(temporaries):
    """
    Create a dependency graph given a list of :class:`sympy.Eq`.
    """

    # Check input is legal and initialize the temporaries graph
    temporaries = [Temporary(*i.args) for i in temporaries]
    nodes = [i.lhs for i in temporaries]
    if len(set(nodes)) != len(nodes):
        raise DSEException("Found redundant node in the TemporariesGraph.")
    graph = TemporariesGraph(zip(nodes, temporaries))

    # Add edges (i.e., reads and readby info) to the graph
    mapper = OrderedDict()
    for i in nodes:
        mapper.setdefault(as_symbol(i), []).append(i)

    for k, v in graph.items():
        # Scalars
        handle = terminals(v.rhs)

        # Tensors (does not inspect indirections such as A[B[i]])
        for i in list(handle):
            if q_indexed(i):
                for idx in i.indices:
                    handle |= terminals(idx)

        # Derive actual reads
        reads = set(flatten([mapper.get(as_symbol(i), []) for i in handle]))

        # Propagate information
        v.reads.update(reads)
        for i in v.reads:
            graph[i].readby.add(k)

    return graph
Exemple #6
0
def collect(exprs):
    """
    Determine groups of aliasing expressions in ``exprs``.

    An expression A aliases an expression B if both A and B apply the same
    operations to the same input operands, with the possibility for indexed objects
    to index into locations at a fixed constant offset in each dimension.

    For example: ::

        exprs = (a[i+1] + b[i+1], a[i+1] + b[j+1], a[i] + c[i],
                 a[i+2] - b[i+2], a[i+2] + b[i], a[i-1] + b[i-1])

    The following expressions in ``exprs`` alias to ``a[i] + b[i]``: ::

        a[i+1] + b[i+1] : same operands and operations, distance along i = 1
        a[i-1] + b[i-1] : same operands and operations, distance along i = -1

    Whereas the following do not: ::

        a[i+1] + b[j+1] : because at least one index differs
        a[i] + c[i] : because at least one of the operands differs
        a[i+2] - b[i+2] : because at least one operation differs
        a[i+2] + b[i] : because distance along ``i`` differ (+2 and +0)
    """
    ExprData = namedtuple('ExprData', 'dimensions offsets')

    # Discard expressions:
    # - that surely won't alias to anything
    # - that are non-scalar
    candidates = OrderedDict()
    for expr in exprs:
        if q_indexed(expr):
            continue
        indexeds = retrieve_indexed(expr.rhs, mode='all')
        if indexeds and not any(q_indirect(i) for i in indexeds):
            handle = calculate_offsets(indexeds)
            if handle:
                candidates[expr.rhs] = ExprData(*handle)

    aliases = OrderedDict()
    mapper = OrderedDict()
    unseen = list(candidates)
    while unseen:
        # Find aliasing expressions
        handle = unseen.pop(0)
        group = [handle]
        for e in list(unseen):
            if compare(handle, e) and\
                    is_translated(candidates[handle].offsets, candidates[e].offsets):
                group.append(e)
                unseen.remove(e)
        mapper.update([(i, group) for i in group])

        # Try creating a basis for the aliasing expressions' offsets
        offsets = [tuple(candidates[e].offsets) for e in group]
        try:
            COM, distances = calculate_COM(offsets)
        except DSEException:
            # Ignore these potential aliases and move on
            continue

        alias = create_alias(handle, COM)
        # In circumstances in which an expression has repeated coefficients, e.g.
        # ... + 0.025*a[...] + 0.025*b[...],
        # We may have found a common basis (i.e., same COM, same alias) at this point
        v = aliases.setdefault(alias, Alias(alias, candidates[handle].dimensions))
        v.extend(group, distances)

    # Heuristically attempt to relax the aliases offsets
    # to maximize the likelyhood of loop fusion
    grouped = OrderedDict()
    for i in aliases.values():
        grouped.setdefault(i.dimensions, []).append(i)
    for dimensions, group in grouped.items():
        ideal_anti_stencil = Stencil.union(*[i.anti_stencil for i in group])
        for i in group:
            if i.anti_stencil.subtract(ideal_anti_stencil).empty:
                aliases[i.alias] = i.relax(ideal_anti_stencil)

    return mapper, aliases