コード例 #1
0
ファイル: compiler.py プロジェクト: Exusial/jittor
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
        "src/utils/str_utils.cc",
    ]
    if os.name == 'nt':
        files = [x.replace('/', '\\') for x in files]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(
        cc_path, cc_flags + f" {opt_flags} ", files,
        jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please rerun your command.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags + f" {opt_flags} ", files,
                jit_utils.cache_path + '/jit_utils_core' + extension_suffix,
                True)
コード例 #2
0
ファイル: compiler.py プロジェクト: zhengjn/jittor
def compile_custom_ops(filenames, extra_flags=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
            builds.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
    gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
    with open(gen_head_fname, "w") as f:
        f.write(gen_src)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
    # gen src initialize first
    builds.insert(0, gen_src_fname)
    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
        exec(f"import {gen_name}")
    return (locals()[gen_name]).ops
コード例 #3
0
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
    ]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please restart jittor.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
コード例 #4
0
def compile_custom_ops(filenames,
                       extra_flags="",
                       return_module=False,
                       dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW
                       | os.RTLD_DEEPBIND,
                       gen_name_=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return_module: return module rather than ops(default: False)
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    pyjt_includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(
                ".cu"):
            builds.append(name)
        if name.endswith(".h"):
            dirname = os.path.dirname(name)
            if dirname.endswith("inc"):
                includes.append(dirname)
            with open(name, "r") as f:
                if "@pyjt" in f.read():
                    pyjt_includes.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if gen_name_ != "":
        gen_name = gen_name_
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(),
                               export=gen_name,
                               extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h")
    gen_lib = os.path.join("custom_ops", gen_name + extension_suffix)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
    # gen src initialize first
    builds.insert(0, gen_src_fname)

    def insert_anchor(gen_src, anchor_str, insert_str):
        # insert insert_str after anchor_str into gen_src
        return gen_src.replace(anchor_str, anchor_str + insert_str, 1)

    for name in pyjt_includes:
        LOG.i("handle pyjt_include", name)
        bname = name.split("/")[-1].split(".")[0]
        gen_src_fname = os.path.join(cache_path, "custom_ops",
                                     gen_name + "_" + bname + ".cc")
        pyjt_compiler.compile_single(name, gen_src_fname)
        builds.insert(1, gen_src_fname)
        gen_src = insert_anchor(gen_src, "namespace jittor {",
                                f"extern void pyjt_def_{bname}(PyObject* m);")
        gen_src = insert_anchor(
            gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {",
            f"jittor::pyjt_def_{bname}(m);")

    with open(gen_head_fname, "w") as f:
        f.write(gen_src)

    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds,
            gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    # unlock scope when initialize
    with lock.unlock_scope():
        with jit_utils.import_scope(dlopen_flags):
            exec(f"import {gen_name}")
    mod = locals()[gen_name]
    if return_module:
        return mod
    return mod.ops
コード例 #5
0
    is_debug = 0
    if os.environ.get("debug") == "1":
        is_debug = 1
        global cc_flags
        cc_flags += " -g -DNODE_MEMCHECK "


cc_flags = " "
# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first
import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND
# if cc_type=="icc":
#     # weird link problem, icc omp library may conflict and cause segfault
#     import_flags = os.RTLD_NOW | os.RTLD_GLOBAL
dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND

with jit_utils.import_scope(import_flags):
    jit_utils.try_import_jit_utils_core()

jittor_path = find_jittor_path()
check_debug_flags()

sys.path.append(cache_path)

with jit_utils.import_scope(import_flags):
    jit_utils.try_import_jit_utils_core()

python_path = sys.executable
py3_config_paths = [
    sys.executable + "-config",
    os.path.dirname(sys.executable) +
    f"/python3.{sys.version_info.minor}-config",
コード例 #6
0
import os
from pathlib import Path
from collections import defaultdict
import pickle
import numpy as np
import jittor_utils
from jittor_utils import LOG
import sys

with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW):
    jittor_utils.try_import_jit_utils_core()

has_error = 0


def convert(data):
    if hasattr(data, "numpy"):
        if "Var" in data.__class__.__name__:
            return data.numpy()
        else:
            return data.detach().cpu().numpy()
    if isinstance(data, tuple):
        return tuple(convert(v) for v in data)
    if isinstance(data, list):
        return [convert(v) for v in data]
    if isinstance(data, np.ndarray):
        return data
    if isinstance(data, dict):
        return {k: convert(data[k]) for k in data}
    return data