Exemplo n.º 1
0
 def write_new_header(self, path_backend_header, header, arg_types):
     mpi.barrier()
     if mpi.rank == 0:
         logger.debug(f"write {self.name_capitalized} signature in file "
                      f"{path_backend_header} with types\n{arg_types}")
         with open(path_backend_header, "w") as file:
             file.write(header)
             file.flush()
Exemplo n.º 2
0
def jit_class(cls, jit_methods, backend):
    """Modify the class by replacing jit methods

    1. create a Python file with @jit functions and methods
    2. import the file
    3. replace the methods

    """
    if not has_to_replace:
        return cls

    cls_name = cls.__name__
    mod_name = cls.__module__
    module = sys.modules[mod_name]

    if mod_name == "__main__":
        mod_name = find_module_name_from_path(module.__file__)

    path_jit_class = mpi.Path(backend.jit.path_class)

    # 1. create a Python file with @jit functions and methods
    python_path_dir = path_jit_class / mod_name.replace(".", os.path.sep)
    python_path = python_path_dir / (cls_name + ".py")

    if mpi.has_to_build(python_path, module.__file__):
        from transonic.justintime import _get_module_jit

        mod = _get_module_jit(backend_name=backend.name, depth_frame=5)
        if mpi.rank == 0:
            python_path = mpi.PathSeq(python_path)
            python_code = (
                mod.info_analysis["codes_dependance_classes"][cls_name] + "\n"
            )
            python_code += backend.jit.produce_code_class(cls)
            write_if_has_to_write(python_path, python_code)
            python_path = mpi.Path(python_path)
        mpi.barrier()

    # 2. import the file
    python_mod_name = path_jit_class.name + "." + mod_name + "." + cls_name
    module = import_from_path(python_path, python_mod_name)

    # 3. replace the methods
    for name_method, method in jit_methods.items():
        func = method.func
        name_new_method = f"__new_method__{cls.__name__}__{name_method}"
        new_method = getattr(module, name_new_method)
        setattr(cls, name_method, functools.wraps(func)(new_method))

    return cls
Exemplo n.º 3
0
    def block_until_avail(self, parallel=True):

        if mpi.rank == 0:
            if parallel:
                limit = self.limit_nb_processes
            else:
                limit = 1

            while len(self.processes) >= limit:
                time.sleep(self.deltat)
                self.processes = [
                    process for process in self.processes
                    if process.is_alive_root()
                ]

        mpi.barrier(timeout=None)
Exemplo n.º 4
0
    def wait_for_all_extensions(self):
        """Wait until all compilation processes are done"""
        if mpi.rank == 0:
            total = len(scheduler.processes)
            task = self.progress.add_task("Wait for all extensions",
                                          total=total)

            while self.processes:
                time.sleep(self.deltat)
                self.processes = [
                    process for process in self.processes
                    if process.is_alive_root()
                ]
                self.progress.update(task,
                                     completed=total - len(self.processes))

        mpi.barrier(timeout=None)
Exemplo n.º 5
0
def test_not_transonified():

    path_for_test = (Path(__file__).parent.parent /
                     "_transonic_testing/for_test_init.py")
    path_output = path_for_test.parent / f"__{backend_default}__"

    if path_output.exists() and mpi.rank == 0:
        rmtree(path_output)
    mpi.barrier()

    from _transonic_testing import for_test_init

    importlib.reload(for_test_init)

    from _transonic_testing.for_test_init import func, func1, check_class

    func(1, 3.14)
    func1(1.1, 2.2)
    check_class()
Exemplo n.º 6
0
    def test_transonified(self):

        print(mpi.rank, "start test", flush=1)

        try:
            os.environ.pop("TRANSONIC_COMPILE_AT_IMPORT")
        except KeyError:
            pass

        try:
            del modules[module_name]
        except KeyError:
            pass

        assert not has_to_compile_at_import()

        print(mpi.rank, "before if self.path_backend.exists()", flush=1)

        if self.path_backend.exists():

            print(mpi.rank, "before self.path_backend.unlink()", flush=1)

            self.path_backend.unlink()

        print(mpi.rank,
              "before make_backend_file(self.path_for_test)",
              flush=1)
        if mpi.rank == 0:
            backend.make_backend_file(self.path_for_test)

        print(mpi.rank, "after make_backend_file(self.path_for_test)", flush=1)
        mpi.barrier()

        from _transonic_testing import for_test_init

        importlib.reload(for_test_init)

        assert self.path_backend.exists()
        assert for_test_init.ts.is_transpiled

        for_test_init.func(1, 3.14)
        for_test_init.func1(1.1, 2.2)
        for_test_init.check_class()
Exemplo n.º 7
0
path_jit = backend.jit.path_base
path_jit_classes = backend.jit.path_class

path_jit_dir = path_jit / str_relative_path
path_classes_dir = path_jit_classes / str_relative_path
path_classes_dir1 = path_jit / path_jit_classes.name / str_relative_path

if mpi.rank == 0:
    if path_jit.exists():
        rmtree(path_jit_dir, ignore_errors=True)
    if path_classes_dir.exists():
        rmtree(path_classes_dir)
    if path_classes_dir1.exists():
        rmtree(path_classes_dir1)

mpi.barrier()


@pytest.mark.skipif(backend_default == "numba", reason="Not supported by Numba")
def test_jit():
    from _transonic_testing.for_test_justintime import func1

    a = np.arange(2)
    b = [1, 2]

    for _ in range(2):
        func1(a, b)
        sleep(0.1)

    wait_for_all_extensions()
    func1(a, b)
