def compile_extern(): # compile llvm passes if cc_type != "clang" or platform.system() != 'Linux': return global kernel_opt_flags cache_path_llvm = os.path.join(cache_path, "llvm") jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm") clang_dir = os.path.dirname(get_full_path_of_executable(cc_path)) assert clang_dir.endswith( "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}" llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include")) assert os.path.isdir(llvm_include), "LLVM include path not found" make_cache_dir(cache_path_llvm) files = os.listdir(jittor_path_llvm) # test_pass.cc is used for test link problem of llvm pass plugin test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc") with open(test_pass_path, 'w') as f: f.write("int main() {return 0;}") # -fno-rtti fix link error # -Wl,-znodelete fix segfault # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287 # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt # try different flags try_flags = [ " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ", " -Wl,-znodelete ", ] found_flags_id = -1 for fname in files: for i, flag in enumerate(try_flags): if found_flags_id != -1 and found_flags_id != i: continue so_name = os.path.join(cache_path_llvm, os.path.splitext(fname)[0] + f".{i}.so") compile(cc_path, f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'", [os.path.join(jittor_path_llvm, fname)], so_name) # if not found available flags, we test it. if found_flags_id == -1: try: s = run_cmd( f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}", cache_path_llvm, print_error=False) except Exception as e: LOG.v(f"Try flag {flag} failed: {e}") continue found_flags_id = i kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' " break else: LOG.w("Clang is used, but LLVM pass plugin is unable to link.") break LOG.vv(f"Compile extern llvm passes: {str(files)}")
def compile(cache_path, jittor_path): headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines() headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines() headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \ [ os.path.join(cache_path, h) for h in headers2 ] basenames = [] for h in headers: with open(h, 'r') as f: src = f.read() # jit_op_maker.h merge compile with var_holder.h if h.endswith("src/var_holder.h"): continue if h.endswith("jit_op_maker.h"): with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = h.split("/")[-1].split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code)
def gen_jit_flags(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL) flags_defs = [] visit = {} for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for _, args in defs: args = args.split(",") type = args[0].strip() name = args[1].strip() if not has_cuda and "cuda" in name and name != "use_cuda": continue default = args[2].strip() doc = ",".join(args[3:]) doc = eval(f"({doc})") LOG.vv(f"Find define {name} from {src_name}") if name in visit: continue visit[name] = 1 jit_declares.append(f"DECLARE_FLAG({type}, {name});") flags_defs.append(f""" /* {name}(type:{type}, default:{default}): {doc} */ // @pyjt(__get__{name}) {type} _get_{name}() {{ return {name}; }} // @pyjt(__set__{name}) void _set_{name}({type} v) {{ set_{name}(v); }} {f'''// @pyjt(__set__{name}) void _set_{name}(bool v) {{ set_{name}(v); }} ''' if type=="int" else ""} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #include "utils/flags.h" namespace jittor {{ {jit_declares} // @pyjt(flags) struct _Flags {{ // @pyjt(__init__) _Flags() {{}} {"".join(flags_defs)} }}; }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f: f.write(jit_src)
def install_cutt(root_folder): # Modified from: https://github.com/ap-hynninen/cutt url = "https://github.com/Jittor/cutt/archive/master.zip" url = "https://codeload.github.com/Jittor/cutt/zip/master" filename = "cutt-master.zip" fullname = os.path.join(root_folder, filename) dirname = os.path.join(root_folder, filename.replace(".zip", "")) true_md5 = "af5bc35eea1832a42c0e0011659b7209" if os.path.exists(fullname): md5 = run_cmd('md5sum ' + fullname).split()[0] if md5 != true_md5: os.remove(fullname) shutil.rmtree(dirname) if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")): LOG.i("Downloading cutt...") download_url_to_local(url, filename, root_folder, true_md5) import zipfile zf = zipfile.ZipFile(fullname) try: zf.extractall(path=root_folder) except RuntimeError as e: print(e) raise zf.close() LOG.i("installing cutt...") arch_flag = "" if len(flags.cuda_archs): arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " arch_flag += ''.join( map(lambda x: f' -code=sm_{x} ', flags.cuda_archs)) run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'", cwd=dirname) return dirname
def install_nccl(root_folder): url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz" url = "https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1" filename = "nccl.tgz" fullname = os.path.join(root_folder, filename) dirname = os.path.join(root_folder, "nccl-2.8.4-1") true_md5 = "900666558c5bc43e0a5e84045b88a06f" if os.path.exists(fullname): md5 = run_cmd('md5sum ' + fullname).split()[0] if md5 != true_md5: os.remove(fullname) if os.path.isdir(dirname): shutil.rmtree(dirname) if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")): LOG.i("Downloading nccl...") download_url_to_local(url, filename, root_folder, true_md5) if core.get_device_count() == 0: return if not inside_mpi(): return import tarfile with tarfile.open(fullname, "r") as tar: tar.extractall(root_folder) LOG.i("installing nccl...") arch_flag = "" if len(flags.cuda_archs): arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " arch_flag += ''.join( map(lambda x: f' -code=sm_{x} ', flags.cuda_archs)) run_cmd( f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ", cwd=dirname) return dirname
def gen_jit_tests(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("JIT_TEST\\((.*?)\\)") names = set() test_defs = [] for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for name in defs: LOG.vv(f"Find test {name} from {src_name}") assert name not in names, f"Conflict test name {name}" names.add(name) jit_declares.append(f"JIT_TEST({name});") test_defs.append(f""" /* From {src_name} */ // @pyjt({name}) static inline void test_{name}() {{ jit_test_{name}(); }} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #pragma once #include "common.h" void expect_error(std::function<void()> func) {{ try {{ func(); }} catch (...) {{ return; }} CHECK(0) << "Missing error"; }} namespace jittor {{ {jit_declares} // @pyjt(tests) // @attrs(submodule) namespace tests {{ {"".join(test_defs)} }} }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f: f.write(jit_src)
def do_compile(cmd): if jit_utils.cc: return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path) else: run_cmd(cmd) return True
kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags + " -fopenmp " if ' -O' not in cc_flags: opt_flags += " -O2 " kernel_opt_flags += " -Ofast " lto_flags = "" if os.environ.get("enable_lto") == "1": if cc_type == "icc": lto_flags = " -flto -ipo -ipo-c " elif cc_type == "g++": lto_flags = " -flto -fuse-linker-plugin " else: lto_flags = " -flto " pybind_include = run_cmd(python_path + " -m pybind11 --includes") LOG.i(f"pybind_include: {pybind_include}") extension_suffix = run_cmd(py3_config_path + " --extension-suffix") LOG.i(f"extension_suffix: {extension_suffix}") make_cache_dir(cache_path) make_cache_dir(os.path.join(cache_path, "jit")) make_cache_dir(os.path.join(cache_path, "obj_files")) make_cache_dir(os.path.join(cache_path, "gen")) # build cache_compile cc_flags += pybind_include cc_flags += f" -I{jittor_path}/src " check_cache_compile() LOG.v(f"Get cache_compile: {jit_utils.cc}")
def check_clang_latest_supported_cpu(): output = run_cmd('clang --print-supported-cpus') apple_cpus = [l.strip() for l in output.split('\n') if 'apple-a' in l] apple_cpus_id = max([int(cpu[7:]) for cpu in apple_cpus]) return f'apple-a{apple_cpus_id}'
data_gz_md5_path = os.path.join(cache_path, "data.md5") if os.path.isfile(data_gz_md5_path): with open(data_gz_md5_path, 'r') as f: target_md5 = f.read() data_o_path = os.path.join(cache_path, "data.o") if target_md5 != md5: data_s_path = os.path.join(cache_path, "data.cc") with open(data_s_path, "w") as f: f.write(data.decode("utf8")) dflags = (cc_flags+opt_flags)\ .replace("-Wall", "") \ .replace("-Werror", "") \ .replace("-shared", "") vdp = os.path.join(jittor_path, "src", "utils", "vdp") run_cmd( fix_cl_flags( f"{cc_path} {dflags} -include \"{vdp}\" \"{data_s_path}\" -c -o \"{data_o_path}\"" )) os.remove(data_s_path) with open(data_gz_md5_path, 'w') as f: f.write(md5) files.append(data_o_path) files = [f for f in files if "__data__" not in f] cc_flags += f" -l\"jit_utils_core{lib_suffix}\" " compile(cc_path, cc_flags + opt_flags, files, 'jittor_core' + extension_suffix) cc_flags += f" -l\"jittor_core{lib_suffix}\" " # TODO: move to compile_extern.py # compile_extern() with jit_utils.import_scope(import_flags):
else: kernel_opt_flags = kernel_opt_flags + " -fopenmp " if ' -O' not in cc_flags: opt_flags += " -O2 " kernel_opt_flags += " -Ofast " lto_flags = "" if os.environ.get("enable_lto") == "1": if cc_type == "icc": lto_flags = " -flto -ipo -ipo-c " elif cc_type == "g++": lto_flags = " -flto -fuse-linker-plugin " else: lto_flags = " -flto " py_include = run_cmd(py3_config_path+" --includes") LOG.i(f"py_include: {py_include}") extension_suffix = run_cmd(py3_config_path+" --extension-suffix") LOG.i(f"extension_suffix: {extension_suffix}") make_cache_dir(cache_path) make_cache_dir(os.path.join(cache_path, "jit")) make_cache_dir(os.path.join(cache_path, "obj_files")) make_cache_dir(os.path.join(cache_path, "gen")) ck_path = os.path.join(cache_path, "checkpoints") make_cache_dir(ck_path) # build cache_compile cc_flags += f" -I{jittor_path}/src " cc_flags += py_include check_cache_compile()
for d in dirs: fname = os.path.join(d, name) if os.path.isfile(fname): return fname LOG.f(f"file {name} not found in {dirs}") if __name__ == "__main__": help_msg = f"Usage: {sys.executable} -m jittor_utils.config --include-flags|--link-flags|--cxx-flags|--cxx-example|--help" if len(sys.argv) <= 1: print(help_msg) sys.exit(1) s = "" # base should be something like python3.7m python3.8 base = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --includes").split()[0] base = "python3" + base.split("python3")[-1] for arg in sys.argv[1:]: if arg == "--include-flags": s += jittor_utils.run_cmd(jittor_utils.py3_config_path + " --includes") s += " -I" + os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "jittor", "src")) s += " " elif arg == "--libs-flags": libext = { 'Linux': 'so', 'Darwin': 'dylib', 'Windows': 'DLL', }[platform.system()] ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path +
s = "" # base should be something like python3.7m python3.8 base = jittor_utils.get_py3_include_path().split()[0] base = "python3" + base.split("python3")[-1] for arg in sys.argv[1:]: if arg == "--include-flags": s += jittor_utils.get_py3_include_path() s += " -I"+os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "jittor", "src")) s += " " elif arg == "--libs-flags": libext = { 'Linux': 'so', 'Darwin': 'dylib', 'Windows': 'DLL', }[platform.system()] ldflags = jittor_utils.run_cmd(jittor_utils.get_py3_config_path() + " --ldflags") libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")] for libbase in libpaths: libpath = os.path.join(libbase, f"lib{base}.{libext}") if os.path.isfile(libpath): s += f" -L{libbase} -l{base} -ldl " break else: raise RuntimeError("Python dynamic library not found") if os.name == 'nt': s = s.replace('-ldl', '') elif arg == "--cxx-flags": s += " --std=c++17 -fPIC " elif arg == "--cxx-example": cc_src = ''' // please compile with: g++ a.cc $(python3 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o a.out && ./a.out
def test_console(self): run_cmd(f"{sys.executable} -m jittor_utils.config --cxx-example > tmp.cc", jt.flags.cache_path) s = run_cmd(f"{jt.flags.cc_path} tmp.cc $({sys.executable} -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o tmp.out && ./tmp.out", jt.flags.cache_path) print(s) assert "jt.Var" in s assert "pred.shape 2 1000" in s
def setup_mpi(): global mpi_ops, mpi, use_mpi global mpicc_path, has_mpi use_mpi = os.environ.get("use_mpi", "1") == "1" mpi_ops = None mpi = None has_mpi = False mpicc_path = env_or_try_find('mpicc_path', 'mpicc') if mpicc_path == "": LOG.i("mpicc not found, distribution disabled.") use_mpi = False else: use_mpi = True has_mpi = True if not inside_mpi(): use_mpi = False if not use_mpi: return global mpi_compile_flags, mpi_link_flags, mpi_flags mpi_compile_flags = run_cmd(mpicc_path + " --showme:compile") mpi_link_flags = run_cmd(mpicc_path + " --showme:link") mpi_flags = mpi_compile_flags + " " + mpi_link_flags LOG.v("mpi_flags: " + mpi_flags) # find all source files mpi_src_dir = os.path.join(jittor_path, "extern", "mpi") mpi_src_files = [] for r, _, f in os.walk(mpi_src_dir): for fname in f: mpi_src_files.append(os.path.join(r, fname)) # mpi compile flags add for nccl mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' " mpi_compile_flags = mpi_compile_flags.replace("-pthread", "") mpi_version = get_version(mpicc_path) if mpi_version.startswith("(1.") or mpi_version.startswith("(2."): # mpi version 1.x need to link like this manual_link(mpi_flags) # mpi(4.x) cannot use deepbind, it need to # share the 'environ' symbol. mpi = compile_custom_ops(mpi_src_files, extra_flags=f" {mpi_flags} ", return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW) mpi_ops = mpi.ops LOG.vv("Get mpi: " + str(mpi.__dict__.keys())) LOG.vv("Get mpi_ops: " + str(mpi_ops.__dict__.keys())) def warper(func): def inner(self, *args, **kw): return func(self, *args, **kw) inner.__doc__ = func.__doc__ return inner for k in mpi_ops.__dict__: if not k.startswith("mpi_"): continue if k == "mpi_test": continue setattr(core.Var, k, warper(mpi_ops.__dict__[k]))