def iet_make(stree): """Create an IET from a ScheduleTree.""" nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Increment(e) if e.is_Increment else Expression(e) for e in i.exprs] body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) elif i.is_Conditional: body = Conditional(i.guard, queues.pop(i)) elif i.is_Iteration: # Order to ensure deterministic code generation uindices = sorted(i.sub_iterators, key=lambda d: d.name) # Generate Iteration body = Iteration(queues.pop(i), i.dim, i.limits, offsets=i.offsets, direction=i.direction, properties=i.properties, uindices=uindices) elif i.is_Section: body = Section('section%d' % nsections, body=queues.pop(i)) nsections += 1 elif i.is_Halo: body = HaloSpot(i.halo_scheme, body=queues.pop(i)) queues.setdefault(i.parent, []).append(body) assert False
def iet_build(stree): """ Construct an Iteration/Expression tree(IET) from a ScheduleTree. """ nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Increment(e) if e.is_Increment else Expression(e) for e in i.exprs] body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) elif i.is_Conditional: body = Conditional(i.guard, queues.pop(i)) elif i.is_Iteration: body = Iteration(queues.pop(i), i.dim, i.limits, direction=i.direction, properties=i.properties, uindices=i.sub_iterators) elif i.is_Section: body = Section('section%d' % nsections, body=queues.pop(i)) nsections += 1 elif i.is_Halo: body = HaloSpot(i.halo_scheme, body=queues.pop(i)) queues.setdefault(i.parent, []).append(body) assert False
def iet_make(stree): """ Create an Iteration/Expression tree (IET) from a :class:`ScheduleTree`. """ nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Expression(e) for e in i.exprs] body = [ExpressionBundle(i.shape, i.ops, i.traffic, body=exprs)] elif i.is_Conditional: body = [Conditional(i.guard, queues.pop(i))] elif i.is_Iteration: # Order to ensure deterministic code generation uindices = sorted(i.sub_iterators, key=lambda d: d.name) # Generate Iteration body = [Iteration(queues.pop(i), i.dim, i.dim.limits, offsets=i.limits, direction=i.direction, uindices=uindices)] elif i.is_Section: body = [Section('section%d' % nsections, body=queues.pop(i))] nsections += 1 elif i.is_Halo: body = [HaloSpot(i.halo_scheme, body=queues.pop(i))] queues.setdefault(i.parent, []).extend(body) assert False
def _hoist_halospots(iet): """ Hoist HaloSpots from inner to outer Iterations where all data dependencies would be honored. """ # Precompute scopes to save time scopes = { i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items() } # Analysis hsmapper = {} imapper = defaultdict(list) for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items(): for hs in halo_spots: hsmapper[hs] = hs.halo_scheme for f in hs.fmapper: for n, i in enumerate(iters): maybe_hoistable = set().union( *[i.dim._defines for i in iters[n:]]) d_flow = scopes[i].d_flow.project(f) if all(not (dep.cause & maybe_hoistable) or dep.write.is_increment for dep in d_flow): hsmapper[hs] = hsmapper[hs].drop(f) imapper[i].append(hs.halo_scheme.project(f)) break # Post-process analysis mapper = { i: HaloSpot(HaloScheme.union(hss), i._rebuild()) for i, hss in imapper.items() } mapper.update({ i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) for i, hs in hsmapper.items() }) # Transform the IET hoisting/dropping HaloSpots as according to the analysis iet = Transformer(mapper, nested=True).visit(iet) # Clean up: de-nest HaloSpots if necessary mapper = {} for hs in FindNodes(HaloSpot).visit(iet): if hs.body.is_HaloSpot: halo_scheme = HaloScheme.union( [hs.halo_scheme, hs.body.halo_scheme]) mapper[hs] = hs._rebuild(halo_scheme=halo_scheme, body=hs.body.body) iet = Transformer(mapper, nested=True).visit(iet) return iet
def _hoist_halospots(iet): """ Hoist HaloSpots from inner to outer Iterations where all data dependencies would be honored. """ # Hoisting rules -- if the retval is True, then it means the input `dep` is not # a stopper to halo hoisting def rule0(dep, candidates, loc_dims): # E.g., `dep=W<f,[x]> -> R<f,[x-1]>` and `candidates=({time}, {x})` => False # E.g., `dep=W<f,[t1, x, y]> -> R<f,[t0, x-1, y+1]>`, `dep.cause={t,time}` and # `candidates=({x},)` => True return (all(i & set(dep.distance_mapper) for i in candidates) and not any(i & dep.cause for i in candidates) and not any(i & loc_dims for i in candidates)) def rule1(dep, candidates, loc_dims): # An increment isn't a stopper to hoisting return dep.write.is_increment hoist_rules = [rule0, rule1] # Precompute scopes to save time scopes = { i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items() } # Analysis hsmapper = {} imapper = defaultdict(list) for iters, halo_spots in MapNodes(Iteration, HaloSpot, 'groupby').visit(iet).items(): for hs in halo_spots: hsmapper[hs] = hs.halo_scheme for f, (loc_indices, _) in hs.fmapper.items(): loc_dims = frozenset().union( [q for d in loc_indices for q in d._defines]) for n, i in enumerate(iters): candidates = [i.dim._defines for i in iters[n:]] test = True for dep in scopes[i].d_flow.project(f): if any( rule(dep, candidates, loc_dims) for rule in hoist_rules): continue test = False break if test: hsmapper[hs] = hsmapper[hs].drop(f) imapper[i].append(hs.halo_scheme.project(f)) break # Post-process analysis mapper = { i: HaloSpot(HaloScheme.union(hss), i._rebuild()) for i, hss in imapper.items() } mapper.update({ i: i.body if hs.is_void else i._rebuild(halo_scheme=hs) for i, hs in hsmapper.items() }) # Transform the IET hoisting/dropping HaloSpots as according to the analysis iet = Transformer(mapper, nested=True).visit(iet) # Clean up: de-nest HaloSpots if necessary mapper = {} for hs in FindNodes(HaloSpot).visit(iet): if hs.body.is_HaloSpot: halo_scheme = HaloScheme.union( [hs.halo_scheme, hs.body.halo_scheme]) mapper[hs] = hs._rebuild(halo_scheme=halo_scheme, body=hs.body.body) iet = Transformer(mapper, nested=True).visit(iet) return iet