def _make_sendrecv(self, f, hse, key, **kwargs): comm = f.grid.distributor._obj_comm buf_dims = [ Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in hse.loc_indices ] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, padding=0, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, padding=0, scope='heap') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') gather = Call('gather_%s' % key, [bufg] + list(bufg.shape) + [f] + ofsg) scatter = Call('scatter_%s' % key, [bufs] + list(bufs.shape) + [f] + ofss) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) count = reduce(mul, bufs.shape, 1) rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) parameters = ([f] + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm]) return Callable('sendrecv_%s' % key, iet, 'void', parameters, ('static', ))
def update_halo(f, fixed): """ Construct an IET performing a halo exchange for a :class:`TensorFunction`. """ # Requirements assert f.is_Function assert f.grid is not None distributor = f.grid.distributor nb = distributor._C_neighbours.obj comm = distributor._C_comm fixed = {d: Symbol(name="o%s" % d.root) for d in fixed} mapper = get_views(f, fixed) body = [] masks = [] for d in f.dimensions: if d in fixed: continue rpeer = FieldFromPointer("%sright" % d, nb) lpeer = FieldFromPointer("%sleft" % d, nb) # Sending to left, receiving from right lsizes, loffsets = mapper[(d, LEFT, OWNED)] rsizes, roffsets = mapper[(d, RIGHT, HALO)] assert lsizes == rsizes sizes = lsizes parameters = ([f] + list(f.symbolic_shape) + sizes + loffsets + roffsets + [rpeer, lpeer, comm]) call = Call('sendrecv_%s' % f.name, parameters) mask = Symbol(name='m%sl' % d) body.append(Conditional(mask, call)) masks.append(mask) # Sending to right, receiving from left rsizes, roffsets = mapper[(d, RIGHT, OWNED)] lsizes, loffsets = mapper[(d, LEFT, HALO)] assert rsizes == lsizes sizes = rsizes parameters = ([f] + list(f.symbolic_shape) + sizes + roffsets + loffsets + [lpeer, rpeer, comm]) call = Call('sendrecv_%s' % f.name, parameters) mask = Symbol(name='m%sr' % d) body.append(Conditional(mask, call)) masks.append(mask) iet = List(body=body) parameters = ([f] + masks + [comm, nb] + list(fixed.values()) + [d.symbolic_size for d in f.dimensions]) return Callable('halo_exchange_%s' % f.name, iet, 'void', parameters, ('static', ))
def _make_halowait(self, f, hse, key, msg=None): cast = cast_mapper[(f.dtype, '*')] fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} dim = Dimension(name='i') msgi = IndexedPointer(msg, dim) bufs = FieldFromComposite(msg._C_field_bufs, msgi) fromrank = FieldFromComposite(msg._C_field_from, msgi) sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi) for i in range(len(f._dist_dimensions))] ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi) for i in range(len(f._dist_dimensions))] ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions] # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Call('scatter%s' % key, [cast(bufs)] + sizes + [f] + ofss) scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi)) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi)) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) # The -1 below is because an Iteration, by default, generates <= ncomms = Symbol(name='ncomms') iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1) parameters = ([f] + list(fixed.values()) + [msg, ncomms]) return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
def _make_wait(self, f, hse, key, msg=None): bufs = FieldFromPointer(msg._C_field_bufs, msg) ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') sizes = [ FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) for i in range(len(f._dist_dimensions)) ] scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[waitsend, waitrecv, scatter]) parameters = ([f] + ofss + [fromrank, msg]) return Callable('wait_%s' % key, iet, 'void', parameters, ('static', ))
def _make_sendrecv(self, f, hse, key, msg=None): comm = f.grid.distributor._obj_comm bufg = FieldFromPointer(msg._C_field_bufg, msg) bufs = FieldFromPointer(msg._C_field_bufs, msg) ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) for i in range(len(f._dist_dimensions))] gather = Call('gather%s' % key, [bufg] + sizes + [f] + ofsg) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) count = reduce(mul, sizes, 1) rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv]) send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend]) iet = List(body=[recv, gather, send]) parameters = ([f] + ofsg + [fromrank, torank, comm, msg]) return SendRecv(key, iet, parameters, bufg, bufs)
def iet_make(stree): """Create an IET from a ScheduleTree.""" nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Increment(e) if e.is_Increment else Expression(e) for e in i.exprs] body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) elif i.is_Conditional: body = Conditional(i.guard, queues.pop(i)) elif i.is_Iteration: # Order to ensure deterministic code generation uindices = sorted(i.sub_iterators, key=lambda d: d.name) # Generate Iteration body = Iteration(queues.pop(i), i.dim, i.limits, offsets=i.offsets, direction=i.direction, properties=i.properties, uindices=uindices) elif i.is_Section: body = Section('section%d' % nsections, body=queues.pop(i)) nsections += 1 elif i.is_Halo: body = HaloSpot(i.halo_scheme, body=queues.pop(i)) queues.setdefault(i.parent, []).append(body) assert False
def iet_build(stree): """ Construct an Iteration/Expression tree(IET) from a ScheduleTree. """ nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Increment(e) if e.is_Increment else Expression(e) for e in i.exprs] body = ExpressionBundle(i.ispace, i.ops, i.traffic, body=exprs) elif i.is_Conditional: body = Conditional(i.guard, queues.pop(i)) elif i.is_Iteration: body = Iteration(queues.pop(i), i.dim, i.limits, direction=i.direction, properties=i.properties, uindices=i.sub_iterators) elif i.is_Section: body = Section('section%d' % nsections, body=queues.pop(i)) nsections += 1 elif i.is_Halo: body = HaloSpot(i.halo_scheme, body=queues.pop(i)) queues.setdefault(i.parent, []).append(body) assert False
def iet_make(stree): """ Create an Iteration/Expression tree (IET) from a :class:`ScheduleTree`. """ nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Expression(e) for e in i.exprs] body = [ExpressionBundle(i.shape, i.ops, i.traffic, body=exprs)] elif i.is_Conditional: body = [Conditional(i.guard, queues.pop(i))] elif i.is_Iteration: # Order to ensure deterministic code generation uindices = sorted(i.sub_iterators, key=lambda d: d.name) # Generate Iteration body = [Iteration(queues.pop(i), i.dim, i.dim.limits, offsets=i.limits, direction=i.direction, uindices=uindices)] elif i.is_Section: body = [Section('section%d' % nsections, body=queues.pop(i))] nsections += 1 elif i.is_Halo: body = [HaloSpot(i.halo_scheme, body=queues.pop(i))] queues.setdefault(i.parent, []).extend(body) assert False
def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break deviceid = DeviceID() device_nvidia = Macro('acc_device_nvidia') if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = DefFunction('acc_get_num_devices', device_nvidia) ngpus_init = LocalExpression(DummyEq(ngpus, call)) asdn_then = Call('acc_set_device_num', [deviceid, device_nvidia]) asdn_else = Call('acc_set_device_num', [rank % ngpus, device_nvidia]) body = [ Call('acc_init', [device_nvidia]), Conditional( CondNe(deviceid, -1), asdn_then, List(body=[rank_decl, rank_init, ngpus_init, asdn_else])) ] else: body = [ Call('acc_init', [device_nvidia]), Conditional( CondNe(deviceid, -1), Call('acc_set_device_num', [deviceid, device_nvidia])) ] init = List(header=c.Comment('Begin of OpenACC+MPI setup'), body=body, footer=(c.Comment('End of OpenACC+MPI setup'), c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def _make_haloupdate(self, f, hse, key, msg=None): comm = f.grid.distributor._obj_comm fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} dim = Dimension(name='i') msgi = IndexedPointer(msg, dim) bufg = FieldFromComposite(msg._C_field_bufg, msgi) bufs = FieldFromComposite(msg._C_field_bufs, msgi) fromrank = FieldFromComposite(msg._C_field_from, msgi) torank = FieldFromComposite(msg._C_field_to, msgi) sizes = [ FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi) for i in range(len(f._dist_dimensions)) ] ofsg = [ FieldFromComposite('%s[%d]' % (msg._C_field_ofsg, i), msgi) for i in range(len(f._dist_dimensions)) ] ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions] # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Call('gather_%s' % key, [bufg] + sizes + [f] + ofsg) gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # Make Irecv/Isend count = reduce(mul, sizes, 1) rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi)) rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi)) recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) # The -1 below is because an Iteration, by default, generates <= ncomms = Symbol(name='ncomms') iet = Iteration([recv, gather, send], dim, ncomms - 1) parameters = ([f, comm, msg, ncomms]) + list(fixed.values()) return Callable('haloupdate%d' % key, iet, 'void', parameters, ('static', ))
def sendrecv(f, fixed): """Construct an IET performing a halo exchange along arbitrary dimension and side.""" assert f.is_Function assert f.grid is not None comm = f.grid.distributor._C_comm buf_dims = [Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap') dat_dims = [Dimension(name='dat_%s' % d.root) for d in f.dimensions] dat = Array(name='dat', dimensions=dat_dims, dtype=f.dtype, scope='external') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') parameters = [bufg] + list(bufg.shape) + [dat] + list(dat.shape) + ofsg gather = Call('gather_%s' % f.name, parameters) parameters = [bufs] + list(bufs.shape) + [dat] + list(dat.shape) + ofss scatter = Call('scatter_%s' % f.name, parameters) # The scatter must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) srecv = MPIStatusObject(name='srecv') rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') count = reduce(mul, bufs.shape, 1) recv = Call('MPI_Irecv', [bufs, count, Macro(numpy_to_mpitypes(f.dtype)), fromrank, '13', comm, rrecv]) send = Call('MPI_Isend', [bufg, count, Macro(numpy_to_mpitypes(f.dtype)), torank, '13', comm, rsend]) waitrecv = Call('MPI_Wait', [rrecv, srecv]) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) iet = List(body=[ArrayCast(dat), iet_insert_C_decls(iet)]) parameters = ([dat] + list(dat.shape) + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm]) return Callable('sendrecv_%s' % f.name, iet, 'void', parameters, ('static',))
def __make_tfunc(self, name, iet, root, threads): # Create the SharedData required = derive_parameters(iet) known = (root.parameters + tuple(i for i in required if i.is_Array and i._mem_shared)) parameters, dynamic_parameters = split(required, lambda i: i in known) sdata = SharedData(name=self.sregistry.make_name(prefix='sdata'), nthreads_std=threads.size, fields=dynamic_parameters) parameters.append(sdata) # Prepend the unwinded SharedData fields, available upon thread activation preactions = [ DummyExpr(i, FieldFromPointer(i.name, sdata.symbolic_base)) for i in dynamic_parameters ] preactions.append( DummyExpr(sdata.symbolic_id, FieldFromPointer(sdata._field_id, sdata.symbolic_base))) # Append the flag reset postactions = [ List(body=[ BlankLine, DummyExpr( FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 1) ]) ] iet = List(body=preactions + [iet] + postactions) # Append the flag reset # The thread has work to do when it receives the signal that all locks have # been set to 0 by the main thread iet = Conditional( CondEq(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 2), iet) # The thread keeps spinning until the alive flag is set to 0 by the main thread iet = While( CondNe(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 0), iet) return Callable(name, iet, 'void', parameters, 'static'), sdata
def iet_make(stree): """ Create an Iteration/Expression tree (IET) from a :class:`ScheduleTree`. """ nsections = 0 queues = OrderedDict() for i in stree.visit(): if i == stree: # We hit this handle at the very end of the visit return List(body=queues.pop(i)) elif i.is_Exprs: exprs = [Expression(e) for e in i.exprs] body = [ExpressionBundle(i.shape, i.ops, i.traffic, body=exprs)] elif i.is_Conditional: body = [Conditional(i.guard, queues.pop(i))] elif i.is_Iteration: # Generate `uindices` uindices = [] for d, offs in i.sub_iterators: modulo = len(offs) for n, o in enumerate(filter_ordered(offs)): value = (i.dim + o) % modulo symbol = Scalar(name="%s%d" % (d.name, n), dtype=np.int32) uindices.append( UnboundedIndex(symbol, value, value, d, d + o)) # Generate Iteration body = [ Iteration(queues.pop(i), i.dim, i.dim.limits, offsets=i.limits, direction=i.direction, uindices=uindices) ] elif i.is_Section: body = [Section('section%d' % nsections, body=queues.pop(i))] nsections += 1 queues.setdefault(i.parent, []).extend(body) assert False
def make_parallel(self, iet): """Transform ``iet`` by introducing shared-memory parallelism.""" mapper = OrderedDict() for tree in retrieve_iteration_tree(iet): # Get the first omp-parallelizable Iteration in `tree` candidates = filter_iterations(tree, key=self.key, stop='asap') if not candidates: continue root = candidates[0] # Build the `omp-for` tree partree = self._make_parallel_tree(root, candidates) # Find out the thread-private and thread-shared variables private = [ i for i in FindSymbols().visit(partree) if i.is_Array and i._mem_stack ] # Build the `omp-parallel` region private = sorted(set([i.name for i in private])) private = ('private(%s)' % ','.join(private)) if private else '' partree = Block(header=self.lang['par-region'](self.nthreads.name, private), body=partree) # Do not enter the parallel region if the step increment might be 0; this # would raise a `Floating point exception (core dumped)` in some OpenMP # implementation. Note that using an OpenMP `if` clause won't work if isinstance(root.step, Symbol): cond = Conditional(CondEq(root.step, 0), Element(c.Statement('return'))) partree = List(body=[cond, partree]) mapper[root] = partree iet = Transformer(mapper).visit(iet) return iet, {'input': [self.nthreads] if mapper else []}
def _make_fetchprefetch(self, iet, sync_ops, pieces, root): fid = SharedData._field_id fetches = [] prefetches = [] presents = [] for s in sync_ops: f = s.function dimensions = s.dimensions fc = s.fetch ifc = s.ifetch pfc = s.pfetch fcond = s.fcond pcond = s.pcond # Construct init IET imask = [(ifc, s.size) if d.root is s.dim.root else FULL for d in dimensions] fetch = PragmaTransfer(self.lang._map_to, f, imask=imask) fetches.append(Conditional(fcond, fetch)) # Construct present clauses imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in dimensions] presents.append( PragmaTransfer(self.lang._map_present, f, imask=imask)) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in dimensions] prefetch = PragmaTransfer(self.lang._map_to_wait, f, imask=imask, queueid=fid) prefetches.append(Conditional(pcond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine])) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)) ] + presents + [iet, tctx.activate]) # Fire up the threads pieces.init.append(tctx.init) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) # Keep track of created objects pieces.objs.add(sync_ops, sdata, threads) return iet
def iet_make(clusters): """ Create an Iteration/Expression tree (IET) given an iterable of :class:`Cluster`s. :param clusters: The iterable :class:`Cluster`s for which the IET is built. """ # {Iteration -> [c0, c1, ...]}, shared clusters shared = {} # The constructed IET processed = [] # {Interval -> Iteration}, carried from preceding cluster schedule = OrderedDict() # Build IET for cluster in clusters: body = [Expression(e) for e in cluster.exprs] if cluster.ispace.empty: # No Iterations are needed processed.extend(body) continue root = None itintervals = cluster.ispace.iteration_intervals # Can I reuse any of the previously scheduled Iterations ? index = 0 for i0, i1 in zip(itintervals, list(schedule)): if i0 != i1 or i0.dim in cluster.atomics: break root = schedule[i1] index += 1 needed = itintervals[index:] # Build Expressions if not needed: body = List(body=body) # Build Iterations scheduling = [] for i in reversed(needed): # Update IET and scheduling if i.dim in cluster.guards: # Must wrap within an if-then scope body = Conditional(cluster.guards[i.dim], body) # Adding (None, None) ensures that nested iterations won't # be reused by the next cluster scheduling.insert(0, (None, None)) iteration = Iteration(body, i.dim, i.dim.limits, offsets=i.limits, direction=i.direction) scheduling.insert(0, (i, iteration)) # Prepare for next dimension body = iteration # If /needed/ is != [], root.dim might be a guarded dimension for /cluster/ if root is not None and root.dim in cluster.guards: body = Conditional(cluster.guards[root.dim], body) # Update the current schedule if root is None: processed.append(body) else: nodes = list(root.nodes) + [body] transf = Transformer( {root: root._rebuild(nodes, **root.args_frozen)}) processed = list(transf.visit(processed)) scheduling = list(schedule.items())[:index] + list(scheduling) scheduling = [(k, transf.rebuilt.get(v, v)) for k, v in scheduling] shared = {transf.rebuilt.get(k, k): v for k, v in shared.items()} schedule = OrderedDict(scheduling) # Record that /cluster/ was used to build the iterations in /schedule/ shared.update( {i: shared.get(i, []) + [cluster] for i in schedule.values() if i}) iet = List(body=processed) # Add in unbounded indices, if needed mapper = {} for k, v in shared.items(): uindices = [] ispace = IterationSpace.merge(*[i.ispace.project([k.dim]) for i in v]) for j, offs in ispace.sub_iterators.get(k.dim, []): modulo = len(offs) for n, o in enumerate(filter_ordered(offs)): name = "%s%d" % (j.name, n) vname = Scalar(name=name, dtype=np.int32) value = (k.dim + o) % modulo uindices.append(UnboundedIndex(vname, value, value, j, j + o)) mapper[k] = k._rebuild(uindices=uindices) iet = NestedTransformer(mapper).visit(iet) return iet
def block4(exprs, iters, dims): # Non-perfect loop nest due to conditional # for i # if i % 2 == 0 # for j return iters[0](Conditional(Eq(Mod(dims['i'], 2), 0), iters[1](exprs[0])))
def iet_make(clusters, dtype): """ Create an Iteration/Expression tree (IET) given an iterable of :class:`Cluster`s. :param clusters: The iterable :class:`Cluster`s for which the IET is built. :param dtype: The data type of the scalar expressions. """ processed = [] schedule = OrderedDict() for cluster in clusters: if not cluster.ispace.empty: root = None intervals = cluster.ispace.intervals # Can I reuse any of the previously scheduled Iterations ? index = 0 for i0, i1 in zip(intervals, list(schedule)): if i0 != i1 or i0.dim in cluster.atomics: break root = schedule[i1] index += 1 needed = intervals[index:] # Build Expressions body = [ Expression( e, np.int32 if cluster.trace.is_index(e.lhs) else dtype) for e in cluster.exprs ] if not needed: body = List(body=body) # Build Iterations scheduling = [] for i in reversed(needed): # Prepare any necessary unbounded index uindices = [] for j, offs in cluster.ispace.sub_iterators.get(i.dim, []): modulo = len(offs) for n, o in enumerate(filter_ordered(offs)): name = "%s%d" % (j.name, n) vname = Scalar(name=name, dtype=np.int32) value = (i.dim + o) % modulo uindices.append( UnboundedIndex(vname, value, value, j, j + o)) # Retrieve the iteration direction direction = cluster.ispace.directions[i.dim] # Update IET and scheduling if i.dim in cluster.guards: # Must wrap within an if-then scope body = Conditional(cluster.guards[i.dim], body) iteration = Iteration(body, i.dim, i.dim.limits, offsets=i.limits, direction=direction, uindices=uindices) # Adding (None, None) ensures that nested iterations won't # be reused by the next cluster scheduling.extend([(None, None), (i, iteration)]) else: iteration = Iteration(body, i.dim, i.dim.limits, offsets=i.limits, direction=direction, uindices=uindices) scheduling.append((i, iteration)) # Prepare for next dimension body = iteration # If /needed/ is != [], root.dim might be a guarded dimension for /cluster/ if root is not None and root.dim in cluster.guards: body = Conditional(cluster.guards[root.dim], body) # Update the current schedule scheduling = OrderedDict(reversed(scheduling)) if root is None: processed.append(body) schedule = scheduling else: nodes = list(root.nodes) + [body] mapper = {root: root._rebuild(nodes, **root.args_frozen)} transformer = Transformer(mapper) processed = list(transformer.visit(processed)) schedule = OrderedDict( list(schedule.items())[:index] + list(scheduling.items())) for k, v in list(schedule.items()): schedule[k] = transformer.rebuilt.get(v, v) else: # No Iterations are needed processed.extend([Expression(e, dtype) for e in cluster.exprs]) return List(body=processed)
def test_conditional(self, fc): then_body = Expression(DummyEq(fc[x, y], fc[x, y] + 1)) else_body = Expression(DummyEq(fc[x, y], fc[x, y] + 2)) conditional = Conditional(x < 3, then_body, else_body) assert str(conditional) == """\
def test_conditional(self, fc, grid): x, y, _ = grid.dimensions then_body = Expression(DummyEq(fc[x, y], fc[x, y] + 1)) else_body = Expression(DummyEq(fc[x, y], fc[x, y] + 2)) conditional = Conditional(x < 3, then_body, else_body) assert str(conditional) == """\
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): threads = self.__make_threads() fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) fsize = s.function._C_get_field(FULL, s.dim).size fc_cond = fc + (s.size - 1) < fsize pfc = s.fetch + 1 pfc_cond = pfc + (s.size - 1) < fsize else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) fc_cond = fc >= 0 pfc = s.fetch - 1 pfc_cond = pfc >= 0 # Construct fetch IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = List(header=self._P._map_to(s.function, imask)) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self._P._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = List(header=self._P._map_to_wait( s.function, imask, SharedData._field_id)) prefetches.append(Conditional(pfc_cond, prefetch)) functions = filter_ordered(s.function for s in sync_ops) casts = [PointerCast(f) for f in functions] # Turn init IET into a Callable name = self.sregistry.make_name(prefix='init_device') body = List(body=casts + fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream for `%s`" % threads.name), body=[Call(name, func.parameters), BlankLine])) # Turn prefetch IET into a threaded Callable name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=casts + prefetches) tfunc, sdata = self.__make_tfunc(name, body, root, threads) pieces.funcs.append(tfunc) # Glue together all the IET pieces, including the activation bits iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)), List(header=presents), iet, self.__make_activate_thread(threads, sdata, sync_ops) ]) # Fire up the threads pieces.init.append( self.__make_init_threads(threads, sdata, tfunc, pieces)) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(self.__make_finalize_threads(threads, sdata)) return iet
def test_conditional(fc, grid): x, y, _ = grid.dimensions then_body = DummyExpr(fc[x, y], fc[x, y] + 1) else_body = DummyExpr(fc[x, y], fc[x, y] + 2) conditional = Conditional(x < 3, then_body, else_body) assert str(conditional) == """\
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) pfc = s.fetch + 1 fc_cond = s.next_cbk(s.dim.symbolic_min) pfc_cond = s.next_cbk(s.dim + 1) else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) pfc = s.fetch - 1 fc_cond = s.next_cbk(s.dim.symbolic_max) pfc_cond = s.next_cbk(s.dim - 1) # Construct init IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = PragmaList(self.lang._map_to(s.function, imask), {s.function} | fc.free_symbols) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self.lang._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = PragmaList(self.lang._map_to_wait(s.function, imask, SharedData._field_id), {s.function} | pfc.free_symbols) prefetches.append(Conditional(pfc_cond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append(List( header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine] )) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait(CondNe(FieldFromComposite(sdata._field_flag, sdata[threads.index]), 1)), List(header=presents), iet, tctx.activate ]) # Fire up the threads pieces.init.append(tctx.init) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) return iet
def _make_sendrecv(self, f, fixed, extra=None): extra = extra or [] comm = f.grid.distributor._obj_comm buf_dims = [ Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed ] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') args = [bufg] + list(bufg.shape) + [f] + ofsg + extra gather = Call('gather%dd' % f.ndim, args) args = [bufs] + list(bufs.shape) + [f] + ofss + extra scatter = Call('scatter%dd' % f.ndim, args) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) srecv = MPIStatusObject(name='srecv') ssend = MPIStatusObject(name='ssend') rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') count = reduce(mul, bufs.shape, 1) recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) waitrecv = Call('MPI_Wait', [rrecv, srecv]) waitsend = Call('MPI_Wait', [rsend, ssend]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) iet = List(body=iet_insert_C_decls(iet)) parameters = ([f] + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm] + extra) return Callable('sendrecv%dd' % f.ndim, iet, 'void', parameters, ('static', ))