def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break devicetype = as_list(self.lang[self.platform]) try: lang_init = [self.lang['init'](devicetype)] except TypeError: # Not all target languages need to be explicitly initialized lang_init = [] deviceid = DeviceID() if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = self.lang['num-devices'](devicetype) ngpus_init = LocalExpression(DummyEq(ngpus, call)) osdd_then = self.lang['set-device']([deviceid] + devicetype) osdd_else = self.lang['set-device']([rank % ngpus] + devicetype) body = lang_init + [ Conditional( CondNe(deviceid, -1), osdd_then, List( body=[rank_decl, rank_init, ngpus_init, osdd_else ]), ) ] header = c.Comment('Begin of %s+MPI setup' % self.lang['name']) footer = c.Comment('End of %s+MPI setup' % self.lang['name']) else: body = lang_init + [ Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) ] header = c.Comment('Begin of %s setup' % self.lang['name']) footer = c.Comment('End of %s setup' % self.lang['name']) init = List(header=header, body=body, footer=(footer, c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def _make_guard(self, parregion, *args): partrees = FindNodes(ParallelTree).visit(parregion) if not any(isinstance(i.root, self.DeviceIteration) for i in partrees): return super()._make_guard(parregion, *args) cond = [] # There must be at least one iteration or potential crash if not parregion.is_Affine: trees = retrieve_iteration_tree(parregion.root) tree = trees[0][:parregion.ncollapsed] cond.extend([i.symbolic_size > 0 for i in tree]) # SparseFunctions may occasionally degenerate to zero-size arrays. In such # a case, a copy-in produces a `nil` pointer on the device. To fire up a # parallel loop we must ensure none of the SparseFunction pointers are `nil` symbols = FindSymbols().visit(parregion) sfs = [i for i in symbols if i.is_SparseFunction] if sfs: size = [prod(f._C_get_field(FULL, d).size for d in f.dimensions) for f in sfs] cond.extend([i > 0 for i in size]) # Drop dynamically evaluated conditions (e.g. because the `symbolic_size` # is an integer value rather than a symbol). This avoids ugly and # unnecessary conditionals such as `if (true) { ...}` cond = [i for i in cond if i != true] # Combine all cond elements if cond: parregion = List(body=[Conditional(And(*cond), parregion)]) return parregion
def avoid_denormals(iet, platform=None): """ Introduce nodes in the Iteration/Expression tree that will expand to C macros telling the CPU to flush denormal numbers in hardware. Denormals are normally flushed when using SSE-based instruction sets, except when compiling shared objects. """ # There is unfortunately no known portable way of flushing denormal to zero. # See for example: https://stackoverflow.com/questions/59546406/\ # a-robust-portable-way-to-set-flush-denormals-to-zero try: if 'sse' not in platform.known_isas: return iet, {} except AttributeError: return iet, {} if iet.is_ElementalFunction: return iet, {} header = (cgen.Comment('Flush denormal numbers to zero in hardware'), cgen.Statement('_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON)'), cgen.Statement('_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON)'), cgen.Line()) body = iet.body._rebuild(body=(List(header=header),) + iet.body.body) iet = iet._rebuild(body=body) return iet, {'includes': ('xmmintrin.h', 'pmmintrin.h')}
def _make_parregion(self, partree, parrays): if not any(i.is_ParallelPrivate for i in partree.collapsed): return self.Region(partree) # Vector-expand all written Arrays within `partree`, since at least # one of the parallelized Iterations requires thread-private Arrays # E.g. a(x, y) -> b(tid, x, y), where `tid` is the ThreadID Dimension vexpandeds = [] for n in FindNodes(Expression).visit(partree): i = n.write if not (i.is_Array or i.is_TempFunction): continue elif i in parrays: pi = parrays[i] else: pi = parrays.setdefault(i, i._make_pointer(dim=self.threadid)) vexpandeds.append(VExpanded(i, pi)) if vexpandeds: init = self.lang['thread-num'](retobj=self.threadid) prefix = List(body=[init] + vexpandeds + list(partree.prefix), footer=c.Line()) partree = partree._rebuild(prefix=prefix) return self.Region(partree)
def _dump_storage(self, iet, storage): mapper = {} for k, v in storage.items(): # Expr -> LocalExpr ? if k.is_Expression: mapper[k] = v continue # allocs/pallocs allocs = flatten(v.allocs) for tid, body in as_mapper(v.pallocs, itemgetter(0), itemgetter(1)).items(): header = self.lang.Region._make_header(tid.symbolic_size) init = c.Initializer(c.Value(tid._C_typedata, tid.name), self.lang['thread-num']) allocs.append(c.Module((header, c.Block([init] + body)))) if allocs: allocs.append(c.Line()) # frees/pfrees frees = [] for tid, body in as_mapper(v.pfrees, itemgetter(0), itemgetter(1)).items(): header = self.lang.Region._make_header(tid.symbolic_size) init = c.Initializer(c.Value(tid._C_typedata, tid.name), self.lang['thread-num']) frees.append(c.Module((header, c.Block([init] + body)))) frees.extend(flatten(v.frees)) if frees: frees.insert(0, c.Line()) mapper[k] = k._rebuild(body=List(header=allocs, body=k.body, footer=frees), **k.args_frozen) processed = Transformer(mapper, nested=True).visit(iet) return processed
def _make_parregion(self, partree, parrays): arrays = [i for i in FindSymbols().visit(partree) if i.is_Array] # Detect thread-private arrays on the heap and "map" them to shared # vector-expanded (one entry per thread) Arrays heap_private = [i for i in arrays if i._mem_heap and i._mem_local] heap_globals = [] for i in heap_private: if i in parrays: pi = parrays[i] else: pi = parrays.setdefault( i, PointerArray(name=self.sregistry.make_name(), dimensions=(self.threadid, ), array=i)) heap_globals.append(HeapGlobal(i, pi)) if heap_globals: init = c.Initializer( c.Value(self.threadid._C_typedata, self.threadid.name), self.lang['thread-num']) prefix = List(header=init, body=heap_globals + list(partree.prefix), footer=c.Line()) partree = partree._rebuild(prefix=prefix) return self.Region(partree)
def _make_reductions(self, partree, collapsed): if not any(i.is_ParallelAtomic for i in collapsed): return partree # Collect expressions inducing reductions exprs = FindNodes(Expression).visit(partree) exprs = [ i for i in exprs if i.is_Increment and not i.is_ForeignExpression ] reduction = [i.output for i in exprs] if (all(i.is_Affine for i in collapsed) or all(not i.is_Indexed for i in reduction)): # Introduce reduction clause mapper = {partree.root: partree.root._rebuild(reduction=reduction)} else: # Introduce one `omp atomic` pragma for each increment mapper = { i: List(header=self.lang['atomic'], body=i) for i in exprs } partree = Transformer(mapper).visit(partree) return partree
def _make_parregion(self, partree, parrays): arrays = [i for i in FindSymbols().visit(partree) if i.is_Array] # Detect thread-private arrays on the heap and "map" them to shared # vector-expanded (one entry per thread) Arrays heap_private = [i for i in arrays if i._mem_heap and i._mem_local] heap_globals = [] for i in heap_private: if i in parrays: pi = parrays[i] else: pi = parrays.setdefault( i, PointerArray(name=self.sregistry.make_name(), dimensions=(self.threadid, ), array=i)) heap_globals.append(Dereference(i, pi)) if heap_globals: body = List(header=self._make_tid(self.threadid), body=heap_globals + [partree], footer=c.Line()) else: body = partree return OpenMPRegion(body, partree.nthreads)
def _make_parregion(self, partree, parrays): if not any(i.is_ParallelPrivate for i in partree.collapsed): return self.Region(partree) # Vector-expand all written Arrays within `partree`, since at least # one of the parallelized Iterations requires thread-private Arrays # E.g. a(x, y) -> b(tid, x, y), where `tid` is the ThreadID Dimension exprs = FindNodes(Expression).visit(partree) warrays = [i.write for i in exprs if i.write.is_Array] vexpandeds = [] for i in warrays: if i in parrays: pi = parrays[i] else: pi = parrays.setdefault(i, PointerArray(name=self.sregistry.make_name(), dimensions=(self.threadid,), array=i)) vexpandeds.append(VExpanded(i, pi)) if vexpandeds: init = c.Initializer(c.Value(self.threadid._C_typedata, self.threadid.name), self.lang['thread-num']) prefix = List(header=init, body=vexpandeds + list(partree.prefix), footer=c.Line()) partree = partree._rebuild(prefix=prefix) return self.Region(partree)
def place_casts(self, iet): """ Create a new IET with the necessary type casts. Parameters ---------- iet : Callable The input Iteration/Expression tree. """ functions = FindSymbols().visit(iet) need_cast = {i for i in functions if i.is_Tensor} # Make the generated code less verbose by avoiding unnecessary casts indexed_names = {i.name for i in FindSymbols('indexeds').visit(iet)} need_cast = { i for i in need_cast if i.name in indexed_names or i.is_ArrayBasic } casts = tuple(PointerCast(i) for i in iet.parameters if i in need_cast) if casts: casts = (List(body=casts, footer=c.Line()), ) iet = iet._rebuild(body=casts + iet.body) return iet, {}
def linearize_transfers(iet, sregistry): casts = FindNodes(PointerCast).visit(iet) candidates = {i.function for i in casts if i.flat is not None} mapper = {} for n in FindNodes(PragmaTransfer).visit(iet): if n.function not in candidates: continue try: imask0 = n.kwargs['imask'] except KeyError: imask0 = [] try: index = imask0.index(FULL) except ValueError: index = len(imask0) # Drop entries being flatten imask = imask0[:index] # The NVC 21.2 compiler (as well as all previous and potentially some # future versions as well) suffers from a bug in the parsing of pragmas # using subarrays in data clauses. For example, the following pragma # excerpt `... copyin(a[0]:b[0])` leads to a compiler error, despite # being perfectly legal OpenACC code. The workaround consists of # generating `const int ofs = a[0]; ... copyin(n:b[0])` exprs = [] if len(imask) < len(imask0) and len(imask) > 0: assert len(imask) == 1 try: start, size = imask[0] except TypeError: start, size = imask[0], 1 if start != 0: # Spare the ugly generated code if unneccesary (occurs often) name = sregistry.make_name(prefix='%s_ofs' % n.function.name) wildcard = Wildcard(name=name, dtype=np.int32, is_const=True) symsect = PragmaLangBB._make_symbolic_sections_from_imask(n.function, imask) assert len(symsect) == 1 start, _ = symsect[0] exprs.append(DummyExpr(wildcard, start, init=True)) imask = [(wildcard, size)] rebuilt = n._rebuild(imask=imask) if exprs: mapper[n] = List(body=exprs + [rebuilt]) else: mapper[n] = rebuilt iet = Transformer(mapper).visit(iet) return iet
def _make_guard(self, partree, collapsed): # Do not enter the parallel region if the step increment is 0; this # would raise a `Floating point exception (core dumped)` in some OpenMP # implementations. Note that using an OpenMP `if` clause won't work cond = [CondEq(i.step, 0) for i in collapsed if isinstance(i.step, Symbol)] cond = Or(*cond) if cond != False: # noqa: `cond` may be a sympy.False which would be == False partree = List(body=[Conditional(cond, Return()), partree]) return partree
def _make_atomic_incs(self, partree): if not partree.is_ParallelAtomic: return partree # Introduce one `omp atomic` pragma for each increment exprs = FindNodes(Expression).visit(partree) exprs = [i for i in exprs if i.is_Increment and not i.is_ForeignExpression] mapper = {i: List(header=self.lang['atomic'], body=i) for i in exprs} partree = Transformer(mapper).visit(partree) return partree
def __init__(self, prodder): # Atomic-ize any single-thread Prodders in the parallel tree condition = CondEq(Ompizer.lang['thread-num'], 0) # Prod within a while loop until all communications have completed # In other words, the thread delegated to prodding is entrapped for as long # as it's required prod_until = Not(DefFunction(prodder.name, [i.name for i in prodder.arguments])) then_body = List(header=c.Comment('Entrap thread until comms have completed'), body=While(prod_until)) Conditional.__init__(self, condition, then_body) Prodder.__init__(self, prodder.name, prodder.arguments, periodic=prodder.periodic)
def place_definitions(self, iet): """ Create a new IET with symbols allocated/deallocated in some memory space. Parameters ---------- iet : Callable The input Iteration/Expression tree. """ storage = Storage() for k, v in MapExprStmts().visit(iet).items(): if k.is_Expression: if k.is_definition: site = v[-1] if v else iet self._alloc_scalar_on_low_lat_mem(site, k, storage) continue objs = [k.write] elif k.is_Call: objs = k.arguments for i in objs: try: if i.is_LocalObject: site = v[-1] if v else iet self._alloc_object_on_low_lat_mem(site, i, storage) elif i.is_Array: if i in iet.parameters: # The Array is passed as a Callable argument continue elif i._mem_stack: self._alloc_array_on_low_lat_mem(iet, i, storage) else: self._alloc_array_on_high_bw_mem(i, storage) elif i.is_Function: self._map_function_on_high_bw_mem(i, storage) except AttributeError: # E.g., a generic SymPy expression pass # Introduce symbol definitions going in the low latency memory mapper = dict(storage._on_low_lat_mem) iet = Transformer(mapper, nested=True).visit(iet) # Introduce symbol definitions going in the high bandwidth memory if storage._on_high_bw_mem: decls, allocs, frees = zip(*storage._on_high_bw_mem) body = List(header=decls + allocs, body=iet.body, footer=frees) iet = iet._rebuild(body=body) return iet, {}
def _(iet): body = FindNodes(WhileAlive).visit(iet) assert len(body) == 1 body = body.pop() devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid init = Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype) ) mapper = {body: List(body=[init, BlankLine, body])} iet = Transformer(mapper).visit(iet) return iet, {}
def avoid_denormals(iet): """ Introduce nodes in the Iteration/Expression tree that will expand to C macros telling the CPU to flush denormal numbers in hardware. Denormals are normally flushed when using SSE-based instruction sets, except when compiling shared objects. """ if iet.is_ElementalFunction: return iet, {} header = (cgen.Comment('Flush denormal numbers to zero in hardware'), cgen.Statement('_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON)'), cgen.Statement('_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON)'), cgen.Line()) body = iet.body._rebuild(body=(List(header=header),) + iet.body.body) iet = iet._rebuild(body=body) return iet, {'includes': ('xmmintrin.h', 'pmmintrin.h')}
def _dump_storage(self, iet, storage): # Introduce symbol definitions going in the low latency memory mapper = dict(storage._on_low_lat_mem) iet = Transformer(mapper, nested=True).visit(iet) # Introduce symbol definitions going in the high bandwidth memory header = [] footer = [] for decl, alloc, free in storage._on_high_bw_mem: if decl is None: header.append(alloc) else: header.extend([decl, alloc]) footer.append(free) if header or footer: body = List(header=header, body=iet.body, footer=footer) iet = iet._rebuild(body=body) return iet
def __init__(self, body, private=None): # Normalize and sanity-check input. A bit ugly, but it makes everything # much simpler to manage and reconstruct body = as_tuple(body) assert len(body) == 1 body = body[0] assert body.is_List if isinstance(body, ParallelTree): partree = body elif body.is_List: assert len(body.body) == 1 and isinstance(body.body[0], ParallelTree) assert len(body.footer) == 0 partree = body.body[0] partree = partree._rebuild( prefix=(List(header=body.header, body=partree.prefix))) header = OmpRegion._make_header(partree.nthreads, private) super().__init__(header=header, body=partree)
def unfold_blocked_tree(iet): """ Unfold nested IterationFolds. Examples -------- Given a section of Iteration/Expression tree as below: :: for i = 1 to N-1 // folded for j = 1 to N-1 // folded foo1() Assuming a fold with offset 1 in both /i/ and /j/ and body ``foo2()``, create: :: for i = 1 to N-1 for j = 1 to N-1 foo1() for i = 2 to N-2 for j = 2 to N-2 foo2() """ # Search the unfolding candidates candidates = [] for tree in retrieve_iteration_tree(iet): handle = tuple(i for i in tree if i.is_IterationFold) if handle: # Sanity check assert IsPerfectIteration().visit(handle[0]) candidates.append(handle) # Perform unfolding mapper = {} for tree in candidates: trees = list(zip(*[i.unfold() for i in tree])) trees = optimize_unfolded_tree(trees[:-1], trees[-1]) mapper[tree[0]] = List(body=trees) # Insert the unfolded Iterations in the Iteration/Expression tree iet = Transformer(mapper).visit(iet) return iet
def make_blocking(self, iet): """ Apply loop blocking to PARALLEL Iteration trees. """ # Make sure loop blocking will span as many Iterations as possible iet = fold_blockable_tree(iet, self.blockinner) mapper = {} efuncs = [] block_dims = [] for tree in retrieve_iteration_tree(iet): # Is the Iteration tree blockable ? iterations = filter_iterations(tree, lambda i: i.is_Tilable) if not self.blockinner: iterations = iterations[:-1] if len(iterations) <= 1: continue root = iterations[0] if not IsPerfectIteration().visit(root): # Don't know how block non-perfect Iteration nests continue # Apply hierarchical loop blocking to `tree` level_0 = [] # Outermost level of blocking level_i = [[] for i in range(1, self.nlevels)] # Inner levels of blocking intra = [] # Within the smallest block for i in iterations: template = "%s%d_blk%s" % (i.dim.name, self.nblocked, '%d') properties = (PARALLEL,) + ((AFFINE,) if i.is_Affine else ()) # Build Iteration across `level_0` blocks d = BlockDimension(i.dim, name=template % 0) level_0.append(Iteration([], d, d.symbolic_max, properties=properties)) # Build Iteration across all `level_i` blocks, `i` in (1, self.nlevels] for n, li in enumerate(level_i, 1): di = BlockDimension(d, name=template % n) li.append(Iteration([], di, limits=(d, d+d.step-1, di.step), properties=properties)) d = di # Build Iteration within the smallest block intra.append(i._rebuild([], limits=(d, d+d.step-1, 1), offsets=(0, 0))) level_i = flatten(level_i) # Track all constructed BlockDimensions block_dims.extend(i.dim for i in level_0 + level_i) # Construct the blocked tree blocked = compose_nodes(level_0 + level_i + intra + [iterations[-1].nodes]) blocked = unfold_blocked_tree(blocked) # Promote to a separate Callable dynamic_parameters = flatten((l0.dim, l0.step) for l0 in level_0) dynamic_parameters.extend([li.step for li in level_i]) efunc = make_efunc("bf%d" % self.nblocked, blocked, dynamic_parameters) efuncs.append(efunc) # Compute the iteration ranges ranges = [] for i, l0 in zip(iterations, level_0): maxb = i.symbolic_max - (i.symbolic_size % l0.step) ranges.append(((i.symbolic_min, maxb, l0.step), (maxb + 1, i.symbolic_max, i.symbolic_max - maxb))) # Build Calls to the `efunc` body = [] for p in product(*ranges): dynamic_args_mapper = {} for l0, (m, M, b) in zip(level_0, p): dynamic_args_mapper[l0.dim] = (m, M) dynamic_args_mapper[l0.step] = (b,) for li in level_i: if li.dim.root is l0.dim.root: value = li.step if b is l0.step else b dynamic_args_mapper[li.step] = (value,) call = efunc.make_call(dynamic_args_mapper) body.append(List(body=call)) mapper[root] = List(body=body) # Next blockable nest, use different (unique) variable/function names self.nblocked += 1 iet = Transformer(mapper).visit(iet) # Force-unfold if some folded Iterations haven't been blocked in the end iet = unfold_blocked_tree(iet) return iet, {'dimensions': block_dims, 'efuncs': efuncs, 'args': [i.step for i in block_dims]}
class AccBB(PragmaLangBB): mapper = { # Misc 'name': 'OpenACC', 'header': 'openacc.h', # Platform mapping AMDGPUX: Macro('acc_device_radeon'), NVIDIAX: Macro('acc_device_nvidia'), # Runtime library 'init': lambda args: Call('acc_init', args), 'num-devices': lambda args: DefFunction('acc_get_num_devices', args), 'set-device': lambda args: Call('acc_set_device_num', args), # Pragmas 'atomic': c.Pragma('acc atomic update'), 'map-enter-to': lambda i, j: c.Pragma('acc enter data copyin(%s%s)' % (i, j)), 'map-enter-to-wait': lambda i, j, k: (c.Pragma('acc enter data copyin(%s%s) async(%s)' % (i, j, k)), c.Pragma('acc wait(%s)' % k)), 'map-enter-alloc': lambda i, j: c.Pragma('acc enter data create(%s%s)' % (i, j)), 'map-present': lambda i, j: c.Pragma('acc data present(%s%s)' % (i, j)), 'map-wait': lambda i: c.Pragma('acc wait(%s)' % i), 'map-update': lambda i, j: c.Pragma('acc exit data copyout(%s%s)' % (i, j)), 'map-update-host': lambda i, j: c.Pragma('acc update self(%s%s)' % (i, j)), 'map-update-host-async': lambda i, j, k: c.Pragma('acc update self(%s%s) async(%s)' % (i, j, k)), 'map-update-device': lambda i, j: c.Pragma('acc update device(%s%s)' % (i, j)), 'map-update-device-async': lambda i, j, k: c.Pragma('acc update device(%s%s) async(%s)' % (i, j, k)), 'map-release': lambda i, j, k: c.Pragma('acc exit data delete(%s%s)%s' % (i, j, k)), 'map-exit-delete': lambda i, j, k: c.Pragma('acc exit data delete(%s%s)%s' % (i, j, k)), 'memcpy-to-device': lambda i, j, k: Call('acc_memcpy_to_device', [i, j, k]), 'memcpy-to-device-wait': lambda i, j, k, l: List(body=[ Call('acc_memcpy_to_device_async', [i, j, k, l]), Call('acc_wait', [l]) ]), 'device-get': Call('acc_get_device_num'), 'device-alloc': lambda i, *a, retobj=None: Call( 'acc_malloc', (i, ), retobj=retobj, cast=True), 'device-free': lambda i, *a: Call('acc_free', (i, )) } mapper.update(CBB.mapper) Region = OmpRegion HostIteration = OmpIteration # Host parallelism still goes via OpenMP DeviceIteration = DeviceAccIteration @classmethod def _map_to_wait(cls, f, imask=None, queueid=None): sections = cls._make_sections_from_imask(f, imask) return cls.mapper['map-enter-to-wait'](f.name, sections, queueid) @classmethod def _map_present(cls, f, imask=None): sections = cls._make_sections_from_imask(f, imask) return cls.mapper['map-present'](f.name, sections) @classmethod def _map_delete(cls, f, imask=None, devicerm=None): sections = cls._make_sections_from_imask(f, imask) if devicerm is not None: cond = ' if(%s)' % devicerm.name else: cond = '' return cls.mapper['map-exit-delete'](f.name, sections, cond) @classmethod def _map_update_host_async(cls, f, imask=None, queueid=None): sections = cls._make_sections_from_imask(f, imask) return cls.mapper['map-update-host-async'](f.name, sections, queueid) @classmethod def _map_update_device_async(cls, f, imask=None, queueid=None): sections = cls._make_sections_from_imask(f, imask) return cls.mapper['map-update-device-async'](f.name, sections, queueid)
def place_definitions(self, iet, **kwargs): """ Create a new IET with symbols allocated/deallocated in some memory space. Parameters ---------- iet : Callable The input Iteration/Expression tree. """ storage = Storage() # Collect and declare symbols for k, v in MapExprStmts().visit(iet).items(): if k.is_Expression: if k.is_definition: site = v[-1] if v else iet self._alloc_scalar_on_low_lat_mem(site, k, storage) continue objs = [k.write] elif k.is_Call: objs = k.arguments for i in objs: try: if i.is_LocalObject: site = v[-1] if v else iet self._alloc_object_on_low_lat_mem(site, i, storage) elif i.is_Array: if i in iet.parameters: # The Array is passed as a Callable argument continue elif i._mem_stack: self._alloc_array_on_low_lat_mem(iet, i, storage) else: self._alloc_array_on_high_bw_mem(i, storage) except AttributeError: # E.g., a generic SymPy expression pass # Place symbols in a memory space if not iet.is_ElementalFunction: writes = set() reads = set() for efunc in kwargs.get('efuncs', []): for i in FindNodes(Expression).visit(efunc): if i.write.is_Function: writes.add(i.write) reads = (reads | {r for r in i.reads if r.is_Function}) - writes for i in filter_sorted(writes): self._map_function_on_high_bw_mem(i, storage) for i in filter_sorted(reads): self._map_function_on_high_bw_mem(i, storage, read_only=True) # Introduce symbol definitions going in the low latency memory mapper = dict(storage._on_low_lat_mem) iet = Transformer(mapper, nested=True).visit(iet) # Introduce symbol definitions going in the high bandwidth memory header = [] footer = [] for decl, alloc, free in storage._on_high_bw_mem: if decl is None: header.append(alloc) else: header.extend([decl, alloc]) footer.append(free) if header or footer: body = List(header=header, body=iet.body, footer=footer) iet = iet._rebuild(body=body) return iet, {}