Ejemplo n.º 1
0
class Dispatcher(_DispatcherBase):
    """
    Implementation of user-facing dispatcher objects (i.e. created using
    the @jit decorator).
    This is an abstract base class. Subclasses should define the targetdescr
    class attribute.
    """
    _fold_args = True
    _impl_kinds = {
        'direct': _FunctionCompiler,
        'generated': _GeneratedFunctionCompiler,
        }
    # A {uuid -> instance} mapping, for deserialization
    _memo = weakref.WeakValueDictionary()
    # hold refs to last N functions deserialized, retaining them in _memo
    # regardless of whether there is another reference
    _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)
    __uuid = None
    __numba__ = 'py_func'

    def __init__(self, py_func, locals={}, targetoptions={},
                 impl_kind='direct', pipeline_class=compiler.Compiler):
        """
        Parameters
        ----------
        py_func: function object to be compiled
        locals: dict, optional
            Mapping of local variable names to Numba types.  Used to override
            the types deduced by the type inference engine.
        targetoptions: dict, optional
            Target-specific config options.
        impl_kind: str
            Select the compiler mode for `@jit` and `@generated_jit`
        pipeline_class: type numba.compiler.CompilerBase
            The compiler pipeline type.
        """
        self.typingctx = self.targetdescr.typing_context
        self.targetctx = self.targetdescr.target_context

        pysig = utils.pysignature(py_func)
        arg_count = len(pysig.parameters)
        can_fallback = not targetoptions.get('nopython', False)

        _DispatcherBase.__init__(self, arg_count, py_func, pysig, can_fallback,
                                 exact_match_required=False)

        functools.update_wrapper(self, py_func)

        self.targetoptions = targetoptions
        self.locals = locals
        self._cache = NullCache()
        compiler_class = self._impl_kinds[impl_kind]
        self._impl_kind = impl_kind
        self._compiler = compiler_class(py_func, self.targetdescr,
                                        targetoptions, locals, pipeline_class)
        self._cache_hits = collections.Counter()
        self._cache_misses = collections.Counter()

        self._type = types.Dispatcher(self)
        self.typingctx.insert_global(self, self._type)

    def dump(self, tab=''):
        print(f'{tab}DUMP {type(self).__name__}[{self.py_func.__name__}, type code={self._type._code}]')
        for cres in self.overloads.values():
            cres.dump(tab = tab + '  ')
        print(f'{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]')

    @property
    def _numba_type_(self):
        return types.Dispatcher(self)

    def enable_caching(self):
        self._cache = FunctionCache(self.py_func)

    def __get__(self, obj, objtype=None):
        '''Allow a JIT function to be bound as a method to an object'''
        if obj is None:  # Unbound method
            return self
        else:  # Bound method
            return pytypes.MethodType(self, obj)

    def __reduce__(self):
        """
        Reduce the instance for pickling.  This will serialize
        the original function as well the compilation options and
        compiled signatures, but not the compiled code itself.
        """
        if self._can_compile:
            sigs = []
        else:
            sigs = [cr.signature for cr in self.overloads.values()]
        globs = self._compiler.get_globals_for_reduction()
        return (serialize._rebuild_reduction,
                (self.__class__, str(self._uuid),
                 serialize._reduce_function(self.py_func, globs),
                 self.locals, self.targetoptions, self._impl_kind,
                 self._can_compile, sigs))

    @classmethod
    def _rebuild(cls, uuid, func_reduced, locals, targetoptions, impl_kind,
                 can_compile, sigs):
        """
        Rebuild an Dispatcher instance after it was __reduce__'d.
        """
        try:
            return cls._memo[uuid]
        except KeyError:
            pass
        py_func = serialize._rebuild_function(*func_reduced)
        self = cls(py_func, locals, targetoptions, impl_kind)
        # Make sure this deserialization will be merged with subsequent ones
        self._set_uuid(uuid)
        for sig in sigs:
            self.compile(sig)
        self._can_compile = can_compile
        return self

    @property
    def _uuid(self):
        """
        An instance-specific UUID, to avoid multiple deserializations of
        a given instance.

        Note this is lazily-generated, for performance reasons.
        """
        u = self.__uuid
        if u is None:
            u = str(uuid.uuid1())
            self._set_uuid(u)
        return u

    def _set_uuid(self, u):
        assert self.__uuid is None
        self.__uuid = u
        self._memo[u] = self
        self._recent.append(self)

    @global_compiler_lock
    def compile(self, sig):
        if not self._can_compile:
            raise RuntimeError("compilation disabled")
        # Use counter to track recursion compilation depth
        with self._compiling_counter:
            args, return_type = sigutils.normalize_signature(sig)
            # Don't recompile if signature already exists
            existing = self.overloads.get(tuple(args))
            if existing is not None:
                return existing.entry_point
            # Try to load from disk cache
            cres = self._cache.load_overload(sig, self.targetctx)
            if cres is not None:
                self._cache_hits[sig] += 1
                # XXX fold this in add_overload()? (also see compiler.py)
                if not cres.objectmode and not cres.interpmode:
                    self.targetctx.insert_user_function(cres.entry_point,
                                                        cres.fndesc, [cres.library])
                self.add_overload(cres)
                return cres.entry_point

            self._cache_misses[sig] += 1
            try:
                cres = self._compiler.compile(args, return_type)
            except errors.ForceLiteralArg as e:
                def folded(args, kws):
                    return self._compiler.fold_argument_types(args, kws)[1]
                raise e.bind_fold_arguments(folded)
            self.add_overload(cres)
            self._cache.save_overload(sig, cres)
            return cres.entry_point

    def get_compile_result(self, sig):
        """Compile (if needed) and return the compilation result with the
        given signature.
        """
        atypes = tuple(sig.args)
        if atypes not in self.overloads:
            self.compile(atypes)
        return self.overloads[atypes]

    def recompile(self):
        """
        Recompile all signatures afresh.
        """
        sigs = list(self.overloads)
        old_can_compile = self._can_compile
        # Ensure the old overloads are disposed of, including compiled functions.
        self._make_finalizer()()
        self._reset_overloads()
        self._cache.flush()
        self._can_compile = True
        try:
            for sig in sigs:
                self.compile(sig)
        finally:
            self._can_compile = old_can_compile

    @property
    def stats(self):
        return _CompileStats(
            cache_path=self._cache.cache_path,
            cache_hits=self._cache_hits,
            cache_misses=self._cache_misses,
            )

    def parallel_diagnostics(self, signature=None, level=1):
        """
        Print parallel diagnostic information for the given signature. If no
        signature is present it is printed for all known signatures. level is
        used to adjust the verbosity, level=1 (default) is minimal verbosity,
        and 2, 3, and 4 provide increasing levels of verbosity.
        """
        def dump(sig):
            ol = self.overloads[sig]
            pfdiag = ol.metadata.get('parfor_diagnostics', None)
            if pfdiag is None:
                msg = "No parfors diagnostic available, is 'parallel=True' set?"
                raise ValueError(msg)
            pfdiag.dump(level)
        if signature is not None:
            dump(signature)
        else:
            [dump(sig) for sig in self.signatures]

    def get_metadata(self, signature=None):
        """
        Obtain the compilation metadata for a given signature.
        """
        if signature is not None:
            return self.overloads[signature].metadata
        else:
            return dict((sig, self.overloads[sig].metadata) for sig in self.signatures)

    def get_function_type(self):
        """Return unique function type of dispatcher when possible, otherwise
        return None.

        A Dispatcher instance has unique function type when it
        contains exactly one compilation result and its compilation
        has been disabled (via its disable_compile method).
        """
        if not self._can_compile and len(self.overloads) == 1:
            cres = tuple(self.overloads.values())[0]
            return types.FunctionType(cres.signature)
