예제 #1
0
파일: templates.py 프로젝트: MrSanZhi/numba
 def get_args_kwargs(sig):
     kws = []
     args = []
     pos_arg = None
     for x in sig.parameters.values():
         if x.default == utils.pyParameter.empty:
             args.append(x)
             if x.kind == utils.pyParameter.VAR_POSITIONAL:
                 pos_arg = x
             elif x.kind == utils.pyParameter.VAR_KEYWORD:
                 msg = (
                     "The use of VAR_KEYWORD (e.g. **kwargs) is "
                     "unsupported. (offending argument name is '%s')")
                 raise InternalError(msg % x)
         else:
             kws.append(x)
     return args, kws, pos_arg
예제 #2
0
파일: templates.py 프로젝트: yikuide/numba
    def _validate_sigs(self, typing_func, impl_func):
        # check that the impl func and the typing func have the same signature!
        typing_sig = utils.pysignature(typing_func)
        impl_sig = utils.pysignature(impl_func)

        # the typing signature is considered golden and must be adhered to by
        # the implementation...
        # Things that are valid:
        # 1. args match exactly
        # 2. kwargs match exactly in name and default value
        # 3. Use of *args in the same location by the same name in both typing
        #    and implementation signature
        # 4. Use of *args in the implementation signature to consume any number
        #    of arguments in the typing signature.
        # Things that are invalid:
        # 5. Use of *args in the typing signature that is not replicated
        #    in the implementing signature
        # 6. Use of **kwargs

        def get_args_kwargs(sig):
            kws = []
            args = []
            pos_arg = None
            for x in sig.parameters.values():
                if x.default == utils.pyParameter.empty:
                    args.append(x)
                    if x.kind == utils.pyParameter.VAR_POSITIONAL:
                        pos_arg = x
                    elif x.kind == utils.pyParameter.VAR_KEYWORD:
                        msg = (
                            "The use of VAR_KEYWORD (e.g. **kwargs) is "
                            "unsupported. (offending argument name is '%s')")
                        raise InternalError(msg % x)
                else:
                    kws.append(x)
            return args, kws, pos_arg

        ty_args, ty_kws, ty_pos = get_args_kwargs(typing_sig)
        im_args, im_kws, im_pos = get_args_kwargs(impl_sig)

        sig_fmt = ("Typing signature:         %s\n"
                   "Implementation signature: %s")
        sig_str = sig_fmt % (typing_sig, impl_sig)

        err_prefix = "Typing and implementation arguments differ in "

        a = ty_args
        b = im_args
        if ty_pos:
            if not im_pos:
                # case 5. described above
                msg = ("VAR_POSITIONAL (e.g. *args) argument kind (offending "
                       "argument name is '%s') found in the typing function "
                       "signature, but is not in the implementing function "
                       "signature.\n%s") % (ty_pos, sig_str)
                raise InternalError(msg)
        else:
            if im_pos:
                # no *args in typing but there's a *args in the implementation
                # this is case 4. described above
                b = im_args[:im_args.index(im_pos)]
                try:
                    a = ty_args[:ty_args.index(b[-1]) + 1]
                except ValueError:
                    # there's no b[-1] arg name in the ty_args, something is
                    # very wrong, we can't work out a diff (*args consumes
                    # unknown quantity of args) so just report first error
                    specialized = "argument names.\n%s\nFirst difference: '%s'"
                    msg = err_prefix + specialized % (sig_str, b[-1])
                    raise InternalError(msg)

        def gen_diff(typing, implementing):
            diff = set(typing) ^ set(implementing)
            return "Difference: %s" % diff

        if a != b:
            specialized = "argument names.\n%s\n%s" % (sig_str, gen_diff(a, b))
            raise InternalError(err_prefix + specialized)

        # ensure kwargs are the same
        ty = [x.name for x in ty_kws]
        im = [x.name for x in im_kws]
        if ty != im:
            specialized = "keyword argument names.\n%s\n%s"
            msg = err_prefix + specialized % (sig_str, gen_diff(
                ty_kws, im_kws))
            raise InternalError(msg)
        same = [x.default for x in ty_kws] == [x.default for x in im_kws]
        if not same:
            specialized = "keyword argument default values.\n%s\n%s"
            msg = err_prefix + specialized % (sig_str, gen_diff(
                ty_kws, im_kws))
            raise InternalError(msg)
예제 #3
0
    def generic(self, args, kws):
        """
        Type the intrinsic by the arguments.
        """
        from numba.core.target_extension import (get_local_target,
                                                 resolve_target_str,
                                                 dispatcher_registry)
        from numba.core.imputils import builtin_registry

        cache_key = self.context, args, tuple(kws.items())
        hwstr = self.metadata.get('target', 'generic')
        # Get the class for the target declared by the function
        hw_clazz = resolve_target_str(hwstr)
        # get the local target
        target_hw = get_local_target(self.context)
        # make sure the target_hw is in the MRO for hw_clazz else bail
        if not target_hw.inherits_from(hw_clazz):
            msg = (f"Intrinsic being resolved on a target from which it does "
                   f"not inherit. Local target is {target_hw}, declared "
                   f"target class is {hw_clazz}.")
            raise InternalError(msg)
        disp = dispatcher_registry[target_hw]
        tgtctx = disp.targetdescr.target_context
        # This is all workarounds...
        # The issue is that whilst targets shouldn't care about which registry
        # in which to register lowering implementations, the CUDA target
        # "borrows" implementations from the CPU from specific registries. This
        # means that if some impl is defined via @intrinsic, e.g. numba.*unsafe
        # modules, _AND_ CUDA also makes use of the same impl, then it's
        # required that the registry in use is one that CUDA borrows from. This
        # leads to the following expression where by the CPU builtin_registry is
        # used if it is in the target context as a known registry (i.e. the
        # target installed it) and if it is not then it is assumed that the
        # registries for the target are unbound to any other target and so it's
        # fine to use any of them as a place to put lowering impls.
        #
        # NOTE: This will need subsequently fixing again when targets use solely
        # the extension APIs to describe their implementation. The issue will be
        # that the builtin_registry should contain _just_ the stack allocated
        # implementations and low level target invariant things and should not
        # be modified further. It should be acceptable to remove the `then`
        # branch and just keep the `else`.

        # In case the target has swapped, e.g. cuda borrowing cpu, refresh to
        # populate.
        tgtctx.refresh()
        if builtin_registry in tgtctx._registries:
            reg = builtin_registry
        else:
            # Pick a registry in which to install intrinsics
            registries = iter(tgtctx._registries)
            reg = next(registries)
        lower_builtin = reg.lower
        try:
            return self._impl_cache[cache_key]
        except KeyError:
            pass
        result = self._definition_func(self.context, *args, **kws)
        if result is None:
            return
        [sig, imp] = result
        pysig = utils.pysignature(self._definition_func)
        # omit context argument from user function
        parameters = list(pysig.parameters.values())[1:]
        sig = sig.replace(pysig=pysig.replace(parameters=parameters))
        self._impl_cache[cache_key] = sig
        self._overload_cache[sig.args] = imp
        # register the lowering
        lower_builtin(imp, *sig.args)(imp)
        return sig