def _make_thread_func(name, iet, root, threads, sregistry): # Create the SharedData, that is the data structure that will be used by the # main thread to pass information dows to the child thread(s) required, parameters, dynamic_parameters = diff_parameters(iet, root) parameters = sorted(parameters, key=lambda i: i.is_Function) # Allow casting sdata = SharedData(name=sregistry.make_name(prefix='sdata'), npthreads=threads.size, fields=required, dynamic_fields=dynamic_parameters) sbase = sdata.symbolic_base sid = sdata.symbolic_id # Create a Callable to initialize `sdata` with the known const values iname = 'init_%s' % sdata.dtype._type_.__name__ ibody = [DummyExpr(FieldFromPointer(i._C_name, sbase), i._C_symbol) for i in parameters] ibody.extend([ BlankLine, DummyExpr(FieldFromPointer(sdata._field_id, sbase), sid), DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1) ]) iparameters = parameters + [sdata, sid] isdata = Callable(iname, ibody, 'void', iparameters, 'static') # Prepend the SharedData fields available upon thread activation preactions = [DummyExpr(i, FieldFromPointer(i.name, sbase)) for i in dynamic_parameters] # Append the flag reset postactions = [List(body=[ BlankLine, DummyExpr(FieldFromPointer(sdata._field_flag, sbase), 1) ])] iet = List(body=preactions + [iet] + postactions) # The thread has work to do when it receives the signal that all locks have # been set to 0 by the main thread iet = Conditional(CondEq(FieldFromPointer(sdata._field_flag, sbase), 2), iet) # The thread keeps spinning until the alive flag is set to 0 by the main thread iet = While(CondNe(FieldFromPointer(sdata._field_flag, sbase), 0), iet) # pthread functions expect exactly one argument, a void*, and must return void* tretval = 'void*' tparameter = VoidPointer('_%s' % sdata.name) # Unpack `sdata` unpack = [PointerCast(sdata, tparameter), BlankLine] for i in parameters: if i.is_AbstractFunction: unpack.extend([Dereference(i, sdata), PointerCast(i)]) else: unpack.append(DummyExpr(i, FieldFromPointer(i.name, sbase))) unpack.append(DummyExpr(sid, FieldFromPointer(sdata._field_id, sbase))) unpack.append(BlankLine) iet = List(body=unpack + [iet, BlankLine, Return(Macro('NULL'))]) tfunc = ThreadFunction(name, iet, tretval, tparameter, 'static') return tfunc, isdata, sdata
def _make_thread_activate(threads, sdata, sync_ops, sregistry): if threads.size == 1: d = threads.index else: d = Symbol(name=sregistry.make_name(prefix=threads.index.name)) sync_locks = [s for s in sync_ops if s.is_SyncLock] condition = Or(*([CondNe(s.handle, 2) for s in sync_locks] + [CondNe(FieldFromComposite(sdata._field_flag, sdata[d]), 1)])) if threads.size == 1: activation = [While(condition)] else: activation = [DummyExpr(d, 0), While(condition, DummyExpr(d, (d + 1) % threads.size))] activation.extend([DummyExpr(FieldFromComposite(i.name, sdata[d]), i) for i in sdata.dynamic_fields]) activation.extend([DummyExpr(s.handle, 0) for s in sync_locks]) activation.append(DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 2)) activation = List( header=[c.Line(), c.Comment("Activate `%s`" % threads.name)], body=activation, footer=c.Line() ) return activation
def _make_thread_init(threads, tfunc, isdata, sdata, sregistry): d = threads.index if threads.size == 1: callback = lambda body: body else: callback = lambda body: Iteration(body, d, threads.size - 1) # A unique identifier for each created pthread pthreadid = d + threads.base_id # Initialize `sdata` arguments = list(isdata.parameters) arguments[-3] = sdata.symbolic_base + d arguments[-2] = pthreadid arguments[-1] = sregistry.deviceid call0 = Call(isdata.name, arguments) # Create pthreads call1 = Call('pthread_create', (threads.symbolic_base + d, Macro('NULL'), Call(tfunc.name, [], is_indirect=True), sdata.symbolic_base + d)) threadsinit = List( header=c.Comment("Fire up and initialize `%s`" % threads.name), body=callback([call0, call1]) ) return threadsinit
def _specialize_iet(self, iet, **kwargs): mapper = {} self._includes.append('ops_seq.h') ops_init = Call("ops_init", [0, 0, 2]) ops_timing = Call("ops_timing_output", [FunctionPointer("stdout")]) ops_exit = Call("ops_exit") global_declarations = [] dims = None for n, (section, trees) in enumerate(find_affine_trees(iet).items()): callable_kernel, declarations, par_loop_call_block, dims = opsit( trees, n) global_declarations.extend(declarations) self._header_functions.append(callable_kernel) mapper[trees[0].root] = par_loop_call_block mapper.update({i.root: mapper.get(i.root) for i in trees}) # Drop trees self._headers.append('#define OPS_%sD' % dims) warning("The OPS backend is still work-in-progress") global_declarations.append(Transformer(mapper).visit(iet)) return List( body=[ops_init, *global_declarations, ops_timing, ops_exit])
def make_efunc(name, iet, dynamic_parameters=None, retval='void', prefix='static'): """ Create an ElementalFunction from (a sequence of) perfectly nested Iterations. """ # Arrays are by definition (vector) temporaries, so if they are written # within `iet`, they can also be declared and allocated within the `efunc` items = FindSymbols().visit(iet) local = [ i.write for i in FindNodes(Expression).visit(iet) if i.write.is_Array ] external = [i for i in items if i.is_Tensor and i not in local] # Insert array casts casts = [ArrayCast(i) for i in external] iet = List(body=casts + [iet]) # Insert declarations iet = iet_insert_C_decls(iet, external) # The Callable parameters params = [i for i in derive_parameters(iet) if i not in local] return ElementalFunction(name, iet, retval, params, prefix, dynamic_parameters)
def _make_thread_finalize(threads, sdata): d = threads.index if threads.size == 1: callback = lambda body: body else: callback = lambda body: Iteration(body, d, threads.size - 1) threadswait = List( header=c.Comment("Wait for completion of `%s`" % threads.name), body=callback([ While(CondEq(FieldFromComposite(sdata._field_flag, sdata[d]), 2)), DummyExpr(FieldFromComposite(sdata._field_flag, sdata[d]), 0), Call('pthread_join', (threads[d], Macro('NULL'))) ])) return threadswait
def h_ccode(self): header_block = List(body=self._header_functions) return CGen().visit(header_block)