Ejemplo n.º 2
0
class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
    """
    Implementation of user-facing dispatcher objects (i.e. created using
    the @jit decorator).
    This is an abstract base class. Subclasses should define the targetdescr
    class attribute.
    """
    _fold_args = True
    _impl_kinds = {
        'direct': _FunctionCompiler,
        'generated': _GeneratedFunctionCompiler,
    }

    __numba__ = 'py_func'

    def __init__(self,
                 py_func,
                 locals={},
                 targetoptions={},
                 impl_kind='direct',
                 pipeline_class=compiler.Compiler):
        """
        Parameters
        ----------
        py_func: function object to be compiled
        locals: dict, optional
            Mapping of local variable names to Numba types.  Used to override
            the types deduced by the type inference engine.
        targetoptions: dict, optional
            Target-specific config options.
        impl_kind: str
            Select the compiler mode for `@jit` and `@generated_jit`
        pipeline_class: type numba.compiler.CompilerBase
            The compiler pipeline type.
        """
        self.typingctx = self.targetdescr.typing_context
        self.targetctx = self.targetdescr.target_context

        pysig = utils.pysignature(py_func)
        arg_count = len(pysig.parameters)
        can_fallback = not targetoptions.get('nopython', False)

        _DispatcherBase.__init__(self,
                                 arg_count,
                                 py_func,
                                 pysig,
                                 can_fallback,
                                 exact_match_required=False)

        functools.update_wrapper(self, py_func)

        self.targetoptions = targetoptions
        self.locals = locals
        self._cache = NullCache()
        compiler_class = self._impl_kinds[impl_kind]
        self._impl_kind = impl_kind
        self._compiler = compiler_class(py_func, self.targetdescr,
                                        targetoptions, locals, pipeline_class)
        self._cache_hits = collections.Counter()
        self._cache_misses = collections.Counter()

        self._type = types.Dispatcher(self)
        self.typingctx.insert_global(self, self._type)

        # Remember target restriction
        self._required_target_backend = targetoptions.get('target_backend')

    def dump(self, tab=''):
        print(f'{tab}DUMP {type(self).__name__}[{self.py_func.__name__}'
              f', type code={self._type._code}]')
        for cres in self.overloads.values():
            cres.dump(tab=tab + '  ')
        print(f'{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]')

    @property
    def _numba_type_(self):
        return types.Dispatcher(self)

    def enable_caching(self):
        self._cache = FunctionCache(self.py_func)

    def __get__(self, obj, objtype=None):
        '''Allow a JIT function to be bound as a method to an object'''
        if obj is None:  # Unbound method
            return self
        else:  # Bound method
            return pytypes.MethodType(self, obj)

    def _reduce_states(self):
        """
        Reduce the instance for pickling.  This will serialize
        the original function as well the compilation options and
        compiled signatures, but not the compiled code itself.

        NOTE: part of ReduceMixin protocol
        """
        if self._can_compile:
            sigs = []
        else:
            sigs = [cr.signature for cr in self.overloads.values()]

        return dict(
            uuid=str(self._uuid),
            py_func=self.py_func,
            locals=self.locals,
            targetoptions=self.targetoptions,
            impl_kind=self._impl_kind,
            can_compile=self._can_compile,
            sigs=sigs,
        )

    @classmethod
    def _rebuild(cls, uuid, py_func, locals, targetoptions, impl_kind,
                 can_compile, sigs):
        """
        Rebuild an Dispatcher instance after it was __reduce__'d.

        NOTE: part of ReduceMixin protocol
        """
        try:
            return cls._memo[uuid]
        except KeyError:
            pass
        self = cls(py_func, locals, targetoptions, impl_kind)
        # Make sure this deserialization will be merged with subsequent ones
        self._set_uuid(uuid)
        for sig in sigs:
            self.compile(sig)
        self._can_compile = can_compile
        return self

    def compile(self, sig):
        disp = self._get_dispatcher_for_current_target()
        if disp is not self:
            return disp.compile(sig)

        with ExitStack() as scope:
            cres = None

            def cb_compiler(dur):
                if cres is not None:
                    self._callback_add_compiler_timer(dur, cres)

            def cb_llvm(dur):
                if cres is not None:
                    self._callback_add_llvm_timer(dur, cres)

            scope.enter_context(
                ev.install_timer("numba:compiler_lock", cb_compiler))
            scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm))
            scope.enter_context(global_compiler_lock)

            if not self._can_compile:
                raise RuntimeError("compilation disabled")
            # Use counter to track recursion compilation depth
            with self._compiling_counter:
                args, return_type = sigutils.normalize_signature(sig)
                # Don't recompile if signature already exists
                existing = self.overloads.get(tuple(args))
                if existing is not None:
                    return existing.entry_point
                # Try to load from disk cache
                cres = self._cache.load_overload(sig, self.targetctx)
                if cres is not None:
                    self._cache_hits[sig] += 1
                    # XXX fold this in add_overload()? (also see compiler.py)
                    if not cres.objectmode:
                        self.targetctx.insert_user_function(
                            cres.entry_point, cres.fndesc, [cres.library])
                    self.add_overload(cres)
                    return cres.entry_point

                self._cache_misses[sig] += 1
                ev_details = dict(
                    dispatcher=self,
                    args=args,
                    return_type=return_type,
                )
                with ev.trigger_event("numba:compile", data=ev_details):
                    try:
                        cres = self._compiler.compile(args, return_type)
                    except errors.ForceLiteralArg as e:

                        def folded(args, kws):
                            return self._compiler.fold_argument_types(
                                args, kws)[1]

                        raise e.bind_fold_arguments(folded)
                    self.add_overload(cres)
                self._cache.save_overload(sig, cres)
                return cres.entry_point

    def get_compile_result(self, sig):
        """Compile (if needed) and return the compilation result with the
        given signature.
        """
        atypes = tuple(sig.args)
        if atypes not in self.overloads:
            self.compile(atypes)
        return self.overloads[atypes]

    def recompile(self):
        """
        Recompile all signatures afresh.
        """
        sigs = list(self.overloads)
        old_can_compile = self._can_compile
        # Ensure the old overloads are disposed of,
        # including compiled functions.
        self._make_finalizer()()
        self._reset_overloads()
        self._cache.flush()
        self._can_compile = True
        try:
            for sig in sigs:
                self.compile(sig)
        finally:
            self._can_compile = old_can_compile

    @property
    def stats(self):
        return _CompileStats(
            cache_path=self._cache.cache_path,
            cache_hits=self._cache_hits,
            cache_misses=self._cache_misses,
        )

    def parallel_diagnostics(self, signature=None, level=1):
        """
        Print parallel diagnostic information for the given signature. If no
        signature is present it is printed for all known signatures. level is
        used to adjust the verbosity, level=1 (default) is minimal verbosity,
        and 2, 3, and 4 provide increasing levels of verbosity.
        """
        def dump(sig):
            ol = self.overloads[sig]
            pfdiag = ol.metadata.get('parfor_diagnostics', None)
            if pfdiag is None:
                msg = "No parfors diagnostic available, is 'parallel=True' set?"
                raise ValueError(msg)
            pfdiag.dump(level)

        if signature is not None:
            dump(signature)
        else:
            [dump(sig) for sig in self.signatures]

    def get_metadata(self, signature=None):
        """
        Obtain the compilation metadata for a given signature.
        """
        if signature is not None:
            return self.overloads[signature].metadata
        else:
            return dict(
                (sig, self.overloads[sig].metadata) for sig in self.signatures)

    def get_function_type(self):
        """Return unique function type of dispatcher when possible, otherwise
        return None.

        A Dispatcher instance has unique function type when it
        contains exactly one compilation result and its compilation
        has been disabled (via its disable_compile method).
        """
        if not self._can_compile and len(self.overloads) == 1:
            cres = tuple(self.overloads.values())[0]
            return types.FunctionType(cres.signature)

    def _get_retarget_dispatcher(self):
        """Returns a dispatcher for the retarget request.
        """
        # Check TLS target configuration
        tc = TargetConfigurationStack()
        retarget = tc.get()
        retarget.check_compatible(self)
        disp = retarget.retarget(self)
        return disp

    def _get_dispatcher_for_current_target(self):
        """Returns a dispatcher for the current target registered in
        `TargetConfigurationStack`. `self` is returned if no target is
        specified.
        """
        tc = TargetConfigurationStack()
        if tc:
            return self._get_retarget_dispatcher()
        else:
            return self

    def _call_tls_target(self, *args, **kwargs):
        """This is called when the C dispatcher logic sees a retarget request.
        """
        disp = self._get_retarget_dispatcher()
        # Call the new dispatcher
        return disp(*args, **kwargs)