def test_subclass_specialization(self): os = OverloadSelector() self.assertTrue(issubclass(types.Sequence, types.Container)) os.append(1, (types.Container, types.Container,)) lstty = types.List(types.boolean) self.assertEqual(os.find((lstty, lstty)), 1) os.append(2, (types.Container, types.Sequence,)) self.assertEqual(os.find((lstty, lstty)), 2)
def test_cache(self): os = OverloadSelector() self.assertEqual(len(os._cache), 0) os.append(1, (types.Any,)) self.assertEqual(os.find((types.int32,)), 1) self.assertEqual(len(os._cache), 1) os.append(2, (types.Integer,)) self.assertEqual(len(os._cache), 0) self.assertEqual(os.find((types.int32,)), 2) self.assertEqual(len(os._cache), 1)
def test_ambiguous_detection(self): os = OverloadSelector() # unambiguous signatures os.append(1, (types.Any, types.Boolean)) os.append(2, (types.Integer, types.Boolean)) self.assertEqual(os.find((types.boolean, types.boolean)), 1) # not implemented with self.assertRaises(NotImplementedError) as raises: os.find((types.boolean, types.int32)) # generic os.append(3, (types.Any, types.Any)) self.assertEqual(os.find((types.boolean, types.int32)), 3) self.assertEqual(os.find((types.boolean, types.boolean)), 1) # add ambiguous signature; can match (bool, any) and (any, bool) os.append(4, (types.Boolean, types.Any)) with self.assertRaises(TypeError) as raises: os.find((types.boolean, types.boolean)) self.assertIn('2 ambiguous signatures', str(raises.exception)) # disambiguous os.append(5, (types.boolean, types.boolean)) self.assertEqual(os.find((types.boolean, types.boolean)), 5)
def test_select_and_sort_1(self): os = OverloadSelector() os.append(1, (types.Any, types.Boolean)) os.append(2, (types.Boolean, types.Integer)) os.append(3, (types.Boolean, types.Any)) os.append(4, (types.Boolean, types.Boolean)) compats = os._select_compatible((types.boolean, types.boolean)) self.assertEqual(len(compats), 3) ordered, scoring = os._sort_signatures(compats) self.assertEqual(len(ordered), 3) self.assertEqual(len(scoring), 3) self.assertEqual(ordered[0], (types.Boolean, types.Boolean)) self.assertEqual(scoring[types.Boolean, types.Boolean], 0) self.assertEqual(scoring[types.Boolean, types.Any], 1) self.assertEqual(scoring[types.Any, types.Boolean], 1)
def test_select_and_sort_2(self): os = OverloadSelector() os.append(1, (types.Container,)) os.append(2, (types.Sequence,)) os.append(3, (types.MutableSequence,)) os.append(4, (types.List,)) compats = os._select_compatible((types.List,)) self.assertEqual(len(compats), 4) ordered, scoring = os._sort_signatures(compats) self.assertEqual(len(ordered), 4) self.assertEqual(len(scoring), 4) self.assertEqual(ordered[0], (types.List,)) self.assertEqual(scoring[(types.List,)], 0) self.assertEqual(scoring[(types.MutableSequence,)], 1) self.assertEqual(scoring[(types.Sequence,)], 2) self.assertEqual(scoring[(types.Container,)], 3)
def create_overload_selector(self, kind): os = OverloadSelector() loader = RegistryLoader(builtin_registry) for impl, sig in loader.new_registrations(kind): os.append(impl, sig) return os
class _OverloadWrapper(object): """This class does all the work of assembling and registering wrapped split implementations. """ def __init__(self, function, typing_key=None): assert function is not None self._function = function self._typing_key = typing_key self._BIND_TYPES = dict() self._selector = None self._TYPER = None # run to register overload, the intrinsic sorts out the binding to the # registered impls at the point the overload is evaluated, i.e. this # is all lazy. self._build() def _stub_generator(self, nargs, body_func, kwargs=None): """This generates a function that takes "nargs" count of arguments and the presented kwargs, the "body_func" is the function that'll type the overloaded function and then work out which lowering to return""" def stub(tyctx): # body is supplied when the function is magic'd into life via glbls return body(tyctx) # noqa: F821 if kwargs is None: kwargs = {} # create new code parts stub_code = stub.__code__ co_args = [stub_code.co_argcount + nargs + len(kwargs)] new_varnames = [*stub_code.co_varnames] new_varnames.extend([f'tmp{x}' for x in range(nargs)]) new_varnames.extend([x for x, _ in kwargs.items()]) from numba.core import utils if utils.PYVERSION >= (3, 8): co_args.append(stub_code.co_posonlyargcount) co_args.append(stub_code.co_kwonlyargcount) co_args.extend([ stub_code.co_nlocals + nargs + len(kwargs), stub_code.co_stacksize, stub_code.co_flags, stub_code.co_code, stub_code.co_consts, stub_code.co_names, tuple(new_varnames), stub_code.co_filename, stub_code.co_name, stub_code.co_firstlineno, stub_code.co_lnotab, stub_code.co_freevars, stub_code.co_cellvars ]) new_code = pytypes.CodeType(*co_args) # get function new_func = pytypes.FunctionType(new_code, {'body': body_func}) return new_func def wrap_typing(self): """ Use this to replace @infer_global, it records the decorated function as a typer for the argument `concrete_function`. """ if self._typing_key is None: key = self._function else: key = self._typing_key def inner(typing_class): # Note that two templates could be used for the same function, to # avoid @infer_global etc the typing template is copied. This is to # ensure there's a 1:1 relationship between the typing templates and # their keys. clazz_dict = dict(typing_class.__dict__) clazz_dict['key'] = key cloned = type(f"cloned_template_for_{key}", typing_class.__bases__, clazz_dict) self._TYPER = cloned _overload_glue.add_no_defer(key) self._build() return typing_class return inner def wrap_impl(self, *args): """ Use this to replace @lower*, it records the decorated function as the lowering implementation """ assert self._TYPER is not None def inner(lowerer): self._BIND_TYPES[args] = lowerer return lowerer return inner def _assemble(self): """Assembles the OverloadSelector definitions from the registered typing to lowering map. """ from numba.core.base import OverloadSelector if self._typing_key is None: key = self._function else: key = self._typing_key _overload_glue.flush_deferred_lowering(key) self._selector = OverloadSelector() msg = f"No entries in the typing->lowering map for {self._function}" assert self._BIND_TYPES, msg for sig, impl in self._BIND_TYPES.items(): self._selector.append(impl, sig) def _build(self): from numba.core.extending import overload, intrinsic @overload(self._function, strict=False) def ol_generated(*ol_args, **ol_kwargs): def body(tyctx): msg = f"No typer registered for {self._function}" if self._TYPER is None: raise errors.InternalError(msg) typing = self._TYPER(tyctx) sig = typing.apply(ol_args, ol_kwargs) if sig is None: # this follows convention of something not typeable # returning None return None if self._selector is None: self._assemble() lowering = self._selector.find(sig.args) msg = (f"Could not find implementation to lower {sig} for ", f"{self._function}") if lowering is None: raise errors.InternalError(msg) return sig, lowering stub = self._stub_generator(len(ol_args), body, ol_kwargs) intrin = intrinsic(stub) # This is horrible, need to generate a jit wrapper function that # walks the ol_kwargs into the intrin with a signature that # matches the lowering sig. The actual kwarg var names matter, # they have to match exactly. arg_str = ','.join([f'tmp{x}' for x in range(len(ol_args))]) kws_str = ','.join(ol_kwargs.keys()) call_str = ','.join([x for x in (arg_str, kws_str) if x]) # NOTE: The jit_wrapper functions cannot take `*args` # albeit this an obvious choice for accepting an unknown number # of arguments. If this is done, `*args` ends up as a cascade of # Tuple assembling in the IR which ends up with literal # information being lost. As a result the _exact_ argument list # is generated to match the number of arguments and kwargs. name = str(self._function) # This is to name the function with something vaguely identifiable name = ''.join([ x if x not in {'>', '<', ' ', '-', '.'} else '_' for x in name ]) gen = textwrap.dedent((""" def jit_wrapper_{}({}): return intrin({}) """)).format(name, call_str, call_str) l = {} g = {'intrin': intrin} exec(gen, g, l) return l['jit_wrapper_{}'.format(name)]
class _OverloadWrapper(object): """This class does all the work of assembling and registering wrapped split implementations. """ def __init__(self, function, typing_key=None): assert function is not None self._function = function self._typing_key = typing_key self._BIND_TYPES = dict() self._selector = None self._TYPER = None # run to register overload, the intrinsic sorts out the binding to the # registered impls at the point the overload is evaluated, i.e. this # is all lazy. self._build() def _stub_generator(self, body_func, varnames): """This generates a function based on the argnames provided in "varnames", the "body_func" is the function that'll type the overloaded function and then work out which lowering to return""" def stub(tyctx): # body is supplied when the function is magic'd into life via glbls return body(tyctx) # noqa: F821 stub_code = stub.__code__ new_varnames = [*stub_code.co_varnames] new_varnames.extend(varnames) co_argcount = len(new_varnames) co_args = [co_argcount] additional_co_nlocals = len(varnames) from numba.core import utils if utils.PYVERSION >= (3, 8): co_args.append(stub_code.co_posonlyargcount) co_args.append(stub_code.co_kwonlyargcount) co_args.extend([ stub_code.co_nlocals + additional_co_nlocals, stub_code.co_stacksize, stub_code.co_flags, stub_code.co_code, stub_code.co_consts, stub_code.co_names, tuple(new_varnames), stub_code.co_filename, stub_code.co_name, stub_code.co_firstlineno, stub_code.co_lnotab, stub_code.co_freevars, stub_code.co_cellvars ]) new_code = pytypes.CodeType(*co_args) # get function new_func = pytypes.FunctionType(new_code, {'body': body_func}) return new_func def wrap_typing(self): """ Use this to replace @infer_global, it records the decorated function as a typer for the argument `concrete_function`. """ if self._typing_key is None: key = self._function else: key = self._typing_key def inner(typing_class): # Note that two templates could be used for the same function, to # avoid @infer_global etc the typing template is copied. This is to # ensure there's a 1:1 relationship between the typing templates and # their keys. clazz_dict = dict(typing_class.__dict__) clazz_dict['key'] = key cloned = type(f"cloned_template_for_{key}", typing_class.__bases__, clazz_dict) self._TYPER = cloned _overload_glue.add_no_defer(key) self._build() return typing_class return inner def wrap_impl(self, *args): """ Use this to replace @lower*, it records the decorated function as the lowering implementation """ assert self._TYPER is not None def inner(lowerer): self._BIND_TYPES[args] = lowerer return lowerer return inner def _assemble(self): """Assembles the OverloadSelector definitions from the registered typing to lowering map. """ from numba.core.base import OverloadSelector if self._typing_key is None: key = self._function else: key = self._typing_key _overload_glue.flush_deferred_lowering(key) self._selector = OverloadSelector() msg = f"No entries in the typing->lowering map for {self._function}" assert self._BIND_TYPES, msg for sig, impl in self._BIND_TYPES.items(): self._selector.append(impl, sig) def _build(self): from numba.core.extending import overload, intrinsic @overload(self._function, strict=False, jit_options={'forceinline': True}) def ol_generated(*ol_args, **ol_kwargs): def body(tyctx): msg = f"No typer registered for {self._function}" if self._TYPER is None: raise errors.InternalError(msg) typing = self._TYPER(tyctx) sig = typing.apply(ol_args, ol_kwargs) if sig is None: # this follows convention of something not typeable # returning None return None if self._selector is None: self._assemble() lowering = self._selector.find(sig.args) msg = (f"Could not find implementation to lower {sig} for ", f"{self._function}") if lowering is None: raise errors.InternalError(msg) return sig, lowering # Need a typing context now so as to get a signature and a binding # for the kwarg order. from numba.core.target_extension import (dispatcher_registry, resolve_target_str, current_target) disp = dispatcher_registry[resolve_target_str(current_target())] typing_context = disp.targetdescr.typing_context typing = self._TYPER(typing_context) sig = typing.apply(ol_args, ol_kwargs) if not sig: # No signature is a typing error, there's no match, so report it raise errors.TypingError("No match") # The following code branches based on whether the signature has a # "pysig", if it does, it's from a CallableTemplate and # specialisation is required based on precise arg/kwarg names and # default values, if it does not, then it just requires # specialisation based on the arg count. # # The "gen_var_names" function is defined to generate the variable # names at the call site of the intrinsic. # # The "call_str_specific" is the list of args to the function # returned by the @overload, it has to have matching arg names and # kwargs names/defaults if the underlying typing template supports # it (CallableTemplate), else it has to have a matching number of # arguments (AbstractTemplate). The "call_str" is the list of args # that will be passed to the intrinsic that deals with typing and # selection of the lowering etc, so it just needs to be a list of # the argument names. if sig.pysig: # CallableTemplate, has pysig pysig_params = sig.pysig.parameters # Define the var names gen_var_names = [x for x in pysig_params.keys()] # CallableTemplate, pysig is present so generate the exact thing # this is to permit calling with positional args specified by # name. buf = [] for k, v in pysig_params.items(): if v.default is v.empty: # no default ~= positional arg buf.append(k) else: # is kwarg, wire in default buf.append(f'{k} = {v.default}') call_str_specific = ', '.join(buf) call_str = ', '.join(pysig_params.keys()) else: # AbstractTemplate, need to bind 1:1 vars to the arg count # Define the var names gen_var_names = [f'tmp{x}' for x in range(len(ol_args))] # Everything is just passed by position, there should be no # kwargs. assert not ol_kwargs call_str_specific = ', '.join(gen_var_names) call_str = call_str_specific stub = self._stub_generator(body, gen_var_names) intrin = intrinsic(stub) # NOTE: The jit_wrapper functions cannot take `*args` # albeit this an obvious choice for accepting an unknown number # of arguments. If this is done, `*args` ends up as a cascade of # Tuple assembling in the IR which ends up with literal # information being lost. As a result the _exact_ argument list # is generated to match the number of arguments and kwargs. name = str(self._function) # This is to name the function with something vaguely identifiable name = ''.join([ x if x not in {'>', '<', ' ', '-', '.'} else '_' for x in name ]) gen = textwrap.dedent((""" def jit_wrapper_{}({}): return intrin({}) """)).format(name, call_str_specific, call_str) l = {} g = {'intrin': intrin} exec(gen, g, l) return l['jit_wrapper_{}'.format(name)]