Exemplo n.º 8
0
    def __call__(self, func):
        if not has_to_replace:
            return func

        if is_method(func):
            return TransonicTemporaryJITMethod(func, self.native, self.xsimd,
                                               self.openmp)

        if not can_import_accelerator(self.backend.name):
            logger.warning(
                "Cannot accelerate a jitted function because "
                f"{self.backend.name_capitalized} is not importable.")
            return func

        func_name = func.__name__

        backend = self.backend
        mod = self.mod
        mod.jit_functions[func_name] = self
        module_name = mod.module_name

        path_jit = mpi.Path(backend.jit.path_base)
        path_backend = path_jit / module_name.replace(".", os.path.sep)

        if mpi.rank == 0:
            path_backend.mkdir(parents=True, exist_ok=True)
        mpi.barrier()

        path_backend = (path_backend / func_name).with_suffix(".py")
        if backend.suffix_header:
            path_backend_header = path_backend.with_suffix(
                backend.suffix_header)
        else:
            path_backend_header = False

        if path_backend.exists():
            if not mod.is_dummy_file and has_to_build(path_backend,
                                                      mod.filename):
                has_to_write = True
            else:
                has_to_write = False
        else:
            has_to_write = True

        src = None

        if has_to_write:
            src, has_to_write = backend.jit.make_backend_source(
                mod.info_analysis, func, path_backend)

            if has_to_write and mpi.rank == 0:
                logger.debug(f"write code in file {path_backend}")
                with open(path_backend, "w") as file:
                    file.write(src)
                    file.flush()

        if src is None and mpi.rank == 0:
            with open(path_backend) as file:
                src = file.read()

        hex_src = None
        name_mod = None
        if mpi.rank == 0:
            # hash from src (to produce the extension name)
            hex_src = make_hex(src)
            name_mod = ".".join(path_backend.absolute().relative_to(
                path_root).with_suffix("").parts)

        hex_src = mpi.bcast(hex_src)
        name_mod = mpi.bcast(name_mod)

        def backenize_with_new_header(arg_types="no types"):

            header_object = backend.jit.make_new_header(func, arg_types)

            header_code = backend.jit.merge_old_and_new_header(
                path_backend_header, header_object, func)
            backend.jit.write_new_header(path_backend_header, header_code,
                                         arg_types)

            # compute the new path of the extension
            hex_header = make_hex(header_code)
            # if mpi.nb_proc > 1:
            #     hex_header0 = mpi.bcast(hex_header)
            #     assert hex_header0 == hex_header
            name_ext_file = (func_name + "_" + hex_src + "_" + hex_header +
                             backend.suffix_extension)
            self.path_extension = path_backend.with_name(name_ext_file)

            self.compiling, self.process = backend.compile_extension(
                path_backend,
                name_ext_file,
                native=self.native,
                xsimd=self.xsimd,
                openmp=self.openmp,
            )

            # for backend like numba
            if not self.compiling:
                backend_module = import_from_path(self.path_extension,
                                                  name_mod)
                assert backend.check_if_compiled(backend_module)
                self.backend_func = getattr(backend_module, func_name)

        ext_files = None
        if mpi.rank == 0:
            glob_name_ext_file = (func_name + "_" + hex_src + "_*" +
                                  backend.suffix_extension)
            ext_files = list(
                mpi.PathSeq(path_backend).parent.glob(glob_name_ext_file))
        ext_files = mpi.bcast(ext_files)

        if not ext_files:
            if has_to_compile_at_import() and _COMPILE_JIT:
                backenize_with_new_header()
            self.backend_func = None
        else:
            path_ext = max(ext_files, key=lambda p: p.stat().st_ctime)
            backend_module = import_from_path(path_ext, name_mod)
            self.backend_func = getattr(backend_module, func_name)

        # this is the function that will be called by the user
        @wraps(func)
        def type_collector(*args, **kwargs):

            if self.compiling:
                if not self.process.is_alive(raise_if_error=True):
                    self.compiling = False
                    time.sleep(0.1)
                    backend_module = import_from_path(self.path_extension,
                                                      name_mod)
                    assert backend.check_if_compiled(backend_module)
                    self.backend_func = getattr(backend_module, func_name)

            try:
                return self.backend_func(*args, **kwargs)
            except TypeError as err:
                # need to compiled or recompile
                error = False
                if self.backend_func:
                    error = str(err)
                    if (error.startswith(
                            "Invalid call to pythranized function `")
                            and " (reshaped)" in error):
                        logger.error(
                            "It seems that a jitted Pythran function has been called "
                            'with a "reshaped" array which is not supported by Pythran.'
                        )
                        raise
                    logger.debug(error)

            if self.compiling or not _COMPILE_JIT:
                return func(*args, **kwargs)

            if (self.backend_func and error and error.startswith(
                    "Invalid call to pythranized function `")):
                logger.debug(error)
                logger.info(
                    f"{backend.name_capitalized} function `{func_name}` called with new types."
                )
                logger.debug(
                    "Transonic is going to recompute the function for the new types."
                )

            arg_types = [
                backend.jit.compute_typename_from_object(arg)
                for arg in itertools.chain(args, kwargs.values())
            ]

            backenize_with_new_header(arg_types)
            return func(*args, **kwargs)

        return type_collector