Exemple #1
0
def compile(cache_path, jittor_path):
    headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
    headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
    headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
        [ os.path.join(cache_path, h) for h in headers2 ]
    basenames = []
    pyjt_names = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        # jit_op_maker.h merge compile with var_holder.h
        if h.endswith("src/var_holder.h"): continue
        if h.endswith("jit_op_maker.h"):
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = h.split("/")[-1].split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)
        pyjt_names.append(fname)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
    pyjt_names.append(fname)
    return pyjt_names
Exemple #2
0
def generate_error_code_from_func_header(func_head, target_scope_name, name,
                                         dfs, basename, h, class_info):
    # func_head is a string like:
    # (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*
    lib_name = os.path.basename(h).split("_")[0]
    # TODO: fix/add var help
    if target_scope_name == "Var": target_scope_name = None
    if target_scope_name:
        if target_scope_name == "flags":
            help_name = "flags"
        else:
            help_name = "" + target_scope_name + '.' + name
    else:
        help_name = name
    if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
        help_name = lib_name + '.' + help_name
    help_cmd = f"help(jt.{help_name})"

    LOG.vvv("gen err from func_head", func_head)
    args = func_head[1:].split(")")[0].split(",")
    error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
    error_code += r' << "\n\nTypes of your inputs are:\n"'
    for arg in args:
        arg = arg.strip()
        if arg.startswith("PyObject* "):
            t, n = arg.split(' ')
            if n == "args" or n == "_args":
                error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} "
            elif n == "kw":
                error_code += f" << PyKwArgPrinter{{{n}}} "
            else:
                error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} "
        elif arg.startswith("PyObject** "):
            t, n = arg.split(' ')
            error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} "
            break
        else:
            LOG.vvv("Unhandled arg", arg)
    LOG.vvv("gen err from func_head", func_head, " -> ", error_code)
    return error_code
