def write_new_header(self, path_backend_header, header, arg_types): mpi.barrier() if mpi.rank == 0: logger.debug(f"write {self.name_capitalized} signature in file " f"{path_backend_header} with types\n{arg_types}") with open(path_backend_header, "w") as file: file.write(header) file.flush()
def jit_class(cls, jit_methods, backend): """Modify the class by replacing jit methods 1. create a Python file with @jit functions and methods 2. import the file 3. replace the methods """ if not has_to_replace: return cls cls_name = cls.__name__ mod_name = cls.__module__ module = sys.modules[mod_name] if mod_name == "__main__": mod_name = find_module_name_from_path(module.__file__) path_jit_class = mpi.Path(backend.jit.path_class) # 1. create a Python file with @jit functions and methods python_path_dir = path_jit_class / mod_name.replace(".", os.path.sep) python_path = python_path_dir / (cls_name + ".py") if mpi.has_to_build(python_path, module.__file__): from transonic.justintime import _get_module_jit mod = _get_module_jit(backend_name=backend.name, depth_frame=5) if mpi.rank == 0: python_path = mpi.PathSeq(python_path) python_code = ( mod.info_analysis["codes_dependance_classes"][cls_name] + "\n" ) python_code += backend.jit.produce_code_class(cls) write_if_has_to_write(python_path, python_code) python_path = mpi.Path(python_path) mpi.barrier() # 2. import the file python_mod_name = path_jit_class.name + "." + mod_name + "." + cls_name module = import_from_path(python_path, python_mod_name) # 3. replace the methods for name_method, method in jit_methods.items(): func = method.func name_new_method = f"__new_method__{cls.__name__}__{name_method}" new_method = getattr(module, name_new_method) setattr(cls, name_method, functools.wraps(func)(new_method)) return cls
def block_until_avail(self, parallel=True): if mpi.rank == 0: if parallel: limit = self.limit_nb_processes else: limit = 1 while len(self.processes) >= limit: time.sleep(self.deltat) self.processes = [ process for process in self.processes if process.is_alive_root() ] mpi.barrier(timeout=None)
def wait_for_all_extensions(self): """Wait until all compilation processes are done""" if mpi.rank == 0: total = len(scheduler.processes) task = self.progress.add_task("Wait for all extensions", total=total) while self.processes: time.sleep(self.deltat) self.processes = [ process for process in self.processes if process.is_alive_root() ] self.progress.update(task, completed=total - len(self.processes)) mpi.barrier(timeout=None)
def test_not_transonified(): path_for_test = (Path(__file__).parent.parent / "_transonic_testing/for_test_init.py") path_output = path_for_test.parent / f"__{backend_default}__" if path_output.exists() and mpi.rank == 0: rmtree(path_output) mpi.barrier() from _transonic_testing import for_test_init importlib.reload(for_test_init) from _transonic_testing.for_test_init import func, func1, check_class func(1, 3.14) func1(1.1, 2.2) check_class()
def test_transonified(self): print(mpi.rank, "start test", flush=1) try: os.environ.pop("TRANSONIC_COMPILE_AT_IMPORT") except KeyError: pass try: del modules[module_name] except KeyError: pass assert not has_to_compile_at_import() print(mpi.rank, "before if self.path_backend.exists()", flush=1) if self.path_backend.exists(): print(mpi.rank, "before self.path_backend.unlink()", flush=1) self.path_backend.unlink() print(mpi.rank, "before make_backend_file(self.path_for_test)", flush=1) if mpi.rank == 0: backend.make_backend_file(self.path_for_test) print(mpi.rank, "after make_backend_file(self.path_for_test)", flush=1) mpi.barrier() from _transonic_testing import for_test_init importlib.reload(for_test_init) assert self.path_backend.exists() assert for_test_init.ts.is_transpiled for_test_init.func(1, 3.14) for_test_init.func1(1.1, 2.2) for_test_init.check_class()
path_jit = backend.jit.path_base path_jit_classes = backend.jit.path_class path_jit_dir = path_jit / str_relative_path path_classes_dir = path_jit_classes / str_relative_path path_classes_dir1 = path_jit / path_jit_classes.name / str_relative_path if mpi.rank == 0: if path_jit.exists(): rmtree(path_jit_dir, ignore_errors=True) if path_classes_dir.exists(): rmtree(path_classes_dir) if path_classes_dir1.exists(): rmtree(path_classes_dir1) mpi.barrier() @pytest.mark.skipif(backend_default == "numba", reason="Not supported by Numba") def test_jit(): from _transonic_testing.for_test_justintime import func1 a = np.arange(2) b = [1, 2] for _ in range(2): func1(a, b) sleep(0.1) wait_for_all_extensions() func1(a, b)
def __call__(self, func): if not has_to_replace: return func if is_method(func): return TransonicTemporaryJITMethod(func, self.native, self.xsimd, self.openmp) if not can_import_accelerator(self.backend.name): logger.warning( "Cannot accelerate a jitted function because " f"{self.backend.name_capitalized} is not importable.") return func func_name = func.__name__ backend = self.backend mod = self.mod mod.jit_functions[func_name] = self module_name = mod.module_name path_jit = mpi.Path(backend.jit.path_base) path_backend = path_jit / module_name.replace(".", os.path.sep) if mpi.rank == 0: path_backend.mkdir(parents=True, exist_ok=True) mpi.barrier() path_backend = (path_backend / func_name).with_suffix(".py") if backend.suffix_header: path_backend_header = path_backend.with_suffix( backend.suffix_header) else: path_backend_header = False if path_backend.exists(): if not mod.is_dummy_file and has_to_build(path_backend, mod.filename): has_to_write = True else: has_to_write = False else: has_to_write = True src = None if has_to_write: src, has_to_write = backend.jit.make_backend_source( mod.info_analysis, func, path_backend) if has_to_write and mpi.rank == 0: logger.debug(f"write code in file {path_backend}") with open(path_backend, "w") as file: file.write(src) file.flush() if src is None and mpi.rank == 0: with open(path_backend) as file: src = file.read() hex_src = None name_mod = None if mpi.rank == 0: # hash from src (to produce the extension name) hex_src = make_hex(src) name_mod = ".".join(path_backend.absolute().relative_to( path_root).with_suffix("").parts) hex_src = mpi.bcast(hex_src) name_mod = mpi.bcast(name_mod) def backenize_with_new_header(arg_types="no types"): header_object = backend.jit.make_new_header(func, arg_types) header_code = backend.jit.merge_old_and_new_header( path_backend_header, header_object, func) backend.jit.write_new_header(path_backend_header, header_code, arg_types) # compute the new path of the extension hex_header = make_hex(header_code) # if mpi.nb_proc > 1: # hex_header0 = mpi.bcast(hex_header) # assert hex_header0 == hex_header name_ext_file = (func_name + "_" + hex_src + "_" + hex_header + backend.suffix_extension) self.path_extension = path_backend.with_name(name_ext_file) self.compiling, self.process = backend.compile_extension( path_backend, name_ext_file, native=self.native, xsimd=self.xsimd, openmp=self.openmp, ) # for backend like numba if not self.compiling: backend_module = import_from_path(self.path_extension, name_mod) assert backend.check_if_compiled(backend_module) self.backend_func = getattr(backend_module, func_name) ext_files = None if mpi.rank == 0: glob_name_ext_file = (func_name + "_" + hex_src + "_*" + backend.suffix_extension) ext_files = list( mpi.PathSeq(path_backend).parent.glob(glob_name_ext_file)) ext_files = mpi.bcast(ext_files) if not ext_files: if has_to_compile_at_import() and _COMPILE_JIT: backenize_with_new_header() self.backend_func = None else: path_ext = max(ext_files, key=lambda p: p.stat().st_ctime) backend_module = import_from_path(path_ext, name_mod) self.backend_func = getattr(backend_module, func_name) # this is the function that will be called by the user @wraps(func) def type_collector(*args, **kwargs): if self.compiling: if not self.process.is_alive(raise_if_error=True): self.compiling = False time.sleep(0.1) backend_module = import_from_path(self.path_extension, name_mod) assert backend.check_if_compiled(backend_module) self.backend_func = getattr(backend_module, func_name) try: return self.backend_func(*args, **kwargs) except TypeError as err: # need to compiled or recompile error = False if self.backend_func: error = str(err) if (error.startswith( "Invalid call to pythranized function `") and " (reshaped)" in error): logger.error( "It seems that a jitted Pythran function has been called " 'with a "reshaped" array which is not supported by Pythran.' ) raise logger.debug(error) if self.compiling or not _COMPILE_JIT: return func(*args, **kwargs) if (self.backend_func and error and error.startswith( "Invalid call to pythranized function `")): logger.debug(error) logger.info( f"{backend.name_capitalized} function `{func_name}` called with new types." ) logger.debug( "Transonic is going to recompute the function for the new types." ) arg_types = [ backend.jit.compute_typename_from_object(arg) for arg in itertools.chain(args, kwargs.values()) ] backenize_with_new_header(arg_types) return func(*args, **kwargs) return type_collector