示例#1
0
    def _schedule_expressions(self, clusters):
        """Wrap :class:`Expression` objects, already grouped in :class:`Cluster`
        objects, within nested :class:`Iteration` objects (representing loops),
        according to dimensions and stencils."""

        # Topologically sort Iterations
        ordering = partial_order([i.stencil.dimensions for i in clusters])
        for i, d in enumerate(list(ordering)):
            if d.is_Buffered:
                ordering.insert(i, d.parent)

        # Build the Iteration/Expression tree
        processed = []
        schedule = OrderedDict()
        atomics = ()
        for i in clusters:
            # Build the Expression objects to be inserted within an Iteration tree
            expressions = [Expression(v, np.int32 if i.trace.is_index(k) else self.dtype)
                           for k, v in i.trace.items()]

            if not i.stencil.empty:
                root = None
                entries = i.stencil.entries

                # Reorder based on the globally-established loop ordering
                entries = sorted(entries, key=lambda i: ordering.index(i.dim))

                # Can I reuse any of the previously scheduled Iterations ?
                index = 0
                for j0, j1 in zip(entries, list(schedule)):
                    if j0 != j1 or j0.dim in atomics:
                        break
                    root = schedule[j1]
                    index += 1
                needed = entries[index:]

                # Build and insert the required Iterations
                iters = [Iteration([], j.dim, j.dim.size, offsets=j.ofs) for j in needed]
                body, tree = compose_nodes(iters + [expressions], retrieve=True)
                scheduling = OrderedDict(zip(needed, tree))
                if root is None:
                    processed.append(body)
                    schedule = scheduling
                else:
                    nodes = list(root.nodes) + [body]
                    mapper = {root: root._rebuild(nodes, **root.args_frozen)}
                    transformer = Transformer(mapper)
                    processed = list(transformer.visit(processed))
                    schedule = OrderedDict(list(schedule.items())[:index] +
                                           list(scheduling.items()))
                    for k, v in list(schedule.items()):
                        schedule[k] = transformer.rebuilt.get(v, v)
            else:
                # No Iterations are needed
                processed.extend(expressions)

            # Track dimensions that cannot be fused at next stage
            atomics = i.atomics

        return List(body=processed)
示例#2
0
def dimension_sort(expr, key=None):
    """
    Topologically sort the :class:`Dimension`s in ``expr``, based on the order
    in which they are encountered when visiting ``expr``.

    :param expr: The :class:`sympy.Eq` from which the :class:`Dimension`s are
                 extracted. They can appear both as array indices or as free
                 symbols.
    :param key: A callable used as key to enforce a final ordering.
    """
    # Get the Indexed dimensions, in appearance order
    constraints = [
        tuple(i.indices) for i in retrieve_indexed(expr, mode='all')
    ]
    for i, constraint in enumerate(list(constraints)):
        normalized = []
        for j in constraint:
            found = [d for d in j.free_symbols if isinstance(d, Dimension)]
            normalized.extend([d for d in found if d not in normalized])
        constraints[i] = normalized
    ordering = partial_order(constraints)

    # Add any leftover free dimensions (not an Indexed' index)
    dimensions = [i for i in expr.free_symbols if isinstance(i, Dimension)]
    dimensions = filter_sorted(dimensions,
                               key=attrgetter('name'))  # for determinism
    ordering.extend([i for i in dimensions if i not in ordering])

    # Add parent dimensions
    derived = [i for i in ordering if i.is_Derived]
    for i in derived:
        ordering.insert(ordering.index(i), i.parent)

    return sorted(ordering, key=lambda i: not i.is_Time)
