def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break devicetype = as_list(self.lang[self.platform]) try: lang_init = [self.lang['init'](devicetype)] except TypeError: # Not all target languages need to be explicitly initialized lang_init = [] deviceid = DeviceID() if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = self.lang['num-devices'](devicetype) ngpus_init = LocalExpression(DummyEq(ngpus, call)) osdd_then = self.lang['set-device']([deviceid] + devicetype) osdd_else = self.lang['set-device']([rank % ngpus] + devicetype) body = lang_init + [ Conditional( CondNe(deviceid, -1), osdd_then, List( body=[rank_decl, rank_init, ngpus_init, osdd_else ]), ) ] header = c.Comment('Begin of %s+MPI setup' % self.lang['name']) footer = c.Comment('End of %s+MPI setup' % self.lang['name']) else: body = lang_init + [ Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) ] header = c.Comment('Begin of %s setup' % self.lang['name']) footer = c.Comment('End of %s setup' % self.lang['name']) init = List(header=header, body=body, footer=(footer, c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def _make_thread_activate(threads, sdata, sync_ops, sregistry): if threads.size == 1: d = threads.index else: d = Symbol(name=sregistry.make_name(prefix=threads.index.name)) sync_locks = [s for s in sync_ops if s.is_SyncLock] condition = Or(*([CondNe(s.handle, 2) for s in sync_locks] + [CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)])) if threads.size == 1: activation = [While(condition)] else: activation = [DummyExpr(d, 0), While(condition, DummyExpr(d, (d + 1) % threads.size))] activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i) for i in sdata.dynamic_fields]) activation.extend([DummyExpr(s.handle, 0) for s in sync_locks]) activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2)) activation = List( header=[c.Line(), c.Comment("Activate `%s`" % threads.name)], body=activation, footer=c.Line() ) return activation
def _make_sendrecv(self, f, hse, key, **kwargs): comm = f.grid.distributor._obj_comm buf_dims = [ Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in hse.loc_indices ] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, padding=0, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, padding=0, scope='heap') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') gather = Call('gather_%s' % key, [bufg] + list(bufg.shape) + [f] + ofsg) scatter = Call('scatter_%s' % key, [bufs] + list(bufs.shape) + [f] + ofss) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) count = reduce(mul, bufs.shape, 1) rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) parameters = ([f] + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm]) return Callable('sendrecv_%s' % key, iet, 'void', parameters, ('static', ))
def _make_thread_func(name, iet, root, threads, sregistry): # Create the SharedData, that is the data structure that will be used by the # main thread to pass information dows to the child thread(s) required, parameters, dynamic_parameters = diff_parameters(iet, root) parameters = sorted(parameters, key=lambda i: i.is_Function) # Allow casting sdata = SharedData(name=sregistry.make_name(prefix='sdata'), npthreads=threads.size, fields=required, dynamic_fields=dynamic_parameters) sbase = sdata.symbolic_base sid = sdata.symbolic_id # Create a Callable to initialize `sdata` with the known const values iname = 'init_%s' % sdata.dtype._type_.__name__ ibody = [DummyExpr(FieldFromPointer(i._C_name, sbase), i._C_symbol) for i in parameters] ibody.extend([ BlankLine, DummyExpr(FieldFromPointer(sdata._field_id, sbase), sid), DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1) ]) iparameters = parameters + [sdata, sid] isdata = Callable(iname, ibody, 'void', iparameters, 'static') # Prepend the SharedData fields available upon thread activation preactions = [DummyExpr(i, FieldFromPointer(i.name, sbase)) for i in dynamic_parameters] # Append the flag reset postactions = [List(body=[ BlankLine, DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1) ])] iet = List(body=preactions + [iet] + postactions) # The thread has work to do when it receives the signal that all locks have # been set to 0 by the main thread iet = Conditional(CondEq(FieldFromPointer(sdata._field_flag, sbase), 2), iet) # The thread keeps spinning until the alive flag is set to 0 by the main thread iet = While(CondNe(FieldFromPointer(sdata._field_flag, sbase), 0), iet) # pthread functions expect exactly one argument, a void*, and must return void* tretval = 'void*' tparameter = VoidPointer('_%s' % sdata.name) # Unpack `sdata` unpack = [PointerCast(sdata, tparameter), BlankLine] for i in parameters: if i.is_AbstractFunction: unpack.extend([Dereference(i, sdata), PointerCast(i)]) else: unpack.append(DummyExpr(i, FieldFromPointer(i.name, sbase))) unpack.append(DummyExpr(sid, FieldFromPointer(sdata._field_id, sbase))) unpack.append(BlankLine) iet = List(body=unpack + [iet, BlankLine, Return(Macro('NULL'))]) tfunc = ThreadFunction(name, iet, tretval, tparameter, 'static') return tfunc, isdata, sdata
def guard(clusters): """ Return a new :class:`ClusterGroup` including new :class:`PartialCluster`s for each conditional expression encountered in ``clusters``. """ processed = ClusterGroup() for c in clusters: # Find out what expressions in /c/ should be guarded mapper = {} for e in c.exprs: for k, v in e.ispace.sub_iterators.items(): for i in v: if i.dim.is_Conditional: mapper.setdefault(i.dim, []).append(e) # Build conditional expressions to guard clusters conditions = {d: CondEq(d.parent % d.factor, 0) for d in mapper} negated = {d: CondNe(d.parent % d.factor, 0) for d in mapper} # Expand with guarded clusters combs = list(powerset(mapper)) for dims, ndims in zip(combs, reversed(combs)): banned = flatten(v for k, v in mapper.items() if k not in dims) exprs = [ e.xreplace({i: IntDiv(i.parent, i.factor) for i in mapper}) for e in c.exprs if e not in banned ] guards = [(i.parent, conditions[i]) for i in dims] guards.extend([(i.parent, negated[i]) for i in ndims]) cluster = PartialCluster(exprs, c.ispace, c.dspace, c.atomics, dict(guards)) processed.append(cluster) return processed
def _make_halowait(self, f, hse, key, msg=None): cast = cast_mapper[(f.dtype, '*')] fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} dim = Dimension(name='i') msgi = IndexedPointer(msg, dim) bufs = FieldFromComposite(msg._C_field_bufs, msgi) fromrank = FieldFromComposite(msg._C_field_from, msgi) sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi) for i in range(len(f._dist_dimensions))] ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi) for i in range(len(f._dist_dimensions))] ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions] # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Call('scatter%s' % key, [cast(bufs)] + sizes + [f] + ofss) scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi)) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi)) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) # The -1 below is because an Iteration, by default, generates <= ncomms = Symbol(name='ncomms') iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1) parameters = ([f] + list(fixed.values()) + [msg, ncomms]) return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
def _make_sendrecv(self, f, hse, key, msg=None): comm = f.grid.distributor._obj_comm bufg = FieldFromPointer(msg._C_field_bufg, msg) bufs = FieldFromPointer(msg._C_field_bufs, msg) ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) for i in range(len(f._dist_dimensions))] gather = Call('gather%s' % key, [bufg] + sizes + [f] + ofsg) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) count = reduce(mul, sizes, 1) rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) recv = IrecvCall([bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv]) send = IsendCall([bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend]) iet = List(body=[recv, gather, send]) parameters = ([f] + ofsg + [fromrank, torank, comm, msg]) return SendRecv(key, iet, parameters, bufg, bufs)
def _make_wait(self, f, hse, key, msg=None): bufs = FieldFromPointer(msg._C_field_bufs, msg) ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') sizes = [ FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg) for i in range(len(f._dist_dimensions)) ] scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg)) waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')]) rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg)) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[waitsend, waitrecv, scatter]) parameters = ([f] + ofss + [fromrank, msg]) return Callable('wait_%s' % key, iet, 'void', parameters, ('static', ))
def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break deviceid = DeviceID() device_nvidia = Macro('acc_device_nvidia') if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = DefFunction('acc_get_num_devices', device_nvidia) ngpus_init = LocalExpression(DummyEq(ngpus, call)) asdn_then = Call('acc_set_device_num', [deviceid, device_nvidia]) asdn_else = Call('acc_set_device_num', [rank % ngpus, device_nvidia]) body = [ Call('acc_init', [device_nvidia]), Conditional( CondNe(deviceid, -1), asdn_then, List(body=[rank_decl, rank_init, ngpus_init, asdn_else])) ] else: body = [ Call('acc_init', [device_nvidia]), Conditional( CondNe(deviceid, -1), Call('acc_set_device_num', [deviceid, device_nvidia])) ] init = List(header=c.Comment('Begin of OpenACC+MPI setup'), body=body, footer=(c.Comment('End of OpenACC+MPI setup'), c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def _(iet): devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid init = Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body) iet = iet._rebuild(body=body) return iet, {}
def _make_waitprefetch(self, iet, sync_ops, pieces, *args): ff = SharedData._field_flag waits = [] objs = filter_ordered(pieces.objs.get(s) for s in sync_ops) for sdata, threads in objs: wait = BusyWait( CondNe(FieldFromComposite(ff, sdata[threads.index]), 1)) waits.append(wait) iet = List(header=c.Comment("Wait for the arrival of prefetched data"), body=waits + [BlankLine, iet]) return iet
def _make_haloupdate(self, f, hse, key, msg=None): comm = f.grid.distributor._obj_comm fixed = {d: Symbol(name="o%s" % d.root) for d in hse.loc_indices} dim = Dimension(name='i') msgi = IndexedPointer(msg, dim) bufg = FieldFromComposite(msg._C_field_bufg, msgi) bufs = FieldFromComposite(msg._C_field_bufs, msgi) fromrank = FieldFromComposite(msg._C_field_from, msgi) torank = FieldFromComposite(msg._C_field_to, msgi) sizes = [ FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi) for i in range(len(f._dist_dimensions)) ] ofsg = [ FieldFromComposite('%s[%d]' % (msg._C_field_ofsg, i), msgi) for i in range(len(f._dist_dimensions)) ] ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions] # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Call('gather_%s' % key, [bufg] + sizes + [f] + ofsg) gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # Make Irecv/Isend count = reduce(mul, sizes, 1) rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi)) rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi)) recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) # The -1 below is because an Iteration, by default, generates <= ncomms = Symbol(name='ncomms') iet = Iteration([recv, gather, send], dim, ncomms - 1) parameters = ([f, comm, msg, ncomms]) + list(fixed.values()) return Callable('haloupdate%d' % key, iet, 'void', parameters, ('static', ))
def sendrecv(f, fixed): """Construct an IET performing a halo exchange along arbitrary dimension and side.""" assert f.is_Function assert f.grid is not None comm = f.grid.distributor._C_comm buf_dims = [Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap') dat_dims = [Dimension(name='dat_%s' % d.root) for d in f.dimensions] dat = Array(name='dat', dimensions=dat_dims, dtype=f.dtype, scope='external') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') parameters = [bufg] + list(bufg.shape) + [dat] + list(dat.shape) + ofsg gather = Call('gather_%s' % f.name, parameters) parameters = [bufs] + list(bufs.shape) + [dat] + list(dat.shape) + ofss scatter = Call('scatter_%s' % f.name, parameters) # The scatter must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) srecv = MPIStatusObject(name='srecv') rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') count = reduce(mul, bufs.shape, 1) recv = Call('MPI_Irecv', [bufs, count, Macro(numpy_to_mpitypes(f.dtype)), fromrank, '13', comm, rrecv]) send = Call('MPI_Isend', [bufg, count, Macro(numpy_to_mpitypes(f.dtype)), torank, '13', comm, rsend]) waitrecv = Call('MPI_Wait', [rrecv, srecv]) waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) iet = List(body=[ArrayCast(dat), iet_insert_C_decls(iet)]) parameters = ([dat] + list(dat.shape) + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm]) return Callable('sendrecv_%s' % f.name, iet, 'void', parameters, ('static',))
def __make_tfunc(self, name, iet, root, threads): # Create the SharedData required = derive_parameters(iet) known = (root.parameters + tuple(i for i in required if i.is_Array and i._mem_shared)) parameters, dynamic_parameters = split(required, lambda i: i in known) sdata = SharedData(name=self.sregistry.make_name(prefix='sdata'), nthreads_std=threads.size, fields=dynamic_parameters) parameters.append(sdata) # Prepend the unwinded SharedData fields, available upon thread activation preactions = [ DummyExpr(i, FieldFromPointer(i.name, sdata.symbolic_base)) for i in dynamic_parameters ] preactions.append( DummyExpr(sdata.symbolic_id, FieldFromPointer(sdata._field_id, sdata.symbolic_base))) # Append the flag reset postactions = [ List(body=[ BlankLine, DummyExpr( FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 1) ]) ] iet = List(body=preactions + [iet] + postactions) # Append the flag reset # The thread has work to do when it receives the signal that all locks have # been set to 0 by the main thread iet = Conditional( CondEq(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 2), iet) # The thread keeps spinning until the alive flag is set to 0 by the main thread iet = While( CondNe(FieldFromPointer(sdata._field_flag, sdata.symbolic_base), 0), iet) return Callable(name, iet, 'void', parameters, 'static'), sdata
def _(iet): body = FindNodes(WhileAlive).visit(iet) assert len(body) == 1 body = body.pop() devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid init = Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype) ) mapper = {body: List(body=[init, BlankLine, body])} iet = Transformer(mapper).visit(iet) return iet, {}
def _make_sendrecv(self, f, fixed, extra=None): extra = extra or [] comm = f.grid.distributor._obj_comm buf_dims = [ Dimension(name='buf_%s' % d.root) for d in f.dimensions if d not in fixed ] bufg = Array(name='bufg', dimensions=buf_dims, dtype=f.dtype, scope='heap') bufs = Array(name='bufs', dimensions=buf_dims, dtype=f.dtype, scope='heap') ofsg = [Symbol(name='og%s' % d.root) for d in f.dimensions] ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions] fromrank = Symbol(name='fromrank') torank = Symbol(name='torank') args = [bufg] + list(bufg.shape) + [f] + ofsg + extra gather = Call('gather%dd' % f.ndim, args) args = [bufs] + list(bufs.shape) + [f] + ofss + extra scatter = Call('scatter%dd' % f.ndim, args) # The `gather` is unnecessary if sending to MPI.PROC_NULL gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather) # The `scatter` must be guarded as we must not alter the halo values along # the domain boundary, where the sender is actually MPI.PROC_NULL scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter) srecv = MPIStatusObject(name='srecv') ssend = MPIStatusObject(name='ssend') rrecv = MPIRequestObject(name='rrecv') rsend = MPIRequestObject(name='rsend') count = reduce(mul, bufs.shape, 1) recv = Call('MPI_Irecv', [ bufs, count, Macro(dtype_to_mpitype(f.dtype)), fromrank, Integer(13), comm, rrecv ]) send = Call('MPI_Isend', [ bufg, count, Macro(dtype_to_mpitype(f.dtype)), torank, Integer(13), comm, rsend ]) waitrecv = Call('MPI_Wait', [rrecv, srecv]) waitsend = Call('MPI_Wait', [rsend, ssend]) iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter]) iet = List(body=iet_insert_C_decls(iet)) parameters = ([f] + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm] + extra) return Callable('sendrecv%dd' % f.ndim, iet, 'void', parameters, ('static', ))
def _make_fetchprefetch(self, iet, sync_ops, pieces, root): fid = SharedData._field_id fetches = [] prefetches = [] presents = [] for s in sync_ops: f = s.function dimensions = s.dimensions fc = s.fetch ifc = s.ifetch pfc = s.pfetch fcond = s.fcond pcond = s.pcond # Construct init IET imask = [(ifc, s.size) if d.root is s.dim.root else FULL for d in dimensions] fetch = PragmaTransfer(self.lang._map_to, f, imask=imask) fetches.append(Conditional(fcond, fetch)) # Construct present clauses imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in dimensions] presents.append( PragmaTransfer(self.lang._map_present, f, imask=imask)) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in dimensions] prefetch = PragmaTransfer(self.lang._map_to_wait, f, imask=imask, queueid=fid) prefetches.append(Conditional(pcond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine])) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)) ] + presents + [iet, tctx.activate]) # Fire up the threads pieces.init.append(tctx.init) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) # Keep track of created objects pieces.objs.add(sync_ops, sdata, threads) return iet
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): threads = self.__make_threads() fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) fsize = s.function._C_get_field(FULL, s.dim).size fc_cond = fc + (s.size - 1) < fsize pfc = s.fetch + 1 pfc_cond = pfc + (s.size - 1) < fsize else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) fc_cond = fc >= 0 pfc = s.fetch - 1 pfc_cond = pfc >= 0 # Construct fetch IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = List(header=self._P._map_to(s.function, imask)) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self._P._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = List(header=self._P._map_to_wait( s.function, imask, SharedData._field_id)) prefetches.append(Conditional(pfc_cond, prefetch)) functions = filter_ordered(s.function for s in sync_ops) casts = [PointerCast(f) for f in functions] # Turn init IET into a Callable name = self.sregistry.make_name(prefix='init_device') body = List(body=casts + fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append( List(header=c.Comment("Initialize data stream for `%s`" % threads.name), body=[Call(name, func.parameters), BlankLine])) # Turn prefetch IET into a threaded Callable name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=casts + prefetches) tfunc, sdata = self.__make_tfunc(name, body, root, threads) pieces.funcs.append(tfunc) # Glue together all the IET pieces, including the activation bits iet = List(body=[ BlankLine, BusyWait( CondNe( FieldFromComposite(sdata._field_flag, sdata[ threads.index]), 1)), List(header=presents), iet, self.__make_activate_thread(threads, sdata, sync_ops) ]) # Fire up the threads pieces.init.append( self.__make_init_threads(threads, sdata, tfunc, pieces)) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(self.__make_finalize_threads(threads, sdata)) return iet
def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root): fetches = [] prefetches = [] presents = [] for s in sync_ops: if s.direction is Forward: fc = s.fetch.subs(s.dim, s.dim.symbolic_min) pfc = s.fetch + 1 fc_cond = s.next_cbk(s.dim.symbolic_min) pfc_cond = s.next_cbk(s.dim + 1) else: fc = s.fetch.subs(s.dim, s.dim.symbolic_max) pfc = s.fetch - 1 fc_cond = s.next_cbk(s.dim.symbolic_max) pfc_cond = s.next_cbk(s.dim - 1) # Construct init IET imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] fetch = PragmaList(self.lang._map_to(s.function, imask), {s.function} | fc.free_symbols) fetches.append(Conditional(fc_cond, fetch)) # Construct present clauses imask = [(s.fetch, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] presents.extend(as_list(self.lang._map_present(s.function, imask))) # Construct prefetch IET imask = [(pfc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions] prefetch = PragmaList(self.lang._map_to_wait(s.function, imask, SharedData._field_id), {s.function} | pfc.free_symbols) prefetches.append(Conditional(pfc_cond, prefetch)) # Turn init IET into a Callable functions = filter_ordered(s.function for s in sync_ops) name = self.sregistry.make_name(prefix='init_device') body = List(body=fetches) parameters = filter_sorted(functions + derive_parameters(body)) func = Callable(name, body, 'void', parameters, 'static') pieces.funcs.append(func) # Perform initial fetch by the main thread pieces.init.append(List( header=c.Comment("Initialize data stream"), body=[Call(name, parameters), BlankLine] )) # Turn prefetch IET into a ThreadFunction name = self.sregistry.make_name(prefix='prefetch_host_to_device') body = List(header=c.Line(), body=prefetches) tctx = make_thread_ctx(name, body, root, None, sync_ops, self.sregistry) pieces.funcs.extend(tctx.funcs) # Glue together all the IET pieces, including the activation logic sdata = tctx.sdata threads = tctx.threads iet = List(body=[ BlankLine, BusyWait(CondNe(FieldFromComposite(sdata._field_flag, sdata[threads.index]), 1)), List(header=presents), iet, tctx.activate ]) # Fire up the threads pieces.init.append(tctx.init) pieces.threads.append(threads) # Final wait before jumping back to Python land pieces.finalize.append(tctx.finalize) return iet