示例#1
0
    def codegen(context, builder, sig, args):
        nrt_table = context.nrt.get_nrt_api(builder)

        llptrtype = context.get_value_type(types.intp)
        cdict = cgutils.create_struct_proxy(sig.return_type)(context, builder)
        fnty = lir.FunctionType(
            lir.VoidType(),
            [
                cdict.meminfo.type.as_pointer(),  # meminfo to fill
                lir.IntType(8).as_pointer(),  # NRT API func table
                lir.IntType(8),
                lir.IntType(8),  # gen_key, gen_value flags
                llptrtype,
                llptrtype,  # hash_func, equality func
                llptrtype,
                llptrtype,  # key incref, decref
                llptrtype,
                llptrtype,  # val incref, decref
                lir.IntType(64),
                lir.IntType(64)
            ])  # key size, val size
        func_name = f"hashmap_create_{key_type_postfix}_to_{value_type_postfix}"
        fn_hashmap_create = cgutils.get_or_insert_function(builder.module,
                                                           fnty,
                                                           name=func_name)

        gen_key = context.get_constant(types.int8, types.int8(not key_numeric))
        gen_val = context.get_constant(types.int8, types.int8(not val_numeric))

        lir_key_type = context.get_value_type(dict_key_type)
        hash_func_addr_const = context.get_constant(types.intp, hash_func_addr)
        eq_func_addr_const = context.get_constant(types.intp, eq_func_addr)
        key_incref = context.get_constant(types.intp, key_incref_func_addr)
        key_decref = context.get_constant(types.intp, key_decref_func_addr)
        key_type_size = context.get_constant(
            types.int64, context.get_abi_sizeof(lir_key_type))

        lir_val_type = context.get_value_type(dict_val_type)
        val_incref = context.get_constant(types.intp, val_incref_func_addr)
        val_decref = context.get_constant(types.intp, val_decref_func_addr)
        val_type_size = context.get_constant(
            types.int64, context.get_abi_sizeof(lir_val_type))

        builder.call(fn_hashmap_create, [
            cdict._get_ptr_by_name('meminfo'), nrt_table, gen_key, gen_val,
            hash_func_addr_const, eq_func_addr_const, key_incref, key_decref,
            val_incref, val_decref, key_type_size, val_type_size
        ])

        cdict.data_ptr = context.nrt.meminfo_data(builder, cdict.meminfo)
        return cdict._getvalue()
示例#2
0
    def codegen(cgctx, builder, sig, args):
        ty = sig.args[0]
        # trigger resolution to get a "custom_hash" impl based on the call type
        # "ty" and its literal value
        # import pdb; pdb.set_trace()
        lsig = fnty.get_call_type(tyctx, (ty, ty), {})
        resolved = cgctx.get_function(fnty, lsig)

        # close over resolved function, this is to deal with python scoping
        def resolved_codegen(cgctx, builder, sig, args):
            return resolved(builder, args)

        # A python function "wrapper" is made for the `@cfunc` arg, this calls
        # the jitted function "wrappee", which will be compiled as part of the
        # compilation chain for the cfunc. In turn the wrappee jitted function
        # has an intrinsic call which is holding reference to the resolved type
        # specialised custom_hash call above.
        @intrinsic
        def dispatcher(_ityctx, _a, _b):
            return types.int8(thing, another), resolved_codegen

        @intrinsic
        def deref(_ityctx, _x):
            # to deref the void * passed. TODO: nrt awareness
            catchthing = thing
            sig = catchthing(_x)

            def codegen(cgctx, builder, sig, args):
                toty = cgctx.get_value_type(sig.return_type).as_pointer()
                addressable = builder.bitcast(args[0], toty)
                zero_intpt = cgctx.get_constant(types.intp, 0)
                vref = builder.gep(addressable, [zero_intpt], inbounds=True)

                return builder.load(vref)

            return sig, codegen

        @njit
        def wrappee(ap, bp):
            a = deref(ap)
            b = deref(bp)
            return dispatcher(a, b)

        def wrapper(a, b):
            return wrappee(a, b)

        callback = cfunc(types.int8(types.voidptr, types.voidptr))(wrapper)

        # bake in address as a int const
        address = callback.address
        return cgctx.get_constant(types.intp, address)
示例#3
0
 def dispatcher(_ityctx, _a, _b):
     return types.int8(thing, another), resolved_codegen
示例#4
0
                jitted_func = njit(sig)(func)
                setattr(self, typed_name, jitted_func)

        return ncompiler


GrB_UnaryOp = OpContainer()
GrB_BinaryOp = OpContainer()

##################################
# Useful collections of signatures
##################################
_unary_bool = [nt.boolean(nt.boolean)]
_unary_int = [
    nt.uint8(nt.uint8),
    nt.int8(nt.int8),
    nt.uint16(nt.uint16),
    nt.int16(nt.int16),
    nt.uint32(nt.uint32),
    nt.int32(nt.int32),
    nt.uint64(nt.uint64),
    nt.int64(nt.int64)
]
_unary_float = [nt.float32(nt.float32), nt.float64(nt.float64)]
_unary_all = _unary_bool + _unary_int + _unary_float

_binary_bool = [nt.boolean(nt.boolean, nt.boolean)]
_binary_int = [
    nt.uint8(nt.uint8, nt.uint8),
    nt.int8(nt.int8, nt.int8),
    nt.uint16(nt.uint16, nt.uint16),