Exemplo n.º 1
0
    def test_entrypoint_tolerance(self):
        # loosely based on Pandas test from:
        #   https://github.com/pandas-dev/pandas/pull/27488

        # FIXME: Python 2 workaround because nonlocal doesn't exist
        counters = {"init": 0}

        def init_function():
            counters["init"] += 1
            raise ValueError("broken")

        mod = types.ModuleType("_test_numba_bad_extension")
        mod.init_func = init_function

        try:
            # will remove this module at the end of the test
            sys.modules[mod.__name__] = mod

            # We are registering an entry point using the "numba" package
            # ("distribution" in pkg_resources-speak) itself, though these are
            # normally registered by other packages.
            dist = "numba"
            entrypoints = pkg_resources.get_entry_map(dist)
            my_entrypoint = pkg_resources.EntryPoint(
                "init",  # name of entry point
                mod.__name__,  # module with entry point object
                attrs=["init_func"],  # name of entry point object
                dist=pkg_resources.get_distribution(dist),
            )
            entrypoints.setdefault("numba_extensions", {})["init"] = my_entrypoint

            from numba.core import entrypoints

            # Allow reinitialization
            entrypoints._already_initialized = False

            with warnings.catch_warnings(record=True) as w:
                entrypoints.init_all()

            bad_str = "Numba extension module '_test_numba_bad_extension'"
            for x in w:
                if bad_str in str(x):
                    break
            else:
                raise ValueError("Expected warning message not found")

            # was our init function called?
            self.assertEqual(counters["init"], 1)

        finally:
            # remove fake module
            if mod.__name__ in sys.modules:
                del sys.modules[mod.__name__]
Exemplo n.º 2
0
    def test_entrypoint_handles_type_extensions(self):
        # loosely based on Pandas test from:
        #   https://github.com/pandas-dev/pandas/pull/27488
        import numba

        def init_function():
            # This init function would normally just call a module init via
            # import or similar, for the sake of testing, inline registration
            # of how to handle the global "_DummyClass".
            class DummyType(numba.types.Type):
                def __init__(self):
                    super(DummyType, self).__init__(name='DummyType')

            @numba.extending.typeof_impl.register(_DummyClass)
            def typer_DummyClass(val, c):
                return DummyType()

            @numba.extending.register_model(DummyType)
            class DummyModel(numba.extending.models.StructModel):
                def __init__(self, dmm, fe_type):
                    members = [
                        ('value', numba.types.float64),
                    ]
                    super(DummyModel, self).__init__(dmm, fe_type, members)

            @numba.extending.unbox(DummyType)
            def unbox_dummy(typ, obj, c):
                value_obj = c.pyapi.object_getattr_string(obj, "value")
                dummy_struct_proxy = numba.core.cgutils.create_struct_proxy(
                    typ)
                dummy_struct = dummy_struct_proxy(c.context, c.builder)
                dummy_struct.value = c.pyapi.float_as_double(value_obj)
                c.pyapi.decref(value_obj)
                err_flag = c.pyapi.err_occurred()
                is_error = numba.core.cgutils.is_not_null(c.builder, err_flag)
                return numba.extending.NativeValue(dummy_struct._getvalue(),
                                                   is_error=is_error)

            @numba.extending.box(DummyType)
            def box_dummy(typ, val, c):
                dummy_struct_proxy = numba.core.cgutils.create_struct_proxy(
                    typ)
                dummy_struct = dummy_struct_proxy(c.context, c.builder)
                value_obj = c.pyapi.float_from_double(dummy_struct.value)
                serialized_clazz = c.pyapi.serialize_object(_DummyClass)
                class_obj = c.pyapi.unserialize(serialized_clazz)
                res = c.pyapi.call_function_objargs(class_obj, (value_obj, ))
                c.pyapi.decref(value_obj)
                c.pyapi.decref(class_obj)
                return res

        mod = types.ModuleType("_test_numba_init_sequence")
        mod.init_func = init_function

        try:
            # will remove this module at the end of the test
            sys.modules[mod.__name__] = mod

            # We are registering an entry point using the "numba" package
            # ("distribution" in pkg_resources-speak) itself, though these are
            # normally registered by other packages.
            dist = "numba"
            entrypoints = pkg_resources.get_entry_map(dist)
            my_entrypoint = pkg_resources.EntryPoint(
                "init",  # name of entry point
                mod.__name__,  # module with entry point object
                attrs=['init_func'],  # name of entry point object
                dist=pkg_resources.get_distribution(dist))
            entrypoints.setdefault('numba_extensions',
                                   {})['init'] = my_entrypoint

            @njit
            def foo(x):
                return x

            ival = _DummyClass(10)
            foo(ival)
        finally:
            # remove fake module
            if mod.__name__ in sys.modules:
                del sys.modules[mod.__name__]