def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = None mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: # Add in any Conditionals and Syncs outside of the outermost Iteration tip = attach_metadata(c, None, stree) if tip is stree: pointers = list(mapper) else: pointers = [] index = 0 for it0, it1 in zip(c.itintervals, pointers): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or syncops cannot be further nested if c.guards.get(d) != prev.guards.get(d) or \ c.syncs.get(d) != prev.syncs.get(d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in pointers[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d)) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions NodeExprs(c.exprs, c.ispace, c.dspace, c.ops, c.traffic, tip) # Prepare for next iteration prev = c return stree
def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = Cluster(None) mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def reuse_metadata(c0, c1, d): return (c0.guards.get(d) == c1.guards.get(d) and c0.syncs.get(d) == c1.syncs.get(d)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: index = 0 # Reuse or add in any Conditionals and Syncs outside of the outermost Iteration if not reuse_metadata(c, prev, None): tip = attach_metadata(c, None, stree) maybe_reusable = [] else: try: tip = mapper[prev.itintervals[index]].top.parent except IndexError: tip = stree maybe_reusable = prev.itintervals for it0, it1 in zip(c.itintervals, maybe_reusable): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or SyncOps cannot further be nested if not reuse_metadata(c, prev, d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in prev.itintervals[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d, ())) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions exprs = [] for conditionals, g in groupby(c.exprs, key=lambda e: e.conditionals): exprs = list(g) # Indirect ConditionalDimensions induce expression-level guards if conditionals: guard = And(*conditionals.values(), evaluate=False) parent = NodeConditional(guard, tip) else: parent = tip NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) # Prepare for next iteration prev = c return stree