示例#1
0
def compile_extern():
    # compile llvm passes
    if cc_type != "clang" or platform.system() != 'Linux':
        return
    global kernel_opt_flags
    cache_path_llvm = os.path.join(cache_path, "llvm")
    jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm")
    clang_dir = os.path.dirname(get_full_path_of_executable(cc_path))
    assert clang_dir.endswith(
        "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}"
    llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include"))
    assert os.path.isdir(llvm_include), "LLVM include path not found"
    make_cache_dir(cache_path_llvm)
    files = os.listdir(jittor_path_llvm)
    # test_pass.cc is used for test link problem of llvm pass plugin
    test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
    with open(test_pass_path, 'w') as f:
        f.write("int main() {return 0;}")

    # -fno-rtti fix link error

    # -Wl,-znodelete fix segfault
    # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287

    # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass
    # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt

    # try different flags
    try_flags = [
        " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ",
        " -Wl,-znodelete ",
    ]
    found_flags_id = -1
    for fname in files:
        for i, flag in enumerate(try_flags):
            if found_flags_id != -1 and found_flags_id != i:
                continue
            so_name = os.path.join(cache_path_llvm,
                                   os.path.splitext(fname)[0] + f".{i}.so")
            compile(cc_path,
                    f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'",
                    [os.path.join(jittor_path_llvm, fname)], so_name)
            # if not found available flags, we test it.
            if found_flags_id == -1:
                try:
                    s = run_cmd(
                        f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}",
                        cache_path_llvm,
                        print_error=False)
                except Exception as e:
                    LOG.v(f"Try flag {flag} failed: {e}")
                    continue
                found_flags_id = i
            kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' "
            break
        else:
            LOG.w("Clang is used, but LLVM pass plugin is unable to link.")
            break
    LOG.vv(f"Compile extern llvm passes: {str(files)}")
