Example #1
0
def _get_equal(context, module, datamodel, container_element_type):
    assert datamodel.contains_nrt_meminfo()

    fe_type = datamodel.fe_type
    data_ptr_ty = datamodel.get_data_type().as_pointer()

    wrapfnty = context.call_conv.get_function_type(types.int32,
                                                   [fe_type, fe_type])
    argtypes = [fe_type, fe_type]

    def build_wrapper(fn):
        builder = Builder(fn.append_basic_block())
        args = context.call_conv.decode_arguments(builder, argtypes, fn)

        sig = typing.signature(types.boolean, fe_type, fe_type)
        op = operator.eq
        fnop = context.typing_context.resolve_value_type(op)
        fnop.get_call_type(context.typing_context, sig.args, {})
        eqfn = context.get_function(fnop, sig)
        res = eqfn(builder, args)
        intres = context.cast(builder, res, types.boolean, types.int32)
        context.call_conv.return_value(builder, intres)

    wrapfn = cgutils.get_or_insert_function(
        module,
        wrapfnty,
        name='.numba_{}.{}_equal.wrap'.format(context.fndesc.mangled_name,
                                              container_element_type))
    build_wrapper(wrapfn)

    equal_fnty = ir.FunctionType(ir.IntType(32), [data_ptr_ty, data_ptr_ty])
    equal_fn = cgutils.get_or_insert_function(
        module,
        equal_fnty,
        name='.numba_{}.{}_equal'.format(context.fndesc.mangled_name,
                                         container_element_type),
    )
    builder = Builder(equal_fn.append_basic_block())
    lhs = datamodel.load_from_data_pointer(builder, equal_fn.args[0])
    rhs = datamodel.load_from_data_pointer(builder, equal_fn.args[1])

    status, retval = context.call_conv.call_function(
        builder,
        wrapfn,
        types.boolean,
        argtypes,
        [lhs, rhs],
    )
    with builder.if_then(status.is_ok, likely=True):
        with builder.if_then(status.is_none):
            builder.ret(context.get_constant(types.int32, 0))
        retval = context.cast(builder, retval, types.boolean, types.int32)
        builder.ret(retval)
    # Error out
    builder.ret(context.get_constant(types.int32, -1))

    return equal_fn
Example #2
0
    def generate_kernel_wrapper(self, library, fname, argtypes, debug):
        """
        Generate the kernel wrapper in the given ``library``.
        The function being wrapped have the name ``fname`` and argument types
        ``argtypes``.  The wrapper function is returned.
        """
        arginfo = self.get_arg_packer(argtypes)
        argtys = list(arginfo.argument_types)
        wrapfnty = Type.function(Type.void(), argtys)
        wrapper_module = self.create_module("cuda.kernel.wrapper")
        fnty = Type.function(Type.int(),
                             [self.call_conv.get_return_type(types.pyobject)] +
                             argtys)
        func = wrapper_module.add_function(fnty, name=fname)

        prefixed = itanium_mangler.prepend_namespace(func.name, ns='cudapy')
        wrapfn = wrapper_module.add_function(wrapfnty, name=prefixed)
        builder = Builder(wrapfn.append_basic_block(''))

        # Define error handling variables
        def define_error_gv(postfix):
            gv = wrapper_module.add_global_variable(Type.int(),
                                                    name=wrapfn.name + postfix)
            gv.initializer = Constant.null(gv.type.pointee)
            return gv

        gv_exc = define_error_gv("__errcode__")
        gv_tid = []
        gv_ctaid = []
        for i in 'xyz':
            gv_tid.append(define_error_gv("__tid%s__" % i))
            gv_ctaid.append(define_error_gv("__ctaid%s__" % i))

        callargs = arginfo.from_arguments(builder, wrapfn.args)
        status, _ = self.call_conv.call_function(builder, func, types.void,
                                                 argtypes, callargs)

        if debug:
            # Check error status
            with cgutils.if_likely(builder, status.is_ok):
                builder.ret_void()

            with builder.if_then(builder.not_(status.is_python_exc)):
                # User exception raised
                old = Constant.null(gv_exc.type.pointee)

                # Use atomic cmpxchg to prevent rewriting the error status
                # Only the first error is recorded

                casfnty = lc.Type.function(old.type,
                                           [gv_exc.type, old.type, old.type])

                casfn = wrapper_module.add_function(casfnty,
                                                    name="___numba_cas_hack")
                xchg = builder.call(casfn, [gv_exc, old, status.code])
                changed = builder.icmp(ICMP_EQ, xchg, old)

                # If the xchange is successful, save the thread ID.
                sreg = nvvmutils.SRegBuilder(builder)
                with builder.if_then(changed):
                    for dim, ptr, in zip("xyz", gv_tid):
                        val = sreg.tid(dim)
                        builder.store(val, ptr)

                    for dim, ptr, in zip("xyz", gv_ctaid):
                        val = sreg.ctaid(dim)
                        builder.store(val, ptr)

        builder.ret_void()

        nvvm.set_cuda_kernel(wrapfn)
        library.add_ir_module(wrapper_module)
        library.finalize()
        wrapfn = library.get_function(wrapfn.name)
        return wrapfn
