コード例 #1
0
ファイル: compile_extern.py プロジェクト: wcq19941215/jittor
def setup_mpi():
    global mpi_ops, mpi, use_mpi
    global mpicc_path, has_mpi
    use_mpi = os.environ.get("use_mpi", "1") == "1"
    mpi_ops = None
    mpi = None
    has_mpi = False
    mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
    if mpicc_path == "":
        LOG.i("mpicc not found, distribution disabled.")
        use_mpi = False
    else:
        use_mpi = True
        has_mpi = True
    if not use_mpi:
        return

    global mpi_compile_flags, mpi_link_flags, mpi_flags
    mpi_compile_flags = run_cmd(mpicc_path + " --showme:compile")
    mpi_link_flags = run_cmd(mpicc_path + " --showme:link")
    mpi_flags = mpi_compile_flags + " " + mpi_link_flags
    LOG.v("mpi_flags: " + mpi_flags)

    # find all source files
    mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
    mpi_src_files = []
    for r, _, f in os.walk(mpi_src_dir):
        for fname in f:
            mpi_src_files.append(os.path.join(r, fname))

    # mpi compile flags add for nccl
    mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
    mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")

    mpi_version = get_version(mpicc_path)
    if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
        # mpi version 1.x need to link like this
        manual_link(mpi_flags)
    # mpi(4.x) cannot use deepbind, it need to
    # share the 'environ' symbol.
    mpi = compile_custom_ops(mpi_src_files,
                             extra_flags=f" {mpi_flags} ",
                             return_module=True,
                             dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
                             gen_name_="jittor_mpi_core")
    mpi_ops = mpi.ops
    LOG.vv("Get mpi: " + str(mpi.__dict__.keys()))
    LOG.vv("Get mpi_ops: " + str(mpi_ops.__dict__.keys()))

    def warper(func):
        def inner(self, *args, **kw):
            return func(self, *args, **kw)

        inner.__doc__ = func.__doc__
        return inner

    for k in mpi_ops.__dict__:
        if not k.startswith("mpi_"): continue
        if k == "mpi_test": continue
        setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
コード例 #2
0
def env_or_try_find(name, bname):
    if name in os.environ:
        path = os.environ[name]
        if path != "":
            version = jit_utils.get_version(path)
            LOG.i(f"Found {bname}{version} at {path}")
        return path
    return try_find_exe(bname)
コード例 #3
0
ファイル: compile_extern.py プロジェクト: wcq19941215/jittor
def setup_cub():
    global cub_home
    cub_home = ""
    from pathlib import Path
    cub_path = os.path.join(str(Path.home()), ".cache", "jittor", "cub")
    cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0])
    extra_flags = ""
    if cuda_version < 11:
        cub_home = install_cub(cub_path)
        extra_flags = f"-I{cub_home}"
        cub_home += "/"
    setup_cuda_lib("cub", link=False, extra_flags=extra_flags)
コード例 #4
0
ファイル: compiler.py プロジェクト: Exusial/jittor
        try_find_exe('/opt/cuda/bin/nvcc')
# if system has no cuda, install jtcuda
if not nvcc_path:
    nvcc_path = install_cuda.install_cuda()
    if nvcc_path:
        nvcc_path = try_find_exe(nvcc_path)
if nvcc_path is None:
    nvcc_path = ""
gdb_path = env_or_try_find('gdb_path', 'gdb')
addr2line_path = try_find_exe('addr2line')
has_pybt = check_pybt(gdb_path, python_path)

if nvcc_path:
    # gen cuda key for cache_path
    cu = "cu"
    v = jit_utils.get_version(nvcc_path)[1:-1]
    nvcc_version = list(map(int, v.split('.')))
    cu += v
    try:
        r, s = sp.getstatusoutput(
            f"{sys.executable} -m jittor_utils.query_cuda_cc")
        if r == 0:
            s = sorted(list(set(s.strip().split())))
            cu += "_sm_" + "_".join(s)
            if "cuda_arch" not in os.environ:
                os.environ["cuda_arch"] = " ".join(cu)
            cu = cu.replace(":", "").replace(" ", "")
    except:
        pass
    LOG.i("cuda key:", cu)
    cache_path = os.path.join(cache_path, cu)