Exemple #3
0
def compile_custom_ops(filenames,
                       extra_flags="",
                       return_module=False,
                       dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW
                       | os.RTLD_DEEPBIND,
                       gen_name_=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return_module: return module rather than ops(default: False)
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    pyjt_includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(
                ".cu"):
            builds.append(name)
        if name.endswith(".h"):
            dirname = os.path.dirname(name)
            if dirname.endswith("inc"):
                includes.append(dirname)
            with open(name, "r") as f:
                if "@pyjt" in f.read():
                    pyjt_includes.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if gen_name_ != "":
        gen_name = gen_name_
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(),
                               export=gen_name,
                               extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h")
    gen_lib = os.path.join("custom_ops", gen_name + extension_suffix)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
    # gen src initialize first
    builds.insert(0, gen_src_fname)

    def insert_anchor(gen_src, anchor_str, insert_str):
        # insert insert_str after anchor_str into gen_src
        return gen_src.replace(anchor_str, anchor_str + insert_str, 1)

    for name in pyjt_includes:
        LOG.i("handle pyjt_include", name)
        bname = name.split("/")[-1].split(".")[0]
        gen_src_fname = os.path.join(cache_path, "custom_ops",
                                     gen_name + "_" + bname + ".cc")
        pyjt_compiler.compile_single(name, gen_src_fname)
        builds.insert(1, gen_src_fname)
        gen_src = insert_anchor(gen_src, "namespace jittor {",
                                f"extern void pyjt_def_{bname}(PyObject* m);")
        gen_src = insert_anchor(
            gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {",
            f"jittor::pyjt_def_{bname}(m);")

    with open(gen_head_fname, "w") as f:
        f.write(gen_src)

    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds,
            gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    # unlock scope when initialize
    with lock.unlock_scope():
        with jit_utils.import_scope(dlopen_flags):
            exec(f"import {gen_name}")
    mod = locals()[gen_name]
    if return_module:
        return mod
    return mod.ops
Exemple #4
0
def compile_src(src, h, basename):
    res = list(reg.finditer(src, re.S))
    if len(res)==0: return
    class_ranges = None
    class_name = None
    class_info = None
    submodule_name = None
    submodule_ranges = None
    submodule_info = None
    defs = []
    LOG.vv(("find in", h))
    for x in res:
        LOG.vvv((x, x.groups()))
        g = x.groups()
        doc = g[1]
        pyjt = g[3]
        attrs = g[5]
        esplit = lambda x: [] if x==None else \
            [ a.strip() for a in x.split(",") if len(a.strip()) ]
        attrs = parse_attrs(attrs)
        pynames = esplit(pyjt)
        end = x.end()
        def find_bc(i):
            while src[i] not in "({;":
                i += 1
            j = i+1
            if src[i]==';':
                return i, j
            presum = 1
            while True:
                if src[j] in "({[":
                    presum += 1
                elif src[j] in ")}]":
                    presum -= 1
                    if presum==0:
                        s = src[i]+src[j]
                        assert s in ("()","{}","()"), "braces not match "+s
                        return i, j
                j += 1
        # // @pyjt(DType)
        # struct DType {
        #              ^ --> a
        #     .....
        # } <--- b
        # or
        # // @pyjt(hash)
        # inline uint hash(const char* input)
        #                 ^ --> a           ^ --> b
        a, b = find_bc(end)
        is_property = 0
        if src[a] == ';':
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            is_property = 1
        if src[a] == '{':
            assert len(pynames)==1
            if "submodule" in attrs:
                assert submodule_ranges==None
                submodule_ranges = (a, b)
                submodule_name = src[end:a-1].strip().split()[-1]
                submodule_info = {
                    "pynames": pynames,
                    "attrs": attrs
                }
                continue
            assert class_ranges==None
            class_ranges = (a, b)
            class_name = src[end:a-1].strip().split()[-1]
            class_info = {
                "pynames": pynames,
                "attrs": attrs
            }
            continue
        is_scope_def = False
        is_static = False
        scope_name = ""
        if class_ranges != None:
            if class_ranges[0] < a and a < class_ranges[1]:
                is_scope_def = True
                scope_name = class_name
        if submodule_ranges != None:
            if submodule_ranges[0] < a and a < submodule_ranges[1]:
                is_scope_def = True
                scope_name = submodule_name
                is_static = True
        dec = src[end:b+1].strip()
        arr = src[end:a].strip().split()
        func_name = arr[-1]

        is_constructor = False
        if is_scope_def and func_name==class_name:
            is_constructor = True

        args = []
        for arg in split_args(src[a+1:b]):
            if arg=="": continue
            default = ""
            if "=" in arg:
                arg, default = arg.split('=')
                default = default
            arg = arg.strip()
            name = arg.split(' ')[-1]
            tp = arg[:-len(name)]
            tp = tp.strip()
            prev_tp = tp
            # const string& ----> string
            if tp.startswith("const") and tp.endswith("&"):
                tp = tp[5:-1].strip()
            # T&& -> T
            if tp.endswith("&&"):
                tp = tp[:-2].strip()
            # ArrayArgs& -> ArrayArgs
            if tp.endswith("&"):
                tp = tp[:-1].strip()
            args.append((tp, name.strip(), default.strip(), prev_tp))
        return_t = ""
        for a in arr[:-1]:
            if a in ["", "inline", "constexpr"]: continue
            if a == "static":
                is_static = True
                continue
            if return_t != "": return_t += " "
            return_t += a

        if is_scope_def and class_info and "submodule" in class_info["attrs"]:
            is_static = True

        for pid, pyname in enumerate(pynames):
            for rname in [ "__lt__", "__le__", "__gt__",
                "__ge__", "__eq__", "__ne__"]:
                if pyname.endswith(rname):
                    attrs[rname] = 1
                    pynames[pid] = pyname.replace(rname, "__richcmp__")

        def_info = {
            "is_scope_def": is_scope_def,
            "is_constructor": is_constructor,
            "is_static": is_static,
            "is_property": is_property,
            "func_name": func_name,
            "args": args, # [(type,name,defaut), ...]
            "return_t": return_t, # return type
            "dec": dec, # full string of xxx(A a, B b)
            "pynames": pynames, # names in @pyjt(...)
            "attrs": attrs, # attrs in @attrs(...)
            "doc": doc,
            "scope_name": scope_name,
        }
        if is_property:
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            assert is_scope_def and not is_static
            def_info["is_property"] = 1
            def_info["pynames"] = ["__get__"+n for n in pynames]
            assert return_t != "void"
            defs.append(dict(def_info))
            def_info["pynames"] = ["__set__"+n for n in pynames]
            assert len(args) == 0
            def_info["args"] = [(def_info["return_t"], func_name, "", "")]
            def_info["return_t"] = "void"
            defs.append(dict(def_info))
            continue
        else:
            defs.append(def_info)
        LOG.vvv(json.dumps(def_info, indent=4))
    # deal with defs
    if len(defs) == 0: return
    # include_name = h[4:] # remove "src/" prefix
    include_name = h
    code = []
    class_defs_code = []
    class_getsets_code = []
    class_gets = OrderedDict()
    class_sets = OrderedDict()
    class_slots_code = []
    submodule_defs_code = []
    def_targets = OrderedDict()
    for df in defs:
        for name in df["pynames"]:
            if df["is_scope_def"] and '.' not in name:
                if df["scope_name"] == class_name:
                    name = class_info["pynames"][0] + '.' + name
                else:
                    name = submodule_info["pynames"][0] + '.' + name
            if name not in def_targets:
                def_targets[name] = []
            def_targets[name].append(df)
    for name in def_targets:
        dfs = def_targets[name]
        target_scope_name = None
        LOG.vv(name)
        if "." in name:
            target_scope_name, name = name.split(".")
        # array for each df:
        arr_func_quick_check_runable = []
        arr_func_args_convert = []
        arr_fill_with_default = []
        arr_func_call = []
        arr_has_return = []
        self_as_arg0 = False
        for df in dfs:
            self_as_arg0 = class_info and \
                target_scope_name == class_info["pynames"][0] and \
                df["scope_name"] == submodule_name \
                and not name.startswith("__")
            res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0))
            arr_func_quick_check_runable.append(res[0])
            arr_func_args_convert.append(res[1])
            arr_fill_with_default.append(res[2])
            arr_func_call.append(res[3])
            arr_has_return.append(res[4])
            
        slot_name = None
        func_cast = ""
        func_fill = ""
        if name == "__init__":
            slot_name = "tp_init"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__repr__":
            slot_name = "tp_repr"
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__get__"):
            slot_name = "tp_gets"
            name = name[len("__get__"):]
            func_head = "(PyObject* self, void*) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__set__"):
            slot_name = "tp_sets"
            name = name[len("__set__"):]
            func_head = "(PyObject* self, PyObject* arg, void*) -> int"
            func_fill = """
                int64 n=1;
                PyObject** args = &arg;
                (void)n, (void)args;
            """

        elif name == "__call__":
            slot_name = "tp_call"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__dealloc__":
            slot_name = "tp_dealloc"
            func_head = "(PyObject* self) -> void"
            func_fill = "int64 n = 0"
        
        elif name in binary_number_slots:
            slot_name = "tp_as_number->"+binary_number_slots[name]
            func_head = "(PyObject* self, PyObject* b) -> PyObject*"
            if name.endswith("pow__"):
                func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """
        
        elif name in unary_number_slots:
            slot_name = "tp_as_number->"+unary_number_slots[name]
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = """
                int64 n = 1;
                PyObject* args[] = {self};
                (void)n, (void)args;
            """
        
        elif name == "__richcmp__":
            slot_name = "tp_richcompare"
            func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """

        elif name == "__len__":
            slot_name = "tp_as_sequence->sq_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__map_len__":
            slot_name = "tp_as_mapping->mp_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__getitem__":
            slot_name = "tp_as_sequence->sq_item"
            func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                (void)n;
                if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
                    PyErr_SetString(PyExc_IndexError, "");
                    return 0;
                }}
            """

        elif name == "__map_getitem__":
            slot_name = "tp_as_mapping->mp_subscript"
            func_head = "(PyObject* self, PyObject* arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                PyObject* args[] = {{arg0}};
                (void)n;
            """

        elif name.startswith("__"):
            LOG.f(f"Not support slot {name}")
            continue

        else:
            func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*"
            func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))"
            # if not return, return py_none
            arr_has_return = [ True for _ in arr_has_return ]

        arr_func_return = []
        doc_all = ""
        decs = "Declarations:\n"
        for did, has_return in enumerate(arr_has_return):
            df = dfs[did]
            func_call = arr_func_call[did]
            if df["doc"]:
                doc_all += "Document:\n"
                doc_all += df["doc"]
            doc_all += "\nDeclaration:\n"
            doc_all += df["dec"]
            decs += df["dec"]+'\n'
            if has_return:
                assert "-> int" not in func_head
                if "-> PyObject*" in func_head:
                    if "return_self" in df["attrs"]:
                        arr_func_return.append(
                            f"return (({func_call}), Py_INCREF(self), self)")
                    else:
                        arr_func_return.append(
                            f"return {get_pytype_map(df['return_t'],1)}(({func_call}))")
                    func_return_failed = "return nullptr"
                else:
                    arr_func_return.append(
                        f"return ({func_call});")
                    func_return_failed = "return -1"
            else:
                if "-> int" in func_head:
                    arr_func_return.append(f"return ({func_call},0)")
                    func_return_failed = "return -1"
                else:
                    assert "-> void" in func_head
                    arr_func_return.append(f"{func_call};return")
                    func_return_failed = "return"
        func = f"""
        {func_cast}[]{func_head} {{
            try {{
                {func_fill};
                uint64 arg_filled=0;
                (void)arg_filled;
                {"".join([f'''
                if ({arr_func_quick_check_runable[did]}) {{
                    {arr_func_args_convert[did]};
                    {arr_fill_with_default[did]};
                    {arr_func_return[did]};
                }}
                '''
                for did in range(len(arr_func_return))
                ])}
                LOGf << "Not a valid call";
            }} catch (const std::exception& e) {{
                PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
                    e.what(),
                    R""({decs})""
                );
            }}
            {func_return_failed};
        }}
        """

        if slot_name:
            if slot_name=="tp_gets":
                class_gets[name] = {
                    "func": func,
                    "doc": doc_all
                }
                continue
            if slot_name=="tp_sets":
                class_sets[name] = {
                    "func": func,
                    "doc": ""
                }
                continue
            class_slots_code.append(f"""
            tp.{slot_name} = {func};
            """)
            continue
        need_static = ""
        if df["is_scope_def"] and df["is_static"] and \
            df["scope_name"] == class_name and \
            "submodule" not in class_info["attrs"]:
            need_static = " | METH_STATIC"
        func = (f"""
        {{ R""({name})"",
        {func},
        METH_FASTCALL | METH_KEYWORDS{need_static},
        R""({doc_all})""
        }}""")
        if df["is_scope_def"]:
            if df["scope_name"] == class_name or \
                (class_info and \
                    target_scope_name == class_info["pynames"][0]):
                class_defs_code.append(func)
            else:
                submodule_defs_code.append(func)
        else:
            code.append(func)
    prop_names = list(set(class_gets.keys()).union(class_sets.keys()))
    prop_names = sorted(prop_names)
    for prop_name in prop_names:
        get_func = "NULL"
        set_func = "NULL"
        doc = ""
        if prop_name in class_gets:
            get_func = class_gets[prop_name]["func"]
            if class_gets[prop_name]["doc"]:
                doc += class_gets[prop_name]["doc"]
        if prop_name in class_sets:
            set_func = class_sets[prop_name]["func"]
            if class_sets[prop_name]["doc"]:
                doc += class_sets[prop_name]["doc"]
        class_getsets_code.append(f"""
            {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}}
        """)
    code.append("{0,0,0,0}")
    class_defs_code.append("{0,0,0,0}")
    class_getsets_code.append("{0,0,0,0}")
    submodule_defs_code.append("{0,0,0,0}")
    core_name = "jittor_core"
    if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]:
        core_name = class_info["attrs"]["core_name"]
    if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
        core_name = submodule_info["attrs"]["core_name"]
    has_map = class_name in ["VarHolder", "NanoVector"]
    has_seq = class_name == "NanoVector"
    code = f"""
    #include "pyjt/py_converter.h"
    #include "common.h"
    #include "{include_name}"

    namespace jittor {{

    {
    "" if class_name is None else
    f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else
    f"PyTypeObject Pyjt{class_name};"
    }
    
    void pyjt_def_{basename}(PyObject* m) {{
        static PyMethodDef defs[] = {{
            {",".join(code)}
        }};
        ASSERT(PyModule_AddFunctions(m, defs)==0);
        {
        f'''
        static PyMethodDef class_defs[] = {{
            {",".join(class_defs_code)}
        }};
        static PyGetSetDef class_getsets[] = {{
            {",".join(class_getsets_code)}
        }};

        static PyNumberMethods number_methods = {{0}};
        {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;"
        if "heaptype" in class_info["attrs"] else
        f"auto& tp = Pyjt{class_name};"}
        tp.tp_as_number = &number_methods;

        {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""}
        {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""}

        {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""}
        {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""}
        
        tp.tp_name = "{core_name}.{class_info["pynames"][0]}";
        tp.tp_basicsize = GET_OBJ_SIZE({class_name});
        tp.tp_new = PyType_GenericNew;
        tp.tp_flags = Py_TPFLAGS_DEFAULT;
        {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);"
        if "heaptype" in class_info["attrs"] else ""}
        tp.tp_methods = &class_defs[0];
        tp.tp_getset = &class_getsets[0];
        {"".join(class_slots_code)};
        ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0);
        Py_INCREF(&tp);
        ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp));
        ''' if class_name is not None else ""
        }
        {f'''

        // sub module def
        static PyMethodDef submodule_defs[] = {{
            {",".join(submodule_defs_code)}
        }};
        auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}");
        ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0);
        ASSERT(sub);
        ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub));
        ''' if submodule_name is not None else ""
        }

    }}

    }}
    """
    return code
Exemple #5
0
def compile_custom_ops(filenames, extra_flags=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(
                ".cu"):
            builds.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(),
                               export=gen_name,
                               extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h")
    gen_lib = os.path.join("custom_ops", gen_name + extension_suffix)
    with open(gen_head_fname, "w") as f:
        f.write(gen_src)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
    # gen src initialize first
    builds.insert(0, gen_src_fname)
    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, cc_flags + opt_flags + includes + extra_flags, builds,
            gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW
                                | os.RTLD_DEEPBIND):
        exec(f"import {gen_name}")
    return (locals()[gen_name]).ops