def generate_kernel_wrapper(self, library, fname, argtypes, debug): """ Generate the kernel wrapper in the given ``library``. The function being wrapped have the name ``fname`` and argument types ``argtypes``. The wrapper function is returned. """ arginfo = self.get_arg_packer(argtypes) argtys = list(arginfo.argument_types) wrapfnty = Type.function(Type.void(), argtys) wrapper_module = self.create_module("cuda.kernel.wrapper") fnty = Type.function(Type.int(), [self.call_conv.get_return_type(types.pyobject)] + argtys) func = wrapper_module.add_function(fnty, name=fname) prefixed = itanium_mangler.prepend_namespace(func.name, ns='cudapy') wrapfn = wrapper_module.add_function(wrapfnty, name=prefixed) builder = Builder(wrapfn.append_basic_block('')) # Define error handling variables def define_error_gv(postfix): gv = wrapper_module.add_global_variable(Type.int(), name=wrapfn.name + postfix) gv.initializer = Constant.null(gv.type.pointee) return gv gv_exc = define_error_gv("__errcode__") gv_tid = [] gv_ctaid = [] for i in 'xyz': gv_tid.append(define_error_gv("__tid%s__" % i)) gv_ctaid.append(define_error_gv("__ctaid%s__" % i)) callargs = arginfo.from_arguments(builder, wrapfn.args) status, _ = self.call_conv.call_function(builder, func, types.void, argtypes, callargs) if debug: # Check error status with cgutils.if_likely(builder, status.is_ok): builder.ret_void() with builder.if_then(builder.not_(status.is_python_exc)): # User exception raised old = Constant.null(gv_exc.type.pointee) # Use atomic cmpxchg to prevent rewriting the error status # Only the first error is recorded casfnty = lc.Type.function(old.type, [gv_exc.type, old.type, old.type]) casfn = wrapper_module.add_function(casfnty, name="___numba_cas_hack") xchg = builder.call(casfn, [gv_exc, old, status.code]) changed = builder.icmp(ICMP_EQ, xchg, old) # If the xchange is successful, save the thread ID. sreg = nvvmutils.SRegBuilder(builder) with builder.if_then(changed): for dim, ptr, in zip("xyz", gv_tid): val = sreg.tid(dim) builder.store(val, ptr) for dim, ptr, in zip("xyz", gv_ctaid): val = sreg.ctaid(dim) builder.store(val, ptr) builder.ret_void() nvvm.set_cuda_kernel(wrapfn) library.add_ir_module(wrapper_module) library.finalize() wrapfn = library.get_function(wrapfn.name) return wrapfn
def generate_kernel_wrapper(self, library, fname, argtypes, debug): """ Generate the kernel wrapper in the given ``library``. The function being wrapped have the name ``fname`` and argument types ``argtypes``. The wrapper function is returned. """ arginfo = self.get_arg_packer(argtypes) argtys = list(arginfo.argument_types) wrapfnty = Type.function(Type.void(), argtys) wrapper_module = self.create_module("cuda.kernel.wrapper") fnty = Type.function(Type.int(), [self.call_conv.get_return_type(types.pyobject)] + argtys) func = wrapper_module.add_function(fnty, name=fname) prefixed = itanium_mangler.prepend_namespace(func.name, ns='cudapy') wrapfn = wrapper_module.add_function(wrapfnty, name=prefixed) builder = Builder(wrapfn.append_basic_block('')) # Define error handling variables def define_error_gv(postfix): gv = wrapper_module.add_global_variable(Type.int(), name=wrapfn.name + postfix) gv.initializer = Constant.null(gv.type.pointee) return gv gv_exc = define_error_gv("__errcode__") gv_tid = [] gv_ctaid = [] for i in 'xyz': gv_tid.append(define_error_gv("__tid%s__" % i)) gv_ctaid.append(define_error_gv("__ctaid%s__" % i)) callargs = arginfo.from_arguments(builder, wrapfn.args) status, _ = self.call_conv.call_function( builder, func, types.void, argtypes, callargs) if debug: # Check error status with cgutils.if_likely(builder, status.is_ok): builder.ret_void() with builder.if_then(builder.not_(status.is_python_exc)): # User exception raised old = Constant.null(gv_exc.type.pointee) # Use atomic cmpxchg to prevent rewriting the error status # Only the first error is recorded casfnty = lc.Type.function(old.type, [gv_exc.type, old.type, old.type]) casfn = wrapper_module.add_function(casfnty, name="___numba_cas_hack") xchg = builder.call(casfn, [gv_exc, old, status.code]) changed = builder.icmp(ICMP_EQ, xchg, old) # If the xchange is successful, save the thread ID. sreg = nvvmutils.SRegBuilder(builder) with builder.if_then(changed): for dim, ptr, in zip("xyz", gv_tid): val = sreg.tid(dim) builder.store(val, ptr) for dim, ptr, in zip("xyz", gv_ctaid): val = sreg.ctaid(dim) builder.store(val, ptr) builder.ret_void() nvvm.set_cuda_kernel(wrapfn) library.add_ir_module(wrapper_module) library.finalize() wrapfn = library.get_function(wrapfn.name) return wrapfn