def clusterize(exprs): """Group a sequence of :class:`ir.Eq`s into one or more :class:`Cluster`s.""" # Build a graph capturing the dependencies among the input tensor expressions mapper = OrderedDict() for i, e1 in enumerate(exprs): trace = [e2 for e2 in exprs[:i] if Scope([e2, e1]).has_dep] + [e1] trace.extend([e2 for e2 in exprs[i + 1:] if Scope([e1, e2]).has_dep]) mapper[e1] = Bunch(trace=trace, ispace=e1.ispace) # Derive the iteration spaces queue = list(mapper) while queue: target = queue.pop(0) ispaces = [mapper[i].ispace for i in mapper[target].trace] coerced_ispace = mapper[target].ispace.intersection(*ispaces) if coerced_ispace != mapper[target].ispace: # Something has changed, need to propagate the update mapper[target].ispace = coerced_ispace queue.extend([i for i in mapper[target].trace if i not in queue]) # Build a PartialCluster for each tensor expression clusters = ClusterGroup() for k, v in mapper.items(): if k.is_Tensor: scalars = [i for i in v.trace[:v.trace.index(k)] if i.is_Scalar] clusters.append(PartialCluster(scalars + [k], v.ispace)) # Group PartialClusters together where possible clusters = groupby(clusters) return clusters.finalize()
def linearize(graph, **kwargs): """ Turn n-dimensional Indexeds into 1-dimensional Indexed with suitable index access function, such as `a[i, j]` -> `a[i*n + j]`. The row-major format of the underlying Function objects is honored. """ # Simple data structure to avoid generation of duplicated code cache = defaultdict(lambda: Bunch(stmts0=[], stmts1=[], cbk=None)) linearization(graph, cache=cache, **kwargs)
def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = None mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: # Add in any Conditionals and Syncs outside of the outermost Iteration tip = attach_metadata(c, None, stree) if tip is stree: pointers = list(mapper) else: pointers = [] index = 0 for it0, it1 in zip(c.itintervals, pointers): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or syncops cannot be further nested if c.guards.get(d) != prev.guards.get(d) or \ c.syncs.get(d) != prev.syncs.get(d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in pointers[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d)) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions NodeExprs(c.exprs, c.ispace, c.dspace, c.ops, c.traffic, tip) # Prepare for next iteration prev = c return stree
@cached_property def accesses(self): return tuple(flatten(as_tuple(i.reads) + as_tuple(i.write) for i in self.mapper.values())) @cached_property def is_read(self): return any(av.reads for av in self.mapper.values()) @cached_property def lastwrite(self): for e, av in reversed(self.mapper.items()): if av.write is not None: return e return None AccessTuple = lambda: Bunch(reads=[], write=None) class AccessMapper(OrderedDict): def __init__(self, expressions): mapper = DefaultOrderedDict(lambda: DefaultOrderedDict(AccessTuple)) for e in expressions: for i in retrieve_function_carriers(e.rhs): mapper[i.function][e].reads.append(i) mapper[e.lhs.function][e].write = e.lhs super().__init__([(f, AccessValue(f, mapper[f])) for f in mapper])
def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = Cluster(None) mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def reuse_metadata(c0, c1, d): return (c0.guards.get(d) == c1.guards.get(d) and c0.syncs.get(d) == c1.syncs.get(d)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: index = 0 # Reuse or add in any Conditionals and Syncs outside of the outermost Iteration if not reuse_metadata(c, prev, None): tip = attach_metadata(c, None, stree) maybe_reusable = [] else: try: tip = mapper[prev.itintervals[index]].top.parent except IndexError: tip = stree maybe_reusable = prev.itintervals for it0, it1 in zip(c.itintervals, maybe_reusable): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or SyncOps cannot further be nested if not reuse_metadata(c, prev, d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in prev.itintervals[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d, ())) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions exprs = [] for conditionals, g in groupby(c.exprs, key=lambda e: e.conditionals): exprs = list(g) # Indirect ConditionalDimensions induce expression-level guards if conditionals: guard = And(*conditionals.values(), evaluate=False) parent = NodeConditional(guard, tip) else: parent = tip NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) # Prepare for next iteration prev = c return stree