def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() mapper = OrderedDict() for c in clusters: pointers = list(mapper) # Find out if any of the existing nodes can be reused index = 0 root = stree for it0, it1 in zip(c.itintervals, pointers): if it0 != it1: break root = mapper[it0] index += 1 if it0.dim in c.guards: break # The reused sub-trees might acquire some new sub-iterators for i in pointers[:index]: mapper[i].ispace = IterationSpace.union(mapper[i].ispace, c.ispace.project([i.dim])) # Nested sub-trees, instead, will not be used anymore for i in pointers[index:]: mapper.pop(i) # Add in Iterations for i in c.itintervals[index:]: root = NodeIteration(c.ispace.project([i.dim]), root, c.properties.get(i.dim)) mapper[i] = root # Add in Expressions NodeExprs(c.exprs, c.ispace, c.dspace, c.ops, c.traffic, root) # Add in Conditionals and Syncs, which chop down the reuse tree drop = None for k, v in [(UniteratedInterval, stree)] + list(mapper.items()): if drop: mapper.pop(k) if k.dim in c.syncs: node = NodeSync(c.syncs[k.dim]) v.last.parent = node node.parent = v drop = True if k.dim in c.guards: node = NodeConditional(c.guards[k.dim]) v.last.parent = node node.parent = v drop = True return stree
def st_schedule(clusters): """ Arrange an iterable of :class:`Cluster`s into a :class:`ScheduleTree`. """ stree = ScheduleTree() mapper = OrderedDict() for c in clusters: pointers = list(mapper) # Find out if any of the existing nodes can be reused index = 0 root = stree for it0, it1 in zip(c.itintervals, pointers): if it0 != it1 or it0.dim in c.atomics: break root = mapper[it0] index += 1 if it0.dim in c.guards: break # The reused sub-trees might acquire some new sub-iterators for i in pointers[:index]: mapper[i].ispace = IterationSpace.merge(mapper[i].ispace, c.ispace.project([i.dim])) # Later sub-trees, instead, will not be used anymore for i in pointers[index:]: mapper.pop(i) # Add in Iterations for i in c.itintervals[index:]: root = NodeIteration(c.ispace.project([i.dim]), root) mapper[i] = root # Add in Expressions NodeExprs(c.exprs, c.ispace, c.dspace, c.shape, c.ops, c.traffic, root) # Add in Conditionals for k, v in mapper.items(): if k.dim in c.guards: node = NodeConditional(c.guards[k.dim]) v.last.parent = node node.parent = v return stree
def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = None mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: # Add in any Conditionals and Syncs outside of the outermost Iteration tip = attach_metadata(c, None, stree) if tip is stree: pointers = list(mapper) else: pointers = [] index = 0 for it0, it1 in zip(c.itintervals, pointers): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or syncops cannot be further nested if c.guards.get(d) != prev.guards.get(d) or \ c.syncs.get(d) != prev.syncs.get(d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in pointers[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d)) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions NodeExprs(c.exprs, c.ispace, c.dspace, c.ops, c.traffic, tip) # Prepare for next iteration prev = c return stree
def stree_schedule(clusters): """ Arrange an iterable of Clusters into a ScheduleTree. """ stree = ScheduleTree() prev = Cluster(None) mapper = DefaultOrderedDict(lambda: Bunch(top=None, bottom=None)) def reuse_metadata(c0, c1, d): return (c0.guards.get(d) == c1.guards.get(d) and c0.syncs.get(d) == c1.syncs.get(d)) def attach_metadata(cluster, d, tip): if d in cluster.guards: tip = NodeConditional(cluster.guards[d], tip) if d in cluster.syncs: tip = NodeSync(cluster.syncs[d], tip) return tip for c in clusters: index = 0 # Reuse or add in any Conditionals and Syncs outside of the outermost Iteration if not reuse_metadata(c, prev, None): tip = attach_metadata(c, None, stree) maybe_reusable = [] else: try: tip = mapper[prev.itintervals[index]].top.parent except IndexError: tip = stree maybe_reusable = prev.itintervals for it0, it1 in zip(c.itintervals, maybe_reusable): if it0 != it1: break index += 1 d = it0.dim # The reused sub-trees might acquire new sub-iterators as well as # new properties mapper[it0].top.ispace = IterationSpace.union( mapper[it0].top.ispace, c.ispace.project([d])) mapper[it0].top.properties = normalize_properties( mapper[it0].top.properties, c.properties[it0.dim]) # Different guards or SyncOps cannot further be nested if not reuse_metadata(c, prev, d): tip = mapper[it0].top tip = attach_metadata(c, d, tip) mapper[it0].bottom = tip break else: tip = mapper[it0].bottom # Nested sub-trees, instead, will not be used anymore for it in prev.itintervals[index:]: mapper.pop(it) # Add in Iterations, Conditionals, and Syncs for it in c.itintervals[index:]: d = it.dim tip = NodeIteration(c.ispace.project([d]), tip, c.properties.get(d, ())) mapper[it].top = tip tip = attach_metadata(c, d, tip) mapper[it].bottom = tip # Add in Expressions exprs = [] for conditionals, g in groupby(c.exprs, key=lambda e: e.conditionals): exprs = list(g) # Indirect ConditionalDimensions induce expression-level guards if conditionals: guard = And(*conditionals.values(), evaluate=False) parent = NodeConditional(guard, tip) else: parent = tip NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) # Prepare for next iteration prev = c return stree