コード例 #1
0
ファイル: context.py プロジェクト: stuartarchibald/numba
    def find_matching_getattr_template(self, typ, attr):
        from numba.core.target_extension import (target_registry,
                                                 get_local_target)

        templates = list(self._get_attribute_templates(typ))

        # get the current target target
        target_hw = get_local_target(self)

        # fish out templates that are specific to the target if a target is
        # specified
        DEFAULT_TARGET = 'generic'
        usable = []
        for ix, temp_cls in enumerate(templates):
            md = getattr(temp_cls, "metadata", {})
            hw = md.get('target', DEFAULT_TARGET)
            if hw is not None:
                hw_clazz = target_registry[hw]
                if target_hw.inherits_from(hw_clazz):
                    usable.append((temp_cls, hw_clazz, ix))

        # sort templates based on target specificity
        def key(x):
            return target_hw.__mro__.index(x[1])

        order = [x[0] for x in sorted(usable, key=key)]

        for template in order:
            return_type = template.resolve(typ, attr)
            if return_type is not None:
                return {
                    'template': template,
                    'return_type': return_type,
                }
コード例 #2
0
ファイル: functions.py プロジェクト: guilhermeleobas/numba
    def get_call_type(self, context, args, kws):

        prefer_lit = [True, False]  # old behavior preferring literal
        prefer_not = [False, True]  # new behavior preferring non-literal
        failures = _ResolutionFailures(context,
                                       self,
                                       args,
                                       kws,
                                       depth=self._depth)

        # get the order in which to try templates
        from numba.core.target_extension import get_local_target  # circular
        target_hw = get_local_target(context)
        order = utils.order_by_target_specificity(target_hw,
                                                  self.templates,
                                                  fnkey=self.key[0])

        self._depth += 1

        for temp_cls in order:
            temp = temp_cls(context)
            # The template can override the default and prefer literal args
            choice = prefer_lit if temp.prefer_literal else prefer_not
            for uselit in choice:
                try:
                    if uselit:
                        sig = temp.apply(args, kws)
                    else:
                        nolitargs = tuple([_unlit_non_poison(a) for a in args])
                        nolitkws = {
                            k: _unlit_non_poison(v)
                            for k, v in kws.items()
                        }
                        sig = temp.apply(nolitargs, nolitkws)
                except Exception as e:
                    if (utils.use_new_style_errors()
                            and not isinstance(e, errors.NumbaError)):
                        raise e
                    else:
                        sig = None
                        failures.add_error(temp, False, e, uselit)
                else:
                    if sig is not None:
                        self._impl_keys[sig.args] = temp.get_impl_key(sig)
                        self._depth -= 1
                        return sig
                    else:
                        registered_sigs = getattr(temp, 'cases', None)
                        if registered_sigs is not None:
                            msg = "No match for registered cases:\n%s"
                            msg = msg % '\n'.join(" * {}".format(x)
                                                  for x in registered_sigs)
                        else:
                            msg = 'No match.'
                        failures.add_error(temp, True, msg, uselit)

        failures.raise_error()
コード例 #3
0
ファイル: context.py プロジェクト: vishalbelsare/numba
    def find_matching_getattr_template(self, typ, attr):

        templates = list(self._get_attribute_templates(typ))

        # get the order in which to try templates
        from numba.core.target_extension import get_local_target  # circular
        target_hw = get_local_target(self)
        order = order_by_target_specificity(target_hw, templates, fnkey=attr)

        for template in order:
            return_type = template.resolve(typ, attr)
            if return_type is not None:
                return {
                    'template': template,
                    'return_type': return_type,
                }
コード例 #4
0
ファイル: templates.py プロジェクト: zhaijf1992/numba
    def _get_jit_decorator(self):
        """Gets a jit decorator suitable for the current target"""

        jitter_str = self.metadata.get('target', None)
        if jitter_str is None:
            from numba import jit
            # There is no target requested, use default, this preserves
            # original behaviour
            jitter = lambda *args, **kwargs: jit(*args, nopython=True, **kwargs)
        else:
            from numba.core.target_extension import (target_registry,
                                                     get_local_target,
                                                     jit_registry)

            # target has been requested, see what it is...
            jitter = jit_registry.get(jitter_str, None)

            if jitter is None:
                # No JIT known for target string, see if something is
                # registered for the string and report if not.
                target_class = target_registry.get(jitter_str, None)
                if target_class is None:
                    msg = ("Unknown target '{}', has it been ",
                           "registered?")
                    raise ValueError(msg.format(jitter_str))

                target_hw = get_local_target(self.context)

                # check that the requested target is in the hierarchy for the
                # current frame's target.
                if not issubclass(target_hw, target_class):
                    msg = "No overloads exist for the requested target: {}."

                jitter = jit_registry[target_hw]

        if jitter is None:
            raise ValueError("Cannot find a suitable jit decorator")

        return jitter