示例#2
0
def compile(cache_path, jittor_path):
    headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
    headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
    headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
        [ os.path.join(cache_path, h) for h in headers2 ]
    basenames = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        # jit_op_maker.h merge compile with var_holder.h
        if h.endswith("src/var_holder.h"): continue
        if h.endswith("jit_op_maker.h"):
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = h.split("/")[-1].split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
示例#3
0
def gen_jit_flags():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)

    flags_defs = []
    visit = {}

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for _, args in defs:
            args = args.split(",")
            type = args[0].strip()
            name = args[1].strip()
            if not has_cuda and "cuda" in name and name != "use_cuda":
                continue
            default = args[2].strip()
            doc = ",".join(args[3:])
            doc = eval(f"({doc})")
            LOG.vv(f"Find define {name} from {src_name}")
            if name in visit:
                continue
            visit[name] = 1
            jit_declares.append(f"DECLARE_FLAG({type}, {name});")
            flags_defs.append(f"""
                /* {name}(type:{type}, default:{default}): {doc} */
                // @pyjt(__get__{name})
                {type} _get_{name}() {{ return {name}; }}
                // @pyjt(__set__{name})
                void _set_{name}({type} v) {{ set_{name}(v); }}
                {f'''// @pyjt(__set__{name})
                void _set_{name}(bool v) {{ set_{name}(v); }}
                ''' if type=="int" else ""}
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #include "utils/flags.h"

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(flags)
    struct _Flags {{
        // @pyjt(__init__)
        _Flags() {{}}
        {"".join(flags_defs)}
    }};

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
        f.write(jit_src)
示例#4
0
def install_cutt(root_folder):
    # Modified from: https://github.com/ap-hynninen/cutt
    url = "https://github.com/Jittor/cutt/archive/master.zip"
    url = "https://codeload.github.com/Jittor/cutt/zip/master"

    filename = "cutt-master.zip"
    fullname = os.path.join(root_folder, filename)
    dirname = os.path.join(root_folder, filename.replace(".zip", ""))
    true_md5 = "af5bc35eea1832a42c0e0011659b7209"

    if os.path.exists(fullname):
        md5 = run_cmd('md5sum ' + fullname).split()[0]
        if md5 != true_md5:
            os.remove(fullname)
            shutil.rmtree(dirname)
    if not os.path.isfile(os.path.join(dirname, "bin", "cutt_test")):
        LOG.i("Downloading cutt...")
        download_url_to_local(url, filename, root_folder, true_md5)

        import zipfile

        zf = zipfile.ZipFile(fullname)
        try:
            zf.extractall(path=root_folder)
        except RuntimeError as e:
            print(e)
            raise
        zf.close()

        LOG.i("installing cutt...")
        arch_flag = ""
        if len(flags.cuda_archs):
            arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
            arch_flag += ''.join(
                map(lambda x: f' -code=sm_{x} ', flags.cuda_archs))
        run_cmd(f"make NVCC_GENCODE='{arch_flag}' nvcc_path='{nvcc_path}'",
                cwd=dirname)
    return dirname
示例#5
0
def install_nccl(root_folder):
    url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz"
    url = "https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1"

    filename = "nccl.tgz"
    fullname = os.path.join(root_folder, filename)
    dirname = os.path.join(root_folder, "nccl-2.8.4-1")
    true_md5 = "900666558c5bc43e0a5e84045b88a06f"

    if os.path.exists(fullname):
        md5 = run_cmd('md5sum ' + fullname).split()[0]
        if md5 != true_md5:
            os.remove(fullname)
            if os.path.isdir(dirname):
                shutil.rmtree(dirname)
    if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")):
        LOG.i("Downloading nccl...")
        download_url_to_local(url, filename, root_folder, true_md5)

        if core.get_device_count() == 0:
            return
        if not inside_mpi():
            return

        import tarfile
        with tarfile.open(fullname, "r") as tar:
            tar.extractall(root_folder)

        LOG.i("installing nccl...")
        arch_flag = ""
        if len(flags.cuda_archs):
            arch_flag = f" -arch=compute_{min(flags.cuda_archs)} "
            arch_flag += ''.join(
                map(lambda x: f' -code=sm_{x} ', flags.cuda_archs))
        run_cmd(
            f"make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag}' ",
            cwd=dirname)
    return dirname
示例#6
0
def gen_jit_tests():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("JIT_TEST\\((.*?)\\)")
    names = set()
    test_defs = []

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for name in defs:
            LOG.vv(f"Find test {name} from {src_name}")
            assert name not in names, f"Conflict test name {name}"
            names.add(name)
            jit_declares.append(f"JIT_TEST({name});")
            test_defs.append(f"""
                /* From {src_name} */
                // @pyjt({name})
                static inline void test_{name}() {{ jit_test_{name}(); }} 
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #pragma once
    #include "common.h"

    void expect_error(std::function<void()> func) {{
        try {{ func(); }}
        catch (...) {{ return; }}
        CHECK(0) << "Missing error";
    }}

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(tests)
    // @attrs(submodule)
    namespace tests {{
        {"".join(test_defs)}
    }}

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
        f.write(jit_src)
示例#7
0
 def do_compile(cmd):
     if jit_utils.cc:
         return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path)
     else:
         run_cmd(cmd)
         return True
示例#8
0
kernel_opt_flags = os.environ.get("kernel_flags",
                                  "") + opt_flags + " -fopenmp "

if ' -O' not in cc_flags:
    opt_flags += " -O2 "
    kernel_opt_flags += " -Ofast "
lto_flags = ""
if os.environ.get("enable_lto") == "1":
    if cc_type == "icc":
        lto_flags = " -flto -ipo -ipo-c "
    elif cc_type == "g++":
        lto_flags = " -flto -fuse-linker-plugin "
    else:
        lto_flags = " -flto "

pybind_include = run_cmd(python_path + " -m pybind11 --includes")
LOG.i(f"pybind_include: {pybind_include}")
extension_suffix = run_cmd(py3_config_path + " --extension-suffix")
LOG.i(f"extension_suffix: {extension_suffix}")

make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))

