def _make_thread_activate(threads, sdata, sync_ops, sregistry): if threads.size == 1: d = threads.index else: d = Symbol(name=sregistry.make_name(prefix=threads.index.name)) sync_locks = [s for s in sync_ops if s.is_SyncLock] condition = Or(*([CondNe(s.handle, 2) for s in sync_locks] + [CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)])) if threads.size == 1: activation = [While(condition)] else: activation = [DummyExpr(d, 0), While(condition, DummyExpr(d, (d + 1) % threads.size))] activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i) for i in sdata.dynamic_fields]) activation.extend([DummyExpr(s.handle, 0) for s in sync_locks]) activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2)) activation = List( header=[c.Line(), c.Comment("Activate `%s`" % threads.name)], body=activation, footer=c.Line() ) return activation
def __init__(self, timer, lname, body): self._name = lname self._timer = timer super().__init__(header=c.Line('START_TIMER(%s)' % lname), body=body, footer=c.Line('STOP_TIMER(%s,%s)' % (lname, timer.name)))
def _dump_storage(self, iet, storage): mapper = {} for k, v in storage.items(): # Expr -> LocalExpr ? if k.is_Expression: mapper[k] = v continue # allocs/pallocs allocs = flatten(v.allocs) for tid, body in as_mapper(v.pallocs, itemgetter(0), itemgetter(1)).items(): header = self.lang.Region._make_header(tid.symbolic_size) init = c.Initializer(c.Value(tid._C_typedata, tid.name), self.lang['thread-num']) allocs.append(c.Module((header, c.Block([init] + body)))) if allocs: allocs.append(c.Line()) # frees/pfrees frees = [] for tid, body in as_mapper(v.pfrees, itemgetter(0), itemgetter(1)).items(): header = self.lang.Region._make_header(tid.symbolic_size) init = c.Initializer(c.Value(tid._C_typedata, tid.name), self.lang['thread-num']) frees.append(c.Module((header, c.Block([init] + body)))) frees.extend(flatten(v.frees)) if frees: frees.insert(0, c.Line()) mapper[k] = k._rebuild(body=List(header=allocs, body=k.body, footer=frees), **k.args_frozen) processed = Transformer(mapper, nested=True).visit(iet) return processed
def gen_op(op_name, attrs, inputs, outputs, output_shapes, kernel_class_name, kernel_header): # Add the headers into the module contents = [] contents.append(c.Include("tensorflow/core/framework/op.h", system = False)) contents.append(c.Include("tensorflow/core/framework/shape_inference.h", system = False)) contents.append(c.Include("tensorflow/core/framework/op_kernel.h", system = False)) contents.append(c.Include(kernel_header, system = False)) # Name space declarations contents.append(c.Line()) contents.append(c.Statement("using namespace tensorflow")) contents.append(c.Line()) shape_fn = gen_shape_fn(output_shapes); # Registration macro reg_macro = gen_reg_op_macro_str(op_name, attrs, inputs, outputs, shape_fn) contents.append(c.Statement(reg_macro)) contents.append(c.Line()) class_defn = gen_op_kernel_class_defn(kernel_class_name) contents.extend(class_defn) contents.append(c.Line()) kernel_build_macro = gen_kernel_build_macro(op_name, kernel_class_name) contents.append(c.Statement(kernel_build_macro)) return c.Module(contents)
def _generate_lib_outer_loop(self): block = cgen.Block([self._components['LIB_KERNEL_CALL']]) i = self._components['LIB_PAIR_INDEX_0'] shared = '' for sx in self._components['OMP_SHARED_SYMS']: shared += sx + ',' shared = shared[:-1] pragma = cgen.Pragma('omp parallel default(none) shared(' + shared + ')') parallel_region = cgen.Block(( cgen.Value('int', '_thread_start'), cgen.Value('int', '_thread_end'), cgen.Line( 'get_thread_decomp((int)_N_LOCAL, &_thread_start, &_thread_end);' ), cgen.For('int ' + i + '= _thread_start', i + '< _thread_end', i + '++', block))) loop = cgen.Module([ cgen.Line('omp_set_num_threads(_NUM_THREADS);'), pragma, parallel_region ]) self._components['LIB_OUTER_LOOP'] = loop
def _initialize(iet): comm = None for i in iet.parameters: if isinstance(i, MPICommObject): comm = i break if comm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [comm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = Function('omp_get_num_devices')() ngpus_init = LocalExpression(DummyEq(ngpus, call)) devicenum_init = LocalExpression(DummyEq(devicenum, rank % ngpus)) body = [rank_decl, rank_init, ngpus_init, devicenum_init] init = List(header=c.Comment('Begin of OpenMP+MPI setup'), body=body, footer=(c.Comment('End of OpenMP+MPI setup'), c.Line())) else: devicenum_init = LocalExpression(DummyEq(devicenum, 0)) body = [devicenum_init] init = List(header=c.Comment('Begin of OpenMP setup'), body=body, footer=(c.Comment('End of OpenMP setup'), c.Line())) iet = iet._rebuild(body=(init,) + iet.body) return iet
def visit_Iteration(self, o): body = flatten(self._visit(i) for i in o.children) _min = o.limits[0] _max = o.limits[1] # For backward direction flip loop bounds if o.direction == Backward: loop_init = 'int %s = %s' % (o.index, ccode(_max)) loop_cond = '%s >= %s' % (o.index, ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: loop_init = 'int %s = %s' % (o.index, ccode(_min)) loop_cond = '%s <= %s' % (o.index, ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body handle = c.For(loop_init, loop_cond, loop_inc, c.Block(body)) # Attach pragmas, if any if o.pragmas: handle = c.Module(o.pragmas + (handle,)) return handle
def visit_Operator(self, o): blankline = c.Line("") # Kernel signature and body body = flatten(self._visit(i) for i in o.children) decls = self._args_decl(o.parameters) signature = c.FunctionDeclaration(c.Value(o.retval, o.name), decls) retval = [c.Statement("return 0")] kernel = c.FunctionBody(signature, c.Block(body + retval)) # Elemental functions esigns = [] efuncs = [blankline] for i in o._func_table.values(): if i.local: esigns.append( c.FunctionDeclaration(c.Value(i.root.retval, i.root.name), self._args_decl(i.root.parameters))) efuncs.extend([i.root.ccode, blankline]) # Header files, extra definitions, ... header = [c.Line(i) for i in o._headers] includes = [c.Include(i, system=False) for i in o._includes] includes += [blankline] cdefs = [ i._C_typedecl for i in o.parameters if i._C_typedecl is not None ] cdefs = filter_sorted(cdefs, key=lambda i: i.tpname) if o._compiler.src_ext == 'cpp': cdefs += [c.Extern('C', signature)] cdefs = [i for j in cdefs for i in (j, blankline)] return c.Module(header + includes + cdefs + esigns + [blankline, kernel] + efuncs)
def ns(self, parent): ns = c.Collection([c.Line(f'namespace {self.namespace[0]}')]) parent.append(ns) if len(self.namespace) > 1: block = c.Block([c.Line(f'class {self.namespace[1]}')]) ns.append(block) return block return parent
def visit_Iteration(self, o): body = flatten(self.visit(i) for i in o.children) # Start if o.offsets[0] != 0: start = str(o.limits[0] + o.offsets[0]) try: start = eval(start) except (NameError, TypeError): pass else: start = o.limits[0] # Bound if o.offsets[1] != 0: end = str(o.limits[1] + o.offsets[1]) try: end = eval(end) except (NameError, TypeError): pass else: end = o.limits[1] # For backward direction flip loop bounds if o.direction == Backward: loop_init = 'int %s = %s' % (o.index, ccode(end)) loop_cond = '%s >= %s' % (o.index, ccode(start)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: loop_init = 'int %s = %s' % (o.index, ccode(start)) loop_cond = '%s <= %s' % (o.index, ccode(end)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: uinit = [ '%s = %s' % (i.name, ccode(i.symbolic_start)) for i in o.uindices ] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [ '%s = %s' % (i.name, ccode(i.symbolic_incr)) for i in o.uindices ] loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body handle = c.For(loop_init, loop_cond, loop_inc, c.Block(body)) # Attach pragmas, if any if o.pragmas: handle = c.Module(o.pragmas + (handle, )) return handle
def process(self, iet): sync_spots = FindNodes(SyncSpot).visit(iet) if not sync_spots: return iet, {} def key(s): # The SyncOps are to be processed in the following order return [ WaitLock, WithLock, Delete, FetchUpdate, FetchPrefetch, PrefetchUpdate, WaitPrefetch ].index(s) callbacks = { WaitLock: self._make_waitlock, WithLock: self._make_withlock, Delete: self._make_delete, FetchUpdate: self._make_fetchupdate, FetchPrefetch: self._make_fetchprefetch, PrefetchUpdate: self._make_prefetchupdate } postponed_callbacks = {WaitPrefetch: self._make_waitprefetch} all_callbacks = [callbacks, postponed_callbacks] pieces = namedtuple('Pieces', 'init finalize funcs objs')([], [], [], Objs()) # The processing is a two-step procedure; first, we apply the `callbacks`; # then, the `postponed_callbacks`, as these depend on objects produced by the # `callbacks` subs = {} for cbks in all_callbacks: for n in sync_spots: mapper = as_mapper(n.sync_ops, lambda i: type(i)) for _type in sorted(mapper, key=key): try: subs[n] = cbks[_type](subs.get(n, n), mapper[_type], pieces, iet) except KeyError: pass iet = Transformer(subs).visit(iet) # Add initialization and finalization code init = List(body=pieces.init, footer=c.Line()) finalize = List(header=c.Line(), body=pieces.finalize) body = iet.body._rebuild(body=(init, ) + iet.body.body + (finalize, )) iet = iet._rebuild(body=body) return iet, { 'efuncs': pieces.funcs, 'includes': ['pthread.h'], 'args': [i.size for i in pieces.objs.threads if not is_integer(i.size)] }
def _generate_lib_outer_loop(self): block = cgen.Block([ self._components['LIB_KERNEL_GATHER'], self._components['LIB_INNER_LOOP'], self._components['LIB_KERNEL_SCATTER'] ]) cx = self._components['LIB_CELL_CX'] cy = self._components['LIB_CELL_CY'] cz = self._components['LIB_CELL_CZ'] ncx = self._components['N_CELL_X'] ncy = self._components['N_CELL_Y'] ncz = self._components['N_CELL_Z'] exec_count = self._components['EXEC_COUNT'] red_exec_count = '_' + exec_count npad = self._components['N_CELL_PAD'] shared = '' for sx in self._components['OMP_SHARED_SYMS']: shared += sx + ',' shared = shared[:-1] pragma = cgen.Pragma('omp parallel for default(none) reduction(+:' + \ red_exec_count +') schedule(dynamic) collapse(3) ' + \ 'shared(' + shared + ')') if runtime.OMP_NUM_THREADS is None: pragma = cgen.Comment(pragma) loop = cgen.Module([ cgen.Line('omp_set_num_threads(_NUM_THREADS);'), cgen.Line('INT64 ' + red_exec_count + ' = 0;'), pragma, # cellx loop cgen.For( 'INT64 ' + cx + '=' + npad, cx + '<' + ncx + '-' + npad, cx + '++', cgen.Block([ cgen.For( 'INT64 ' + cy + '=' + npad, cy + '<' + ncy + '-' + npad, cy + '++', cgen.Block((cgen.For('INT64 ' + cz + '=' + npad, cz + '<' + ncz + '-' + npad, cz + '++', block), ))), ])), cgen.Line('*' + exec_count + ' += ' + red_exec_count + ';') ]) self._components['LIB_OUTER_LOOP'] = loop
def left_contractions(self, pos): """Generates the code computing the left-contraction part of the opimization matrix for site nr. `pos` :param pos: The local tensor to copy (should be `< len(X)`) :returns: List containing cgen Statements """ if pos == 0: return [c.Statement('left_c[0] = 1')] result = self.copy_ltens_to_share(0) result += [c.Line()] contract_ltens_with_a = 'dgemv(blasNoTranspose, x_shared, current_row + {offset:d}, {dim_out:d}, {dim_in:d}, {target:})' src = contract_ltens_with_a.format(offset=0, dim_out=self._ranks[0], dim_in=self._dims[0], target='left_c') # We need to check this every time and can't simpy return since # otherwise __syncthreads crashes result += [c.If('mid < %i' % self._meas, c.Statement(src))] for i in range(1, pos): result += self.copy_ltens_to_share(i) result += [c.Line()] # Since we assume A to consist of product measurements result += [ c.If( 'mid < %i' % self._meas, c.Block([ c.Statement( contract_ltens_with_a.format( offset=sum(self._dims[:i]), dim_out=self._ranks[i - 1] * self._ranks[i], dim_in=self._dims[i], target='tmat_c')), c.Statement( 'dgemv(blasTranspose, tmat_c, left_c, {rank_l}, {rank_r}, buf_c)' .format(rank_l=self._ranks[i - 1], rank_r=self._ranks[i])), c.Statement( 'memcpy(left_c, buf_c, sizeof({ctype}) * {rank_r})' .format(ctype=c.dtype_to_ctype(self._dtype), rank_r=self._ranks[i])) ])), c.Line() ] return result
def test_list_denesting(): l0 = List(header=cgen.Line('a'), body=List(header=cgen.Line('b'))) l1 = l0._rebuild(body=List(header=cgen.Line('c'))) assert len(l0.body) == 0 assert len(l1.body) == 0 assert str(l1) == "a\nb\nc" l2 = l1._rebuild(l1.body) assert len(l2.body) == 0 assert str(l2) == str(l1) l3 = l2._rebuild(l2.body, **l2.args_frozen) assert len(l3.body) == 0 assert str(l3) == str(l2)
def right_contractions(self, pos): """Generates the code computing the right-contraction part of the opimization matrix for site nr. `pos` :param pos: The local tensor to copy (should be `< len(X)`) :returns: List containing cgen Statements """ if pos == self._sites - 1: return [c.Statement('right_c[0] = 1')] result = self.copy_ltens_to_share(self._sites - 1) result += [c.Line()] contract_ltens_with_a = 'dgemv(blasNoTranspose, x_shared, current_row + {offset:d}, {dim_out:d}, {dim_in:d}, {target:})' src = contract_ltens_with_a.format(offset=sum(self._dims[:-1]), dim_out=self._ranks[-1], dim_in=self._dims[-1], target='right_c') result += [c.If('mid < %i' % self._meas, c.Statement(src))] for i in range(self._sites - 2, pos, -1): result += self.copy_ltens_to_share(i) result += [c.Line()] # Since we assume A to consist of product measurements result += [ c.If( 'mid < %i' % self._meas, c.Block([ c.Statement( contract_ltens_with_a.format( offset=sum(self._dims[:i]), dim_out=self._ranks[i - 1] * self._ranks[i], dim_in=self._dims[i], target='tmat_c')), c.Statement( 'dgemv(blasNoTranspose, tmat_c, right_c, {rank_l}, {rank_r}, buf_c)' .format(rank_l=self._ranks[i - 1], rank_r=self._ranks[i])), c.Statement( 'memcpy(right_c, buf_c, sizeof({ctype}) * {rank_l})' .format(ctype=c.dtype_to_ctype(self._dtype), rank_l=self._ranks[i - 1])), ])), c.Line() ] return result
def _make_delete(self, iet, sync_ops, *args): # Construct deletion clauses deletions = [] for s in sync_ops: dimensions = s.dimensions fc = s.fetch imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in dimensions] deletions.append(self.lang._map_delete(s.function, imask)) # Glue together the new IET pieces iet = List(header=c.Line(), body=iet, footer=[c.Line()] + deletions) return iet
def _generate_kernel_func(self): self._components['KERNEL_FUNC'] = cgen.FunctionBody( cgen.FunctionDeclaration( cgen.DeclSpecifier( cgen.Value("void", 'k_' + self._kernel.name), 'inline'), self._components['KERNEL_ARG_DECLS']), cgen.Block([cgen.Line(self._kernel.code)]))
def _generate_lib_outer_loop(self): block = cgen.Block([self._components['LIB_KERNEL_GATHER'], self._components['LIB_INNER_LOOP'], self._components['LIB_KERNEL_SCATTER']]) i = self._components['LIB_PAIR_INDEX_0'] shared = '' for sx in self._components['OMP_SHARED_SYMS']: shared+= sx+',' shared = shared[:-1] pragma = cgen.Pragma('omp parallel for schedule(static) // default(shared) shared(' + shared + ')') if runtime.OMP_NUM_THREADS is None: pragma = cgen.Comment(pragma) loop = cgen.Module([ cgen.Line('omp_set_num_threads(_NUM_THREADS);'), pragma, cgen.For('int ' + i + '=0', i + '<_N_LOCAL', i+'++', block) ]) self._components['LIB_OUTER_LOOP'] = loop
def test_transformer_wrap(exprs, block1, block2, block3): """Basic transformer test that wraps an expression in comments""" line1 = '// This is the opening comment' line2 = '// This is the closing comment' wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2)) transformer = Transformer({exprs[0]: wrapper(exprs[0])}) for block in [block1, block2, block3]: newblock = transformer.visit(block) newcode = str(newblock.ccode) oldnumlines = len(str(block.ccode).split('\n')) newnumlines = len(newcode.split('\n')) assert newnumlines >= oldnumlines + 2 assert line1 in newcode assert line2 in newcode assert "a[i] = a[i] + b[i] + 5.0F;" in newcode
def _generate_lib_inner_loop_block(self): # generate j gather #'J_GATHER' cj = self._components['LIB_CELL_INDEX_1'] j_gather = cgen.Module([ cgen.Comment('#### Pre kernel j gather ####'), ]) inner_l = [] src_sym = '_tmp_jgpx' dst_sym = self._components['CCC_1'] # add dats to omp shared and init global array reduction for i, dat in enumerate(self._dat_dict.items()): obj = dat[1][0] mode = dat[1][1] symbol = dat[0] if issubclass(type(obj), data.ParticleDat): tsym = self._components['PARTICLE_DAT_PARTITION'].jdict[symbol] inner_l.append( DSLStrideGather(symbol, tsym, obj.ncomp, src_sym, dst_sym, self._components['CCC_MAX'])) inner_l.append(cgen.Line(dst_sym + '++;')) inner = cgen.Module(inner_l) g = self._components['CELL_LIST_ITER'](src_sym, cj, inner) j_gather.append(cgen.Initializer(cgen.Value('INT64', dst_sym), '0')) j_gather.append(g) self._components['J_GATHER'] = j_gather
def visit_Operator(self, o, mode='all'): # Kernel signature and body body = flatten(self._visit(i) for i in o.children) decls = self._args_decl(o.parameters) signature = c.FunctionDeclaration(c.Value(o.retval, o.name), decls) retval = [c.Line(), c.Statement("return 0")] kernel = c.FunctionBody(signature, c.Block(body + retval)) # Elemental functions esigns = [] efuncs = [blankline] for i in o._func_table.values(): if i.local: prefix = ' '.join(i.root.prefix + (i.root.retval, )) esigns.append( c.FunctionDeclaration(c.Value(prefix, i.root.name), self._args_decl(i.root.parameters))) efuncs.extend([self._visit(i.root), blankline]) # Definitions headers = [c.Define(*i) for i in o._headers] + [blankline] # Header files includes = self._operator_includes(o) + [blankline] # Type declarations typedecls = self._operator_typedecls(o, mode) if mode in ('all', 'public') and o._compiler.src_ext in ('cpp', 'cu'): typedecls.append(c.Extern('C', signature)) typedecls = [i for j in typedecls for i in (j, blankline)] return c.Module(headers + includes + typedecls + esigns + [blankline, kernel] + efuncs)
def _generate_kernel_scatter(self): kernel_scatter = cgen.Module( [cgen.Comment('#### Post kernel scatter ####')]) ci = self._components['LIB_CELL_INDEX_0'] inner_l = [] src_sym = '_sgpx' dst_sym = '_shpx' # add dats to omp shared and init global array reduction for i, dat in enumerate(self._dat_dict.items()): obj = dat[1][0] mode = dat[1][1] symbol = dat[0] if issubclass(type(obj), data.ParticleDat) and mode.write: tsym = self._components['PARTICLE_DAT_PARTITION'].idict[symbol] inner_l.append( DSLStrideScatter(tsym, symbol, obj.ncomp, dst_sym, src_sym, self._components['CCC_MAX'])) inner_l.append(cgen.Line(dst_sym + '++;')) inner = cgen.Module(inner_l) g = self._components['CELL_LIST_ITER'](src_sym, ci, inner) kernel_scatter.append( cgen.Initializer(cgen.Value('INT64', dst_sym), '0')) kernel_scatter.append(g) self._components['LIB_KERNEL_SCATTER'] = kernel_scatter
def _make_parregion(self, partree, parrays): arrays = [i for i in FindSymbols().visit(partree) if i.is_Array] # Detect thread-private arrays on the heap and "map" them to shared # vector-expanded (one entry per thread) Arrays heap_private = [i for i in arrays if i._mem_heap and i._mem_local] heap_globals = [] for i in heap_private: if i in parrays: pi = parrays[i] else: pi = parrays.setdefault( i, PointerArray(name=self.sregistry.make_name(), dimensions=(self.threadid, ), array=i)) heap_globals.append(Dereference(i, pi)) if heap_globals: body = List(header=self._make_tid(self.threadid), body=heap_globals + [partree], footer=c.Line()) else: body = partree return OpenMPRegion(body, partree.nthreads)
def _generate_kernel_call(self): kernel_call = cgen.Module( [cgen.Comment('#### Kernel call arguments ####')]) kernel_call_symbols = [] if self._kernel.static_args is not None: for i, dat in enumerate(self._kernel.static_args.items()): kernel_call_symbols.append(dat[0]) for i, dat in enumerate(self._dat_dict.items()): obj = dat[1][0] mode = dat[1][1] symbol = dat[0] g = self._components['PARTICLE_DAT_C'][symbol] kernel_call_symbols.append(g.kernel_arg) kernel_call.append(g.kernel_create_j_arg) self._components['KERNEL_GATHER'] += g.kernel_create_i_arg self._components['KERNEL_SCATTER'] += g.kernel_create_i_scatter kernel_call.append(cgen.Comment('#### Kernel call ####')) kernel_call_symbols_s = '' for sx in kernel_call_symbols: kernel_call_symbols_s += sx + ',' kernel_call_symbols_s = kernel_call_symbols_s[:-1] kernel_call.append( cgen.Line('k_' + self._kernel.name + '(' + kernel_call_symbols_s + ');')) self._components['LIB_KERNEL_CALL'] = kernel_call
def _(iet): # TODO: we need to pick the rank from `comm_shm`, not `comm`, # so that we have nranks == ngpus (as long as the user has launched # the right number of MPI processes per node given the available # number of GPUs per node) objcomm = None for i in iet.parameters: if isinstance(i, MPICommObject): objcomm = i break devicetype = as_list(self.lang[self.platform]) try: lang_init = [self.lang['init'](devicetype)] except TypeError: # Not all target languages need to be explicitly initialized lang_init = [] deviceid = DeviceID() if objcomm is not None: rank = Symbol(name='rank') rank_decl = LocalExpression(DummyEq(rank, 0)) rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)]) ngpus = Symbol(name='ngpus') call = self.lang['num-devices'](devicetype) ngpus_init = LocalExpression(DummyEq(ngpus, call)) osdd_then = self.lang['set-device']([deviceid] + devicetype) osdd_else = self.lang['set-device']([rank % ngpus] + devicetype) body = lang_init + [ Conditional( CondNe(deviceid, -1), osdd_then, List( body=[rank_decl, rank_init, ngpus_init, osdd_else ]), ) ] header = c.Comment('Begin of %s+MPI setup' % self.lang['name']) footer = c.Comment('End of %s+MPI setup' % self.lang['name']) else: body = lang_init + [ Conditional( CondNe(deviceid, -1), self.lang['set-device']([deviceid] + devicetype)) ] header = c.Comment('Begin of %s setup' % self.lang['name']) footer = c.Comment('End of %s setup' % self.lang['name']) init = List(header=header, body=body, footer=(footer, c.Line())) iet = iet._rebuild(body=(init, ) + iet.body) return iet, {'args': deviceid}
def generate(self): """Generate (i.e. yield) the source code of the module line-by-line. """ body = [] body += (self.preamble + [cgen.Line()] + self.body) return cgen.Module(body)
def _make_parregion(self, partree, parrays): if not any(i.is_ParallelPrivate for i in partree.collapsed): return self.Region(partree) # Vector-expand all written Arrays within `partree`, since at least # one of the parallelized Iterations requires thread-private Arrays # E.g. a(x, y) -> b(tid, x, y), where `tid` is the ThreadID Dimension exprs = FindNodes(Expression).visit(partree) warrays = [i.write for i in exprs if i.write.is_Array] vexpandeds = [] for i in warrays: if i in parrays: pi = parrays[i] else: pi = parrays.setdefault(i, PointerArray(name=self.sregistry.make_name(), dimensions=(self.threadid,), array=i)) vexpandeds.append(VExpanded(i, pi)) if vexpandeds: init = c.Initializer(c.Value(self.threadid._C_typedata, self.threadid.name), self.lang['thread-num']) prefix = List(header=init, body=vexpandeds + list(partree.prefix), footer=c.Line()) partree = partree._rebuild(prefix=prefix) return self.Region(partree)
def _make_parregion(self, partree, parrays): if not any(i.is_ParallelPrivate for i in partree.collapsed): return self.Region(partree) # Vector-expand all written Arrays within `partree`, since at least # one of the parallelized Iterations requires thread-private Arrays # E.g. a(x, y) -> b(tid, x, y), where `tid` is the ThreadID Dimension vexpandeds = [] for n in FindNodes(Expression).visit(partree): i = n.write if not (i.is_Array or i.is_TempFunction): continue elif i in parrays: pi = parrays[i] else: pi = parrays.setdefault(i, i._make_pointer(dim=self.threadid)) vexpandeds.append(VExpanded(i, pi)) if vexpandeds: init = self.lang['thread-num'](retobj=self.threadid) prefix = List(body=[init] + vexpandeds + list(partree.prefix), footer=c.Line()) partree = partree._rebuild(prefix=prefix) return self.Region(partree)
def place_casts(self, iet): """ Create a new IET with the necessary type casts. Parameters ---------- iet : Callable The input Iteration/Expression tree. """ functions = FindSymbols().visit(iet) need_cast = {i for i in functions if i.is_Tensor} # Make the generated code less verbose by avoiding unnecessary casts indexed_names = {i.name for i in FindSymbols('indexeds').visit(iet)} need_cast = { i for i in need_cast if i.name in indexed_names or i.is_ArrayBasic } casts = tuple(PointerCast(i) for i in iet.parameters if i in need_cast) if casts: casts = (List(body=casts, footer=c.Line()), ) iet = iet._rebuild(body=casts + iet.body) return iet, {}
def generate(self): objects = [] objects += self.includes objects += [c.Line()] objects += self.objects return "\n".join(str(x) for x in objects)