def setup_mpi(): global mpi_ops, mpi, use_mpi global mpicc_path, has_mpi use_mpi = os.environ.get("use_mpi", "1") == "1" mpi_ops = None mpi = None has_mpi = False mpicc_path = env_or_try_find('mpicc_path', 'mpicc') if mpicc_path == "": LOG.i("mpicc not found, distribution disabled.") use_mpi = False else: use_mpi = True has_mpi = True if not use_mpi: return global mpi_compile_flags, mpi_link_flags, mpi_flags mpi_compile_flags = run_cmd(mpicc_path + " --showme:compile") mpi_link_flags = run_cmd(mpicc_path + " --showme:link") mpi_flags = mpi_compile_flags + " " + mpi_link_flags LOG.v("mpi_flags: " + mpi_flags) # find all source files mpi_src_dir = os.path.join(jittor_path, "extern", "mpi") mpi_src_files = [] for r, _, f in os.walk(mpi_src_dir): for fname in f: mpi_src_files.append(os.path.join(r, fname)) # mpi compile flags add for nccl mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' " mpi_compile_flags = mpi_compile_flags.replace("-pthread", "") mpi_version = get_version(mpicc_path) if mpi_version.startswith("(1.") or mpi_version.startswith("(2."): # mpi version 1.x need to link like this manual_link(mpi_flags) # mpi(4.x) cannot use deepbind, it need to # share the 'environ' symbol. mpi = compile_custom_ops(mpi_src_files, extra_flags=f" {mpi_flags} ", return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core") mpi_ops = mpi.ops LOG.vv("Get mpi: " + str(mpi.__dict__.keys())) LOG.vv("Get mpi_ops: " + str(mpi_ops.__dict__.keys())) def warper(func): def inner(self, *args, **kw): return func(self, *args, **kw) inner.__doc__ = func.__doc__ return inner for k in mpi_ops.__dict__: if not k.startswith("mpi_"): continue if k == "mpi_test": continue setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
def env_or_try_find(name, bname): if name in os.environ: path = os.environ[name] if path != "": version = jit_utils.get_version(path) LOG.i(f"Found {bname}{version} at {path}") return path return try_find_exe(bname)
def setup_cub(): global cub_home cub_home = "" from pathlib import Path cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub") cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0]) extra_flags = "" if cuda_version < 11: cub_home = install_cub(cub_path) extra_flags = f"-I{cub_home}" cub_home += "/" setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
try_find_exe('/opt/cuda/bin/nvcc') # if system has no cuda, install jtcuda if not nvcc_path: nvcc_path = install_cuda.install_cuda() if nvcc_path: nvcc_path = try_find_exe(nvcc_path) if nvcc_path is None: nvcc_path = "" gdb_path = env_or_try_find('gdb_path', 'gdb') addr2line_path = try_find_exe('addr2line') has_pybt = check_pybt(gdb_path, python_path) if nvcc_path: # gen cuda key for cache_path cu = "cu" v = jit_utils.get_version(nvcc_path)[1:-1] nvcc_version = list(map(int, v.split('.'))) cu += v try: r, s = sp.getstatusoutput( f"{sys.executable} -m jittor_utils.query_cuda_cc") if r == 0: s = sorted(list(set(s.strip().split()))) cu += "_sm_" + "_".join(s) if "cuda_arch" not in os.environ: os.environ["cuda_arch"] = " ".join(cu) cu = cu.replace(":", "").replace(" ", "") except: pass LOG.i("cuda key:", cu) cache_path = os.path.join(cache_path, cu)