# build cache_compile
cc_flags += pybind_include
cc_flags += f" -I{jittor_path}/src "
check_cache_compile()
LOG.v(f"Get cache_compile: {jit_utils.cc}")
示例#9
0
def check_clang_latest_supported_cpu():
    output = run_cmd('clang --print-supported-cpus')
    apple_cpus = [l.strip() for l in output.split('\n') if 'apple-a' in l]
    apple_cpus_id = max([int(cpu[7:]) for cpu in apple_cpus])
    return f'apple-a{apple_cpus_id}'
示例#10
0
    data_gz_md5_path = os.path.join(cache_path, "data.md5")
    if os.path.isfile(data_gz_md5_path):
        with open(data_gz_md5_path, 'r') as f:
            target_md5 = f.read()
    data_o_path = os.path.join(cache_path, "data.o")
    if target_md5 != md5:
        data_s_path = os.path.join(cache_path, "data.cc")
        with open(data_s_path, "w") as f:
            f.write(data.decode("utf8"))
        dflags = (cc_flags+opt_flags)\
            .replace("-Wall", "") \
            .replace("-Werror", "") \
            .replace("-shared", "")
        vdp = os.path.join(jittor_path, "src", "utils", "vdp")
        run_cmd(
            fix_cl_flags(
                f"{cc_path} {dflags} -include \"{vdp}\" \"{data_s_path}\" -c -o \"{data_o_path}\""
            ))
        os.remove(data_s_path)
        with open(data_gz_md5_path, 'w') as f:
            f.write(md5)
    files.append(data_o_path)
    files = [f for f in files if "__data__" not in f]

cc_flags += f" -l\"jit_utils_core{lib_suffix}\" "
compile(cc_path, cc_flags + opt_flags, files, 'jittor_core' + extension_suffix)
cc_flags += f" -l\"jittor_core{lib_suffix}\" "

# TODO: move to compile_extern.py
# compile_extern()

with jit_utils.import_scope(import_flags):
示例#11
0
else:
    kernel_opt_flags = kernel_opt_flags + " -fopenmp "

if ' -O' not in cc_flags:
    opt_flags += " -O2 "
    kernel_opt_flags += " -Ofast "
lto_flags = ""
if os.environ.get("enable_lto") == "1":
    if cc_type == "icc":
        lto_flags = " -flto -ipo -ipo-c "
    elif cc_type == "g++":
        lto_flags = " -flto -fuse-linker-plugin "
    else:
        lto_flags = " -flto "

py_include = run_cmd(py3_config_path+" --includes")
LOG.i(f"py_include: {py_include}")
extension_suffix = run_cmd(py3_config_path+" --extension-suffix")
LOG.i(f"extension_suffix: {extension_suffix}")

make_cache_dir(cache_path)
make_cache_dir(os.path.join(cache_path, "jit"))
make_cache_dir(os.path.join(cache_path, "obj_files"))
make_cache_dir(os.path.join(cache_path, "gen"))
ck_path = os.path.join(cache_path, "checkpoints")
make_cache_dir(ck_path)

# build cache_compile
cc_flags += f" -I{jittor_path}/src "
cc_flags += py_include
check_cache_compile()
示例#12
0
    for d in dirs:
        fname = os.path.join(d, name)
        if os.path.isfile(fname):
            return fname
    LOG.f(f"file {name} not found in {dirs}")


