def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break devicetype = as_list(self.lang[self.platform]) try: lang_init = [self.lang['init'](devicetype)] except TypeError: # Not all target languages need to be explicitly initialized lang_init = [] deviceid = DeviceID() if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = self.lang['num-devices'](devicetype) ngpus_init = LocalExpression(DummyEq(ngpus, call)) osdd_then = self.lang['set-device']([deviceid] + devicetype) osdd_else = self.lang['set-device']([rank % ngpus] + devicetype) body = lang_init + [ Conditional( CondNe(deviceid, -1), osdd_then, List( body=[rank_decl, rank_init, ngpus_init, osdd_else ]), ) ] header = c.Comment('Begin of %s+MPI setup' % self.lang['name']) footer = c.Comment('End of %s+MPI setup' % self.lang['name']) else: body = lang_init + [ Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) ] header = c.Comment('Begin of %s setup' % self.lang['name']) footer = c.Comment('End of %s setup' % self.lang['name']) init = List(header=header, body=body, footer=(footer, c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def __init__(self, prodder): condition = CondEq(Function('omp_get_thread_num')(), 0) then_body = Call(prodder.name, prodder.arguments) Conditional.__init__(self, condition, then_body) Prodder.__init__(self, prodder.name, prodder.arguments, periodic=prodder.periodic)
def __init__(self, prodder): # Atomic-ize any single-thread Prodders in the parallel tree condition = CondEq(Ompizer.lang['thread-num'], 0) # Prod within a while loop until all communications have completed # In other words, the thread delegated to prodding is entrapped for as long # as it's required prod_until = Not(DefFunction(prodder.name, [i.name for i in prodder.arguments])) then_body = List(header=c.Comment('Entrap thread until comms have completed'), body=While(prod_until)) Conditional.__init__(self, condition, then_body) Prodder.__init__(self, prodder.name, prodder.arguments, periodic=prodder.periodic)
def _make_guard(self, parregion, *args): partrees = FindNodes(ParallelTree).visit(parregion) if not any(isinstance(i.root, self.DeviceIteration) for i in partrees): return super()._make_guard(parregion, *args) cond = [] # There must be at least one iteration or potential crash if not parregion.is_Affine: trees = retrieve_iteration_tree(parregion.root) tree = trees[0][:parregion.ncollapsed] cond.extend([i.symbolic_size > 0 for i in tree]) # SparseFunctions may occasionally degenerate to zero-size arrays. In such # a case, a copy-in produces a `nil` pointer on the device. To fire up a # parallel loop we must ensure none of the SparseFunction pointers are `nil` symbols = FindSymbols().visit(parregion) sfs = [i for i in symbols if i.is_SparseFunction] if sfs: size = [prod(f._C_get_field(FULL, d).size for d in f.dimensions) for f in sfs] cond.extend([i > 0 for i in size]) # Drop dynamically evaluated conditions (e.g. because the `symbolic_size` # is an integer value rather than a symbol). This avoids ugly and # unnecessary conditionals such as `if (true) { ...}` cond = [i for i in cond if i != true] # Combine all cond elements if cond: parregion = List(body=[Conditional(And(*cond), parregion)]) return parregion
def _make_guard(self, partree, collapsed): # Do not enter the parallel region if the step increment is 0; this # would raise a `Floating point exception (core dumped)` in some OpenMP # implementations. Note that using an OpenMP `if` clause won't work cond = [CondEq(i.step, 0) for i in collapsed if isinstance(i.step, Symbol)] cond = Or(*cond) if cond != False: # noqa: `cond` may be a sympy.False which would be == False partree = List(body=[Conditional(cond, Return()), partree]) return partree
def _(iet): devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid init = Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) body = iet.body._rebuild(body=(init, BlankLine) + iet.body.body) iet = iet._rebuild(body=body) return iet, {}
def _(iet): body = FindNodes(WhileAlive).visit(iet) assert len(body) == 1 body = body.pop() devicetype = as_list(self.lang[self.platform]) deviceid = self.deviceid init = Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype) ) mapper = {body: List(body=[init, BlankLine, body])} iet = Transformer(mapper).visit(iet) return iet, {}