Example #3
0
    def generate_kernel_wrapper(self, library, fname, argtypes):
        """
        Generate the kernel wrapper in the given ``library``.
        The function being wrapped have the name ``fname`` and argument types
        ``argtypes``.  The wrapper function is returned.
        """
        arginfo = self.get_arg_packer(argtypes)
        argtys = list(arginfo.argument_types)
        wrapfnty = Type.function(Type.void(), argtys)
        wrapper_module = self.create_module("cuda.kernel.wrapper")
        fnty = Type.function(Type.int(),
                             [self.call_conv.get_return_type(types.pyobject)] + argtys)
        func = wrapper_module.add_function(fnty, name=fname)

        wrapfn = wrapper_module.add_function(wrapfnty, name="cudaPy_" + func.name)
        builder = Builder(wrapfn.append_basic_block(''))

        # Define error handling variables
        def define_error_gv(postfix):
            gv = wrapper_module.add_global_variable(Type.int(),
                                                    name=wrapfn.name + postfix)
            gv.initializer = Constant.null(gv.type.pointee)
            return gv

        gv_exc = define_error_gv("__errcode__")
        gv_tid = []
        gv_ctaid = []
        for i in 'xyz':
            gv_tid.append(define_error_gv("__tid%s__" % i))
            gv_ctaid.append(define_error_gv("__ctaid%s__" % i))

        callargs = arginfo.from_arguments(builder, wrapfn.args)
        status, _ = self.call_conv.call_function(
            builder, func, types.void, argtypes, callargs)

        # Check error status
        with cgutils.if_likely(builder, status.is_ok):
            builder.ret_void()

        with builder.if_then(builder.not_(status.is_python_exc)):
            # User exception raised
            old = Constant.null(gv_exc.type.pointee)

            # Use atomic cmpxchg to prevent rewriting the error status
            # Only the first error is recorded

            casfnty = lc.Type.function(old.type, [gv_exc.type, old.type,
                                                  old.type])

            casfn = wrapper_module.add_function(casfnty,
                                                name="___numba_cas_hack")
            xchg = builder.call(casfn, [gv_exc, old, status.code])
            changed = builder.icmp(ICMP_EQ, xchg, old)

            # If the xchange is successful, save the thread ID.
            sreg = nvvmutils.SRegBuilder(builder)
            with builder.if_then(changed):
                for dim, ptr, in zip("xyz", gv_tid):
                    val = sreg.tid(dim)
                    builder.store(val, ptr)

                for dim, ptr, in zip("xyz", gv_ctaid):
                    val = sreg.ctaid(dim)
                    builder.store(val, ptr)

        builder.ret_void()

        nvvm.set_cuda_kernel(wrapfn)
        library.add_ir_module(wrapper_module)
        library.finalize()
        wrapfn = library.get_function(wrapfn.name)
        return wrapfn