if __name__ == "__main__":
    help_msg = f"Usage: {sys.executable} -m jittor_utils.config --include-flags|--link-flags|--cxx-flags|--cxx-example|--help"
    if len(sys.argv) <= 1:
        print(help_msg)
        sys.exit(1)

    s = ""
    # base should be something like python3.7m python3.8
    base = jittor_utils.run_cmd(jittor_utils.py3_config_path +
                                " --includes").split()[0]
    base = "python3" + base.split("python3")[-1]
    for arg in sys.argv[1:]:
        if arg == "--include-flags":
            s += jittor_utils.run_cmd(jittor_utils.py3_config_path +
                                      " --includes")
            s += " -I" + os.path.abspath(
                os.path.join(os.path.dirname(__file__), "..", "jittor", "src"))
            s += " "
        elif arg == "--libs-flags":
            libext = {
                'Linux': 'so',
                'Darwin': 'dylib',
                'Windows': 'DLL',
            }[platform.system()]
            ldflags = jittor_utils.run_cmd(jittor_utils.py3_config_path +
示例#13
0
    s = ""
    # base should be something like python3.7m python3.8
    base = jittor_utils.get_py3_include_path().split()[0]
    base = "python3" + base.split("python3")[-1]
    for arg in sys.argv[1:]:
        if arg == "--include-flags":
            s += jittor_utils.get_py3_include_path()
            s += " -I"+os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "jittor", "src"))
            s += " "
        elif arg == "--libs-flags":
            libext = {
                'Linux': 'so',
                'Darwin': 'dylib',
                'Windows': 'DLL',
            }[platform.system()]
            ldflags = jittor_utils.run_cmd(jittor_utils.get_py3_config_path() + " --ldflags")
            libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")]
            for libbase in libpaths:
                libpath = os.path.join(libbase, f"lib{base}.{libext}")
                if os.path.isfile(libpath):
                    s += f" -L{libbase} -l{base} -ldl "
                    break
            else:
                raise RuntimeError("Python dynamic library not found")
            if os.name == 'nt':
                s = s.replace('-ldl', '')
        elif arg == "--cxx-flags":
            s += " --std=c++17 -fPIC "
        elif arg == "--cxx-example":
            cc_src = '''
// please compile with: g++ a.cc $(python3 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o a.out && ./a.out
 def test_console(self):
     run_cmd(f"{sys.executable} -m jittor_utils.config --cxx-example > tmp.cc", jt.flags.cache_path)
     s = run_cmd(f"{jt.flags.cc_path} tmp.cc $({sys.executable} -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o tmp.out && ./tmp.out", jt.flags.cache_path)
     print(s)
     assert "jt.Var" in s
     assert "pred.shape 2 1000" in s
示例#15
0
def setup_mpi():
    global mpi_ops, mpi, use_mpi
    global mpicc_path, has_mpi
    use_mpi = os.environ.get("use_mpi", "1") == "1"
    mpi_ops = None
    mpi = None
    has_mpi = False
    mpicc_path = env_or_try_find('mpicc_path', 'mpicc')
    if mpicc_path == "":
        LOG.i("mpicc not found, distribution disabled.")
        use_mpi = False
    else:
        use_mpi = True
        has_mpi = True
    if not inside_mpi():
        use_mpi = False
    if not use_mpi:
        return

    global mpi_compile_flags, mpi_link_flags, mpi_flags
    mpi_compile_flags = run_cmd(mpicc_path + " --showme:compile")
    mpi_link_flags = run_cmd(mpicc_path + " --showme:link")
    mpi_flags = mpi_compile_flags + " " + mpi_link_flags
    LOG.v("mpi_flags: " + mpi_flags)

    # find all source files
    mpi_src_dir = os.path.join(jittor_path, "extern", "mpi")
    mpi_src_files = []
    for r, _, f in os.walk(mpi_src_dir):
        for fname in f:
            mpi_src_files.append(os.path.join(r, fname))

    # mpi compile flags add for nccl
    mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
    mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")

    mpi_version = get_version(mpicc_path)
    if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
        # mpi version 1.x need to link like this
        manual_link(mpi_flags)
    # mpi(4.x) cannot use deepbind, it need to
    # share the 'environ' symbol.
    mpi = compile_custom_ops(mpi_src_files,
                             extra_flags=f" {mpi_flags} ",
                             return_module=True,
                             dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
    mpi_ops = mpi.ops
    LOG.vv("Get mpi: " + str(mpi.__dict__.keys()))
    LOG.vv("Get mpi_ops: " + str(mpi_ops.__dict__.keys()))

    def warper(func):
        def inner(self, *args, **kw):
            return func(self, *args, **kw)

        inner.__doc__ = func.__doc__
        return inner

    for k in mpi_ops.__dict__:
        if not k.startswith("mpi_"): continue
        if k == "mpi_test": continue
        setattr(core.Var, k, warper(mpi_ops.__dict__[k]))