Exemple #1
0
class ClassBuilder(object):
    """
    A jitclass builder for a mutable jitclass.  This will register
    typing and implementation hooks to the given typing and target contexts.
    """
    class_impl_registry = imputils.Registry()
    implemented_methods = set()

    def __init__(self, class_type, typingctx, targetctx):
        self.class_type = class_type
        self.typingctx = typingctx
        self.targetctx = targetctx

    def register(self):
        """
        Register to the frontend and backend.
        """
        # Register generic implementations for all jitclasses
        self._register_methods(self.class_impl_registry,
                               self.class_type.instance_type)
        # NOTE other registrations are done at the top-level
        # (see ctor_impl and attr_impl below)
        self.targetctx.install_registry(self.class_impl_registry)

    def _register_methods(self, registry, instance_type):
        """
        Register method implementations.
        This simply registers that the method names are valid methods.  Inside
        of imp() below we retrieve the actual method to run from the type of
        the reciever argument (i.e. self).
        """
        to_register = list(instance_type.jit_methods) + \
            list(instance_type.jit_static_methods)
        for meth in to_register:

            # There's no way to retrieve the particular method name
            # inside the implementation function, so we have to register a
            # specific closure for each different name
            if meth not in self.implemented_methods:
                self._implement_method(registry, meth)
                self.implemented_methods.add(meth)

    def _implement_method(self, registry, attr):
        # create a separate instance of imp method to avoid closure clashing
        def get_imp():
            def imp(context, builder, sig, args):
                instance_type = sig.args[0]

                if attr in instance_type.jit_methods:
                    method = instance_type.jit_methods[attr]
                elif attr in instance_type.jit_static_methods:
                    method = instance_type.jit_static_methods[attr]
                    # imp gets called as a method, where the first argument is
                    # self.  We drop this for a static method.
                    sig = sig.replace(args=sig.args[1:])
                    args = args[1:]

                disp_type = types.Dispatcher(method)
                call = context.get_function(disp_type, sig)
                out = call(builder, args)
                _add_linking_libs(context, call)
                return imputils.impl_ret_new_ref(context, builder,
                                                 sig.return_type, out)

            return imp

        def _getsetitem_gen(getset):
            _dunder_meth = "__%s__" % getset
            op = getattr(operator, getset)

            @templates.infer_global(op)
            class GetSetItem(templates.AbstractTemplate):
                def generic(self, args, kws):
                    instance = args[0]
                    if isinstance(instance, types.ClassInstanceType) and \
                            _dunder_meth in instance.jit_methods:
                        meth = instance.jit_methods[_dunder_meth]
                        disp_type = types.Dispatcher(meth)
                        sig = disp_type.get_call_type(self.context, args, kws)
                        return sig

            # lower both {g,s}etitem and __{g,s}etitem__ to catch the calls
            # from python and numba
            imputils.lower_builtin((types.ClassInstanceType, _dunder_meth),
                                   types.ClassInstanceType,
                                   types.VarArg(types.Any))(get_imp())
            imputils.lower_builtin(op, types.ClassInstanceType,
                                   types.VarArg(types.Any))(get_imp())

        dunder_stripped = attr.strip('_')
        if dunder_stripped in ("getitem", "setitem"):
            _getsetitem_gen(dunder_stripped)
        else:
            registry.lower(
                (types.ClassInstanceType, attr), types.ClassInstanceType,
                types.VarArg(types.Any))(get_imp())
Exemple #2
0
from . import register_external
from numba.core import imputils, typing  # noqa: E402
from numba.cuda import libdevicefuncs  # noqa: E402

# Typing
typing_registry = typing.templates.Registry()

# Lowering
lowering_registry = imputils.Registry()

for fname, (retty, args) in libdevicefuncs.functions.items():
    argtys = tuple(map(lambda x: f"{x.ty}*" if x.is_ptr else f"{x.ty}", args))
    register_external(fname, retty, argtys, __name__, globals(),
                      typing_registry, lowering_registry)