def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", "src/utils/str_utils.cc", ] if os.name == 'nt': files = [x.replace('/', '\\') for x in files] global jit_utils_core_files jit_utils_core_files = files recompile = compile( cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please rerun your command.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
def compile_custom_ops(filenames, extra_flags=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): builds.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h") gen_lib = os.path.join("custom_ops", gen_name+extension_suffix) with open(gen_head_fname, "w") as f: f.write(gen_src) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname) # gen src initialize first builds.insert(0, gen_src_fname) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): exec(f"import {gen_name}") return (locals()[gen_name]).ops
def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", ] global jit_utils_core_files jit_utils_core_files = files recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please restart jittor.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
def compile_custom_ops(filenames, extra_flags="", return_module=False, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND, gen_name_=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return_module: return module rather than ops(default: False) return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] pyjt_includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith( ".cu"): builds.append(name) if name.endswith(".h"): dirname = os.path.dirname(name) if dirname.endswith("inc"): includes.append(dirname) with open(name, "r") as f: if "@pyjt" in f.read(): pyjt_includes.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if gen_name_ != "": gen_name = gen_name_ if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h") gen_lib = os.path.join("custom_ops", gen_name + extension_suffix) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) # gen src initialize first builds.insert(0, gen_src_fname) def insert_anchor(gen_src, anchor_str, insert_str): # insert insert_str after anchor_str into gen_src return gen_src.replace(anchor_str, anchor_str + insert_str, 1) for name in pyjt_includes: LOG.i("handle pyjt_include", name) bname = name.split("/")[-1].split(".")[0] gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + "_" + bname + ".cc") pyjt_compiler.compile_single(name, gen_src_fname) builds.insert(1, gen_src_fname) gen_src = insert_anchor(gen_src, "namespace jittor {", f"extern void pyjt_def_{bname}(PyObject* m);") gen_src = insert_anchor( gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {", f"jittor::pyjt_def_{bname}(m);") with open(gen_head_fname, "w") as f: f.write(gen_src) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) # unlock scope when initialize with lock.unlock_scope(): with jit_utils.import_scope(dlopen_flags): exec(f"import {gen_name}") mod = locals()[gen_name] if return_module: return mod return mod.ops
is_debug = 0 if os.environ.get("debug") == "1": is_debug = 1 global cc_flags cc_flags += " -g -DNODE_MEMCHECK " cc_flags = " " # os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND # if cc_type=="icc": # # weird link problem, icc omp library may conflict and cause segfault # import_flags = os.RTLD_NOW | os.RTLD_GLOBAL dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() jittor_path = find_jittor_path() check_debug_flags() sys.path.append(cache_path) with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() python_path = sys.executable py3_config_paths = [ sys.executable + "-config", os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
import os from pathlib import Path from collections import defaultdict import pickle import numpy as np import jittor_utils from jittor_utils import LOG import sys with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): jittor_utils.try_import_jit_utils_core() has_error = 0 def convert(data): if hasattr(data, "numpy"): if "Var" in data.__class__.__name__: return data.numpy() else: return data.detach().cpu().numpy() if isinstance(data, tuple): return tuple(convert(v) for v in data) if isinstance(data, list): return [convert(v) for v in data] if isinstance(data, np.ndarray): return data if isinstance(data, dict): return {k: convert(data[k]) for k in data} return data