コード例 #5
0
    def get_call_type(self, context, args, kws):
        from numba.core.target_extension import (target_registry,
                                                 get_local_target)

        prefer_lit = [True, False]    # old behavior preferring literal
        prefer_not = [False, True]    # new behavior preferring non-literal
        failures = _ResolutionFailures(context, self, args, kws,
                                       depth=self._depth)

        # get the current target target
        target_hw = get_local_target(context)

        # fish out templates that are specific to the target if a target is
        # specified
        DEFAULT_TARGET = 'generic'
        usable = []
        for ix, temp_cls in enumerate(self.templates):
            # ? Need to do something about this next line
            hw = temp_cls.metadata.get('target', DEFAULT_TARGET)
            if hw is not None:
                hw_clazz = target_registry[hw]
                if target_hw.inherits_from(hw_clazz):
                    usable.append((temp_cls, hw_clazz, ix))

        # sort templates based on target specificity
        def key(x):
            return target_hw.__mro__.index(x[1])
        order = [x[0] for x in sorted(usable, key=key)]

        if not order:
            msg = (f"Function resolution cannot find any matches for function"
                   f" '{self.key[0]}' for the current target: '{target_hw}'.")
            raise errors.UnsupportedError(msg)

        self._depth += 1

        for temp_cls in order:
            temp = temp_cls(context)
            # The template can override the default and prefer literal args
            choice = prefer_lit if temp.prefer_literal else prefer_not
            for uselit in choice:
                try:
                    if uselit:
                        sig = temp.apply(args, kws)
                    else:
                        nolitargs = tuple([_unlit_non_poison(a) for a in args])
                        nolitkws = {k: _unlit_non_poison(v)
                                    for k, v in kws.items()}
                        sig = temp.apply(nolitargs, nolitkws)
                except Exception as e:
                    sig = None
                    failures.add_error(temp, False, e, uselit)
                else:
                    if sig is not None:
                        self._impl_keys[sig.args] = temp.get_impl_key(sig)
                        self._depth -= 1
                        return sig
                    else:
                        registered_sigs = getattr(temp, 'cases', None)
                        if registered_sigs is not None:
                            msg = "No match for registered cases:\n%s"
                            msg = msg % '\n'.join(" * {}".format(x) for x in
                                                  registered_sigs)
                        else:
                            msg = 'No match.'
                        failures.add_error(temp, True, msg, uselit)

        failures.raise_error()
コード例 #6
0
ファイル: templates.py プロジェクト: zhaijf1992/numba
    def generic(self, args, kws):
        """
        Type the intrinsic by the arguments.
        """
        from numba.core.target_extension import (get_local_target,
                                                 resolve_target_str,
                                                 dispatcher_registry)
        from numba.core.imputils import builtin_registry

        cache_key = self.context, args, tuple(kws.items())
        hwstr = self.metadata.get('target', 'generic')
        # Get the class for the target declared by the function
        hw_clazz = resolve_target_str(hwstr)
        # get the local target
        target_hw = get_local_target(self.context)
        # make sure the target_hw is in the MRO for hw_clazz else bail
        if not target_hw.inherits_from(hw_clazz):
            msg = (f"Intrinsic being resolved on a target from which it does "
                   f"not inherit. Local target is {target_hw}, declared "
                   f"target class is {hw_clazz}.")
            raise InternalError(msg)
        disp = dispatcher_registry[target_hw]
        tgtctx = disp.targetdescr.target_context
        # This is all workarounds...
        # The issue is that whilst targets shouldn't care about which registry
        # in which to register lowering implementations, the CUDA target
        # "borrows" implementations from the CPU from specific registries. This
        # means that if some impl is defined via @intrinsic, e.g. numba.*unsafe
        # modules, _AND_ CUDA also makes use of the same impl, then it's
        # required that the registry in use is one that CUDA borrows from. This
        # leads to the following expression where by the CPU builtin_registry is
        # used if it is in the target context as a known registry (i.e. the
        # target installed it) and if it is not then it is assumed that the
        # registries for the target are unbound to any other target and so it's
        # fine to use any of them as a place to put lowering impls.
        #
        # NOTE: This will need subsequently fixing again when targets use solely
        # the extension APIs to describe their implementation. The issue will be
        # that the builtin_registry should contain _just_ the stack allocated
        # implementations and low level target invariant things and should not
        # be modified further. It should be acceptable to remove the `then`
        # branch and just keep the `else`.

        # In case the target has swapped, e.g. cuda borrowing cpu, refresh to
        # populate.
        tgtctx.refresh()
        if builtin_registry in tgtctx._registries:
            reg = builtin_registry
        else:
            # Pick a registry in which to install intrinsics
            registries = iter(tgtctx._registries)
            reg = next(registries)
        lower_builtin = reg.lower
        try:
            return self._impl_cache[cache_key]
        except KeyError:
            pass
        result = self._definition_func(self.context, *args, **kws)
        if result is None:
            return
        [sig, imp] = result
        pysig = utils.pysignature(self._definition_func)
        # omit context argument from user function
        parameters = list(pysig.parameters.values())[1:]
        sig = sig.replace(pysig=pysig.replace(parameters=parameters))
        self._impl_cache[cache_key] = sig
        self._overload_cache[sig.args] = imp
        # register the lowering
        lower_builtin(imp, *sig.args)(imp)
        return sig