def get_args_kwargs(sig): kws = [] args = [] pos_arg = None for x in sig.parameters.values(): if x.default == utils.pyParameter.empty: args.append(x) if x.kind == utils.pyParameter.VAR_POSITIONAL: pos_arg = x elif x.kind == utils.pyParameter.VAR_KEYWORD: msg = ( "The use of VAR_KEYWORD (e.g. **kwargs) is " "unsupported. (offending argument name is '%s')") raise InternalError(msg % x) else: kws.append(x) return args, kws, pos_arg
def _validate_sigs(self, typing_func, impl_func): # check that the impl func and the typing func have the same signature! typing_sig = utils.pysignature(typing_func) impl_sig = utils.pysignature(impl_func) # the typing signature is considered golden and must be adhered to by # the implementation... # Things that are valid: # 1. args match exactly # 2. kwargs match exactly in name and default value # 3. Use of *args in the same location by the same name in both typing # and implementation signature # 4. Use of *args in the implementation signature to consume any number # of arguments in the typing signature. # Things that are invalid: # 5. Use of *args in the typing signature that is not replicated # in the implementing signature # 6. Use of **kwargs def get_args_kwargs(sig): kws = [] args = [] pos_arg = None for x in sig.parameters.values(): if x.default == utils.pyParameter.empty: args.append(x) if x.kind == utils.pyParameter.VAR_POSITIONAL: pos_arg = x elif x.kind == utils.pyParameter.VAR_KEYWORD: msg = ( "The use of VAR_KEYWORD (e.g. **kwargs) is " "unsupported. (offending argument name is '%s')") raise InternalError(msg % x) else: kws.append(x) return args, kws, pos_arg ty_args, ty_kws, ty_pos = get_args_kwargs(typing_sig) im_args, im_kws, im_pos = get_args_kwargs(impl_sig) sig_fmt = ("Typing signature: %s\n" "Implementation signature: %s") sig_str = sig_fmt % (typing_sig, impl_sig) err_prefix = "Typing and implementation arguments differ in " a = ty_args b = im_args if ty_pos: if not im_pos: # case 5. described above msg = ("VAR_POSITIONAL (e.g. *args) argument kind (offending " "argument name is '%s') found in the typing function " "signature, but is not in the implementing function " "signature.\n%s") % (ty_pos, sig_str) raise InternalError(msg) else: if im_pos: # no *args in typing but there's a *args in the implementation # this is case 4. described above b = im_args[:im_args.index(im_pos)] try: a = ty_args[:ty_args.index(b[-1]) + 1] except ValueError: # there's no b[-1] arg name in the ty_args, something is # very wrong, we can't work out a diff (*args consumes # unknown quantity of args) so just report first error specialized = "argument names.\n%s\nFirst difference: '%s'" msg = err_prefix + specialized % (sig_str, b[-1]) raise InternalError(msg) def gen_diff(typing, implementing): diff = set(typing) ^ set(implementing) return "Difference: %s" % diff if a != b: specialized = "argument names.\n%s\n%s" % (sig_str, gen_diff(a, b)) raise InternalError(err_prefix + specialized) # ensure kwargs are the same ty = [x.name for x in ty_kws] im = [x.name for x in im_kws] if ty != im: specialized = "keyword argument names.\n%s\n%s" msg = err_prefix + specialized % (sig_str, gen_diff( ty_kws, im_kws)) raise InternalError(msg) same = [x.default for x in ty_kws] == [x.default for x in im_kws] if not same: specialized = "keyword argument default values.\n%s\n%s" msg = err_prefix + specialized % (sig_str, gen_diff( ty_kws, im_kws)) raise InternalError(msg)
def generic(self, args, kws): """ Type the intrinsic by the arguments. """ from numba.core.target_extension import (get_local_target, resolve_target_str, dispatcher_registry) from numba.core.imputils import builtin_registry cache_key = self.context, args, tuple(kws.items()) hwstr = self.metadata.get('target', 'generic') # Get the class for the target declared by the function hw_clazz = resolve_target_str(hwstr) # get the local target target_hw = get_local_target(self.context) # make sure the target_hw is in the MRO for hw_clazz else bail if not target_hw.inherits_from(hw_clazz): msg = (f"Intrinsic being resolved on a target from which it does " f"not inherit. Local target is {target_hw}, declared " f"target class is {hw_clazz}.") raise InternalError(msg) disp = dispatcher_registry[target_hw] tgtctx = disp.targetdescr.target_context # This is all workarounds... # The issue is that whilst targets shouldn't care about which registry # in which to register lowering implementations, the CUDA target # "borrows" implementations from the CPU from specific registries. This # means that if some impl is defined via @intrinsic, e.g. numba.*unsafe # modules, _AND_ CUDA also makes use of the same impl, then it's # required that the registry in use is one that CUDA borrows from. This # leads to the following expression where by the CPU builtin_registry is # used if it is in the target context as a known registry (i.e. the # target installed it) and if it is not then it is assumed that the # registries for the target are unbound to any other target and so it's # fine to use any of them as a place to put lowering impls. # # NOTE: This will need subsequently fixing again when targets use solely # the extension APIs to describe their implementation. The issue will be # that the builtin_registry should contain _just_ the stack allocated # implementations and low level target invariant things and should not # be modified further. It should be acceptable to remove the `then` # branch and just keep the `else`. # In case the target has swapped, e.g. cuda borrowing cpu, refresh to # populate. tgtctx.refresh() if builtin_registry in tgtctx._registries: reg = builtin_registry else: # Pick a registry in which to install intrinsics registries = iter(tgtctx._registries) reg = next(registries) lower_builtin = reg.lower try: return self._impl_cache[cache_key] except KeyError: pass result = self._definition_func(self.context, *args, **kws) if result is None: return [sig, imp] = result pysig = utils.pysignature(self._definition_func) # omit context argument from user function parameters = list(pysig.parameters.values())[1:] sig = sig.replace(pysig=pysig.replace(parameters=parameters)) self._impl_cache[cache_key] = sig self._overload_cache[sig.args] = imp # register the lowering lower_builtin(imp, *sig.args)(imp) return sig