def _schedule_expressions(self, clusters): """Wrap :class:`Expression` objects, already grouped in :class:`Cluster` objects, within nested :class:`Iteration` objects (representing loops), according to dimensions and stencils.""" # Topologically sort Iterations ordering = partial_order([i.stencil.dimensions for i in clusters]) for i, d in enumerate(list(ordering)): if d.is_Buffered: ordering.insert(i, d.parent) # Build the Iteration/Expression tree processed = [] schedule = OrderedDict() atomics = () for i in clusters: # Build the Expression objects to be inserted within an Iteration tree expressions = [Expression(v, np.int32 if i.trace.is_index(k) else self.dtype) for k, v in i.trace.items()] if not i.stencil.empty: root = None entries = i.stencil.entries # Reorder based on the globally-established loop ordering entries = sorted(entries, key=lambda i: ordering.index(i.dim)) # Can I reuse any of the previously scheduled Iterations ? index = 0 for j0, j1 in zip(entries, list(schedule)): if j0 != j1 or j0.dim in atomics: break root = schedule[j1] index += 1 needed = entries[index:] # Build and insert the required Iterations iters = [Iteration([], j.dim, j.dim.size, offsets=j.ofs) for j in needed] body, tree = compose_nodes(iters + [expressions], retrieve=True) scheduling = OrderedDict(zip(needed, tree)) if root is None: processed.append(body) schedule = scheduling else: nodes = list(root.nodes) + [body] mapper = {root: root._rebuild(nodes, **root.args_frozen)} transformer = Transformer(mapper) processed = list(transformer.visit(processed)) schedule = OrderedDict(list(schedule.items())[:index] + list(scheduling.items())) for k, v in list(schedule.items()): schedule[k] = transformer.rebuilt.get(v, v) else: # No Iterations are needed processed.extend(expressions) # Track dimensions that cannot be fused at next stage atomics = i.atomics return List(body=processed)
def dimension_sort(expr, key=None): """ Topologically sort the :class:`Dimension`s in ``expr``, based on the order in which they are encountered when visiting ``expr``. :param expr: The :class:`sympy.Eq` from which the :class:`Dimension`s are extracted. They can appear both as array indices or as free symbols. :param key: A callable used as key to enforce a final ordering. """ # Get the Indexed dimensions, in appearance order constraints = [ tuple(i.indices) for i in retrieve_indexed(expr, mode='all') ] for i, constraint in enumerate(list(constraints)): normalized = [] for j in constraint: found = [d for d in j.free_symbols if isinstance(d, Dimension)] normalized.extend([d for d in found if d not in normalized]) constraints[i] = normalized ordering = partial_order(constraints) # Add any leftover free dimensions (not an Indexed' index) dimensions = [i for i in expr.free_symbols if isinstance(i, Dimension)] dimensions = filter_sorted(dimensions, key=attrgetter('name')) # for determinism ordering.extend([i for i in dimensions if i not in ordering]) # Add parent dimensions derived = [i for i in ordering if i.is_Derived] for i in derived: ordering.insert(ordering.index(i), i.parent) return sorted(ordering, key=lambda i: not i.is_Time)
def dimension_sort(expr, key=None): """ Topologically sort the :class:`Dimension`s in ``expr``, based on the order in which they are encountered when visiting ``expr``. :param expr: The :class:`sympy.Eq` from which the :class:`Dimension`s are extracted. They can appear both as array indices or as free symbols. :param key: A callable used as key to enforce a final ordering. """ # Get all Indexed dimensions, in the same order as the appear in /expr/ constraints = [] for i in retrieve_indexed(expr, mode='all'): constraint = [] for ai, fi in zip(i.indices, i.base.function.indices): if ai.is_Number: constraint.append(fi) else: constraint.extend([d for d in ai.free_symbols if isinstance(d, Dimension) and d not in constraint]) constraints.append(tuple(constraint)) ordering = partial_order(constraints) # Add any leftover free dimensions (not an Indexed' index) dimensions = [i for i in expr.free_symbols if isinstance(i, Dimension)] dimensions = filter_sorted(dimensions, key=attrgetter('name')) # for determinism ordering.extend([i for i in dimensions if i not in ordering]) # Add parent dimensions derived = [i for i in ordering if i.is_Derived] for i in derived: ordering.insert(ordering.index(i), i.parent) return sorted(ordering, key=lambda i: not i.is_Time)
def extract(cls, expr): """ Compute the stencil of ``expr``. """ assert expr.is_Equality # Collect all indexed objects appearing in /expr/ terminals = retrieve_terminals(expr, mode='all') indexeds = [i for i in terminals if i.is_Indexed] indexeds += flatten([retrieve_indexed(i) for i in e.indices] for e in indexeds) # Enforce deterministic dimension ordering... dims = OrderedDict() for e in terminals: if isinstance(e, Dimension): dims[(e, )] = e elif e.is_Indexed: d = [] for a in e.indices: found = [ i for i in a.free_symbols if isinstance(i, Dimension) ] d.extend([i for i in found if i not in d]) dims[tuple(d)] = e # ... giving higher priority to TimeFunction objects; time always go first dims = sorted( list(dims), key=lambda i: not (isinstance(dims[i], Dimension) or dims[i].base. function.is_TimeFunction)) stencil = Stencil([(i, set()) for i in partial_order(dims)]) # Determine the points accessed along each dimension for e in indexeds: for a in e.indices: if isinstance(a, Dimension): stencil[a].update([0]) d = None off = [0] for i in a.args: if isinstance(i, Dimension): d = i elif i.is_integer: off += [i] if d is not None: stencil[d].update(off) return stencil
def extract(cls, expr): """ Compute the stencil of ``expr``. """ assert expr.is_Equality # Collect all indexed objects appearing in /expr/ indexed = list(retrieve_indexed(expr.lhs)) indexed += list(retrieve_indexed(expr.rhs)) indexed += flatten([retrieve_indexed(i) for i in e.indices] for e in indexed) # Enforce deterministic ordering dims = [] for e in indexed: d = [] for a in e.indices: found = [ idx for idx in a.free_symbols if isinstance(idx, Dimension) ] d.extend([idx for idx in found if idx not in d]) dims.append(tuple(d)) stencil = Stencil([(i, set()) for i in partial_order(dims)]) # Determine the points accessed along each dimension for e in indexed: for a in e.indices: if isinstance(a, Dimension): stencil[a].update([0]) d = None off = [0] for idx in a.args: if isinstance(idx, Dimension): d = idx elif idx.is_integer: off += [idx] if d is not None: stencil[d].update(off) return stencil
def test_partial_order(elements, expected): ordering = partial_order(elements) assert ordering == expected