示例#3
0
def dimension_sort(expr, key=None):
    """
    Topologically sort the :class:`Dimension`s in ``expr``, based on the order
    in which they are encountered when visiting ``expr``.

    :param expr: The :class:`sympy.Eq` from which the :class:`Dimension`s are
                 extracted. They can appear both as array indices or as free
                 symbols.
    :param key: A callable used as key to enforce a final ordering.
    """
    # Get all Indexed dimensions, in the same order as the appear in /expr/
    constraints = []
    for i in retrieve_indexed(expr, mode='all'):
        constraint = []
        for ai, fi in zip(i.indices, i.base.function.indices):
            if ai.is_Number:
                constraint.append(fi)
            else:
                constraint.extend([d for d in ai.free_symbols
                                   if isinstance(d, Dimension) and d not in constraint])
        constraints.append(tuple(constraint))
    ordering = partial_order(constraints)

    # Add any leftover free dimensions (not an Indexed' index)
    dimensions = [i for i in expr.free_symbols if isinstance(i, Dimension)]
    dimensions = filter_sorted(dimensions, key=attrgetter('name'))  # for determinism
    ordering.extend([i for i in dimensions if i not in ordering])

    # Add parent dimensions
    derived = [i for i in ordering if i.is_Derived]
    for i in derived:
        ordering.insert(ordering.index(i), i.parent)

    return sorted(ordering, key=lambda i: not i.is_Time)
示例#4
0
    def extract(cls, expr):
        """
        Compute the stencil of ``expr``.
        """
        assert expr.is_Equality

        # Collect all indexed objects appearing in /expr/
        terminals = retrieve_terminals(expr, mode='all')
        indexeds = [i for i in terminals if i.is_Indexed]
        indexeds += flatten([retrieve_indexed(i) for i in e.indices]
                            for e in indexeds)

        # Enforce deterministic dimension ordering...
        dims = OrderedDict()
        for e in terminals:
            if isinstance(e, Dimension):
                dims[(e, )] = e
            elif e.is_Indexed:
                d = []
                for a in e.indices:
                    found = [
                        i for i in a.free_symbols if isinstance(i, Dimension)
                    ]
                    d.extend([i for i in found if i not in d])
                dims[tuple(d)] = e
        # ... giving higher priority to TimeFunction objects; time always go first
        dims = sorted(
            list(dims),
            key=lambda i: not (isinstance(dims[i], Dimension) or dims[i].base.
                               function.is_TimeFunction))
        stencil = Stencil([(i, set()) for i in partial_order(dims)])

        # Determine the points accessed along each dimension
        for e in indexeds:
            for a in e.indices:
                if isinstance(a, Dimension):
                    stencil[a].update([0])
                d = None
                off = [0]
                for i in a.args:
                    if isinstance(i, Dimension):
                        d = i
                    elif i.is_integer:
                        off += [i]
                if d is not None:
                    stencil[d].update(off)

        return stencil
示例#5
0
    def extract(cls, expr):
        """
        Compute the stencil of ``expr``.
        """
        assert expr.is_Equality

        # Collect all indexed objects appearing in /expr/
        indexed = list(retrieve_indexed(expr.lhs))
        indexed += list(retrieve_indexed(expr.rhs))
        indexed += flatten([retrieve_indexed(i) for i in e.indices]
                           for e in indexed)

        # Enforce deterministic ordering
        dims = []
        for e in indexed:
            d = []
            for a in e.indices:
                found = [
                    idx for idx in a.free_symbols
                    if isinstance(idx, Dimension)
                ]
                d.extend([idx for idx in found if idx not in d])
            dims.append(tuple(d))
        stencil = Stencil([(i, set()) for i in partial_order(dims)])

        # Determine the points accessed along each dimension
        for e in indexed:
            for a in e.indices:
                if isinstance(a, Dimension):
                    stencil[a].update([0])
                d = None
                off = [0]
                for idx in a.args:
                    if isinstance(idx, Dimension):
                        d = idx
                    elif idx.is_integer:
                        off += [idx]
                if d is not None:
                    stencil[d].update(off)

        return stencil
示例#6
0
def test_partial_order(elements, expected):
    ordering = partial_order(elements)
    assert ordering == expected