Exemplo n.º 1
0
def compile_custom_ops(filenames, extra_flags=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
            builds.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
    gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
    with open(gen_head_fname, "w") as f:
        f.write(gen_src)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
    # gen src initialize first
    builds.insert(0, gen_src_fname)
    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
        exec(f"import {gen_name}")
    return (locals()[gen_name]).ops
Exemplo n.º 2
0
def gen_jit_flags():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)

    flags_defs = []
    visit = {}

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for _, args in defs:
            args = args.split(",")
            type = args[0].strip()
            name = args[1].strip()
            if not has_cuda and "cuda" in name and name != "use_cuda":
                continue
            default = args[2].strip()
            doc = ",".join(args[3:])
            doc = eval(f"({doc})")
            LOG.vv(f"Find define {name} from {src_name}")
            if name in visit:
                continue
            visit[name] = 1
            jit_declares.append(f"DECLARE_FLAG({type}, {name});")
            flags_defs.append(f"""
                /* {name}(type:{type}, default:{default}): {doc} */
                // @pyjt(__get__{name})
                {type} _get_{name}() {{ return {name}; }}
                // @pyjt(__set__{name})
                void _set_{name}({type} v) {{ set_{name}(v); }}
                {f'''// @pyjt(__set__{name})
                void _set_{name}(bool v) {{ set_{name}(v); }}
                ''' if type=="int" else ""}
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #include "utils/flags.h"

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(flags)
    struct _Flags {{
        // @pyjt(__init__)
        _Flags() {{}}
        {"".join(flags_defs)}
    }};

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
        f.write(jit_src)
Exemplo n.º 3
0
def compile_single(head_file_name, src_file_name, src=None):
    basename = head_file_name.split("/")[-1].split(".")[0]
    if src==None:
        with open(head_file_name, 'r') as f:
            src = f.read()
    code = compile_src(src, head_file_name, basename)
    if not code: return False
    LOG.vvv("write to", src_file_name)
    LOG.vvvv(code)
    with open(src_file_name, 'w') as f:
        f.write(code)
    return True
Exemplo n.º 4
0
def compile(cache_path, jittor_path):
    headers1 = glob.glob(jittor_path + "/src/**/*.h", recursive=True)
    headers2 = glob.glob(cache_path + "/gen/**/*.h", recursive=True)
    headers = headers1 + headers2
    basenames = []
    pyjt_names = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        bh = os.path.basename(h)
        # jit_op_maker.h merge compile with var_holder.h
        if bh == "var_holder.h": continue
        if bh == "jit_op_maker.h":
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = bh.split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)
        pyjt_names.append(fname)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
    pyjt_names.append(fname)
    return pyjt_names
Exemplo n.º 5
0
def compile(cache_path, jittor_path):
    headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
    headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
    headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
        [ os.path.join(cache_path, h) for h in headers2 ]
    basenames = []
    pyjt_names = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        # jit_op_maker.h merge compile with var_holder.h
        if h.endswith("src/var_holder.h"): continue
        if h.endswith("jit_op_maker.h"):
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = h.split("/")[-1].split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)
        pyjt_names.append(fname)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
    pyjt_names.append(fname)
    return pyjt_names
Exemplo n.º 6
0
def gen_jit_tests():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("JIT_TEST\\((.*?)\\)")
    names = set()
    test_defs = []

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for name in defs:
            LOG.vv(f"Find test {name} from {src_name}")
            assert name not in names, f"Conflict test name {name}"
            names.add(name)
            jit_declares.append(f"JIT_TEST({name});")
            test_defs.append(f"""
                /* From {src_name} */
                // @pyjt({name})
                static inline void test_{name}() {{ jit_test_{name}(); }} 
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #pragma once
    #include "common.h"

    void expect_error(std::function<void()> func) {{
        try {{ func(); }}
        catch (...) {{ return; }}
        CHECK(0) << "Missing error";
    }}

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(tests)
    // @attrs(submodule)
    namespace tests {{
        {"".join(test_defs)}
    }}

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
        f.write(jit_src)
Exemplo n.º 7
0
def compile_custom_ops(filenames,
                       extra_flags="",
                       return_module=False,
                       dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW
                       | os.RTLD_DEEPBIND,
                       gen_name_=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return_module: return module rather than ops(default: False)
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    pyjt_includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(
                ".cu"):
            builds.append(name)
        if name.endswith(".h"):
            dirname = os.path.dirname(name)
            if dirname.endswith("inc"):
                includes.append(dirname)
            with open(name, "r") as f:
                if "@pyjt" in f.read():
                    pyjt_includes.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if gen_name_ != "":
        gen_name = gen_name_
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(),
                               export=gen_name,
                               extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h")
    gen_lib = os.path.join("custom_ops", gen_name + extension_suffix)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src)
    # gen src initialize first
    builds.insert(0, gen_src_fname)

    def insert_anchor(gen_src, anchor_str, insert_str):
        # insert insert_str after anchor_str into gen_src
        return gen_src.replace(anchor_str, anchor_str + insert_str, 1)

    for name in pyjt_includes:
        LOG.i("handle pyjt_include", name)
        bname = name.split("/")[-1].split(".")[0]
        gen_src_fname = os.path.join(cache_path, "custom_ops",
                                     gen_name + "_" + bname + ".cc")
        pyjt_compiler.compile_single(name, gen_src_fname)
        builds.insert(1, gen_src_fname)
        gen_src = insert_anchor(gen_src, "namespace jittor {",
                                f"extern void pyjt_def_{bname}(PyObject* m);")
        gen_src = insert_anchor(
            gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {",
            f"jittor::pyjt_def_{bname}(m);")

    with open(gen_head_fname, "w") as f:
        f.write(gen_src)

    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds,
            gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    # unlock scope when initialize
    with lock.unlock_scope():
        with jit_utils.import_scope(dlopen_flags):
            exec(f"import {gen_name}")
    mod = locals()[gen_name]
    if return_module:
        return mod
    return mod.ops
Exemplo n.º 8
0
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
    def add_src(cc_func_name, cc_args, op_name, op_args, src, pybind_name,
                py_args, jit_cc_src, doc_string, attrs):
        has_ir = set([
            "add", "sub", "mul", "matmul", "truediv", "floordiv", "mod",
            "divmod", "pow", "lshift", "rshift", "and", "xor", "or"
        ])
        pybind_names = [s.strip() for s in pybind_name.split(",")]
        cc_make_args = [arg.replace("VarHolder*", "Var*") for arg in cc_args]
        op_make_args = [arg.replace("->var", "") for arg in op_args]
        py_args = [arg.replace("Var*", "VarHolder*") for arg in py_args]
        op_args = []
        cc_args_with_default = []
        for i, arg in enumerate(cc_args):
            pre_arg = arg.split()[-1].split('=')[0]
            op_arg = None
            if arg.startswith("VarHolder*"):
                op_arg = pre_arg + "->var"
            elif arg.startswith("vector<VarHolder*>"):
                op_arg = f"convert({pre_arg})"
            if "&&" in arg:
                if op_arg == None:
                    op_arg = "move(" + pre_arg + ")"
                op_make_args[i] = "move(" + pre_arg + ")"
            if op_arg == None: op_arg = pre_arg
            op_args.append(op_arg)
            py_arg = py_args[i]
            if "_a=" not in py_arg:
                cc_args_with_default.append(arg)
                continue
            py_arg = py_arg.split("_a=")[1]
            cc_args_with_default.append(arg + "=" + py_arg)
        cc_args = cc_args_with_default
        # steps of Op creation:
        # 1. new op
        # 2. new output var (create_output in op constructor)
        # 3. take over op's output VarPtr from outputs_holder
        # 4. set op's output
        # 5. set op's input
        # 6. infer shape(op->init())
        if "multiple_outputs" not in attrs:
            jit_cc_src.append(f"""
            VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
                Op* _op = new {op_name}({", ".join(op_make_args)});
                if (_op->outputs_holder.size() != 1) {{
                    delete _op;
                    LOGf << "Wrong output size of" << \"{op_name}\";
                }}
                if (_op->flags.get(NodeFlags::_forwarded)) {{
                    VarPtr output(move(_op->outputs_holder[0]));
                    delete _op;
                    return output;
                }}
                _op->outputs_holder[0]->set_inputs({{_op}});
                VarPtr output(move(_op->outputs_holder[0]));
                {src.replace("->var","")};
                _op->init();
                return output;
            }}
            """)
        else:
            jit_cc_src.append(f"""
            vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
                Op* _op = new {op_name}({", ".join(op_make_args)});
                if (_op->flags.get(NodeFlags::_forwarded)) {{
                    vector<VarPtr> outputs = move(_op->outputs_holder);
                    delete _op;
                    return outputs;
                }}
                vector<VarPtr> outputs = move(_op->outputs_holder);
                for (uint i=0; i<outputs.size(); i++)
                    outputs[i]->set_inputs({{_op}});
                {src.replace("->var","")};
                _op->init();
                return outputs;
            }}
            """)
        if pybind_name == 'None':
            return
        pyjt_names = []
        for pybind_name in pybind_names:
            if pybind_name.startswith("__"):
                pyjt_names.append("Var." + pybind_name)
            else:
                pyjt_names.append(pybind_name)
                if len(cc_args) > 0 and cc_args[0].startswith("VarHolder* "):
                    pyjt_names.append("Var." + pybind_name)
        if "multiple_outputs" in attrs:
            jit_cc_src.append(f"""
            /*{doc_string}*/
            // @pyjt({",".join(pyjt_names)})
            vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
                return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)
        else:
            jit_cc_src.append(f"""
            /*{doc_string}*/
            // @pyjt({",".join(pyjt_names)})
            VarHolder* {cc_func_name}({", ".join(cc_args)}) {{
                return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)
        need_ir_define = False
        ir_name = None
        for pybind_name in pybind_names:
            if pybind_name.startswith("__") and pybind_name[2:-2] in has_ir:
                need_ir_define = True
                assert ir_name is None
                ir_name = pybind_name[2:-2]
        if need_ir_define:
            assert len(cc_args) > 0 and cc_args[0].startswith("VarHolder* ")
            this = cc_args[0].split()[-1]
            jit_cc_src.append(f"""
            // @pyjt(Var.__i{ir_name}__)
            // @attrs(return_self)
            VarHolder* i{cc_func_name}({", ".join(cc_args)}) {{
                *{this} = make_{cc_func_name}({", ".join(op_args)});
                return {this};
            }}
            """)
            assert len(cc_args) > 1 and cc_args[1].startswith(
                "VarHolder* "), cc_args
            r_cc_args = [cc_args[1], cc_args[0]] + cc_args[2:]
            r_py_args = [py_args[1], py_args[0]] + py_args[2:]
            jit_cc_src.append(f"""
            VarHolder* r{cc_func_name}({", ".join(r_cc_args)}) {{
                return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)

    jit_cc_src = []
    jit_headers = ""
    initer = []
    pybind_reg = '(/\\*(.*?)\\*/\\s*)?(//\\s*@pybind\\(([^\\n]*)\\)\\s*)?'
    pybind_attrs_reg = pybind_reg + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?'
    for header in op_headers:
        # xxx_xxx_op
        name = os.path.basename(header)
        name = os.path.splitext(name)[0]
        # xxx_xxx
        assert name.endswith("_op")
        func_name = name[:-3]
        # XxxXxxOp
        name2 = map(lambda s: s[:1].upper() + s[1:], name.split('_'))
        name2 = "".join(name2)
        with open(os.path.join(jittor_path, header), encoding='utf8') as f:
            src = f.read()
        # XxxXxxOp(args)
        res = re.findall(pybind_attrs_reg + '[^~](' + name2 + "\\([^\\n]*\\))",
                         src, re.S)
        assert len(res) >= 1, "Wrong op args in " + header
        # registe op
        cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
        constructors = []
        for i in range(len(res)):
            name = 'make_' + func_name + '_' * i
            constructors.append(f"{{ &typeid(&{name}), (void*)&{name} }}")
        constructors = ",".join(constructors)
        var_member_reg = r"\n\s*Var\b(.*);"
        var_member_match = re.findall(var_member_reg, src)
        var_member_match = " ".join(var_member_match)
        for c in "*,":
            var_member_match = var_member_match.replace(c, " ")
        var_member = var_member_match.split()
        LOG.vv("var_member_match " + var_member_match)
        LOG.vv("var_member " + str(var_member))
        var_member_src = [
            f"VAR_MEMBER_NAME_AND_OFFSET({name}, {name2})"
            for name in var_member
        ]
        var_member_src = ",".join(var_member_src)
        initer.append(
            f'\n        op_registe({{ "{func_name}", R"({cc_name})", extra_flags, {{{constructors}}}, {{{var_member_src}}} }});'
        )
        for hid, h_def in enumerate(res):
            h_def = list(h_def)
            # // @attrs(...)
            attrs = {}
            if h_def[4] != "":
                attrs = pyjt_compiler.parse_attrs(h_def[5])
            del h_def[4:6]
            # /* doc_string */
            # // @pybind(bind_name)
            # XxxXxxOp(args_def)
            doc_string = h_def[1].strip()
            h_def = h_def[2:]
            args_def = h_def[2][len(name2) + 1:-1]
            bind_name = h_def[1]
            if bind_name == "":
                bind_name = func_name
            if args_def == "":
                args = []
            else:
                args = list(
                    map(lambda s: s.split()[-1].split('=')[0],
                        args_def.split(',')))
            # py_args: "arg"_a=default
            py_args = []
            new_args_def = []
            new_args = []
            # source of convert VarHolder* to Var*
            vh2v_src = []
            more_src = []
            for arg, arg_def in zip(args, args_def.split(',')):
                py_arg = f'"{arg}"_a'
                if '=' in arg_def:
                    py_arg += "=" + arg_def.split('=')[-1]
                    arg_def = arg_def.split('=')[0]
                py_args.append(py_arg)
                arg_type = arg_def[:-(len(arg) + 1)].strip()
                if arg_type == "Var*":
                    new_args_def.append("VarHolder* " + arg)
                    vh2v_src.append(arg + "->var")
                    new_args.append(arg + "->var")
                elif arg_type.startswith("vector<Var*>"):
                    new_args_def.append(
                        arg_type.replace("Var", "VarHolder") + ' ' + arg)
                    new_args.append(arg)
                    more_src.append(f"_op->add_inputs({arg});")
                else:
                    new_args_def.append(arg_def)
                    new_args.append(arg)
            vh2v_src = "_op->set_inputs({" + ", ".join(vh2v_src) + "});" + \
                "".join(more_src)
            LOG.vvvv(f"Find op: {name2} args: {new_args}")
            if header.startswith("src/"):
                jit_headers += f"#include \"{header[4:]}\"\n"
            else:
                jit_headers += f"#include \"{header}\"\n"
            add_src(func_name + '_' * hid, new_args_def, name2, new_args,
                    vh2v_src, bind_name, py_args, jit_cc_src, doc_string,
                    attrs)
            if func_name in ["binary", "unary", "reduce"]:
                # generate binary op alias
                with open(os.path.join(jittor_path,
                                       f"src/ops/{func_name}_op.cc"),
                          encoding="utf-8") as f:
                    src = f.read()
                src = src.split(f"unordered_set<string> {func_name}_ops = "
                                "{")[1].split("};")[0]
                res2 = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src,
                                  re.S)
                # remove /* doc_string */ pattern
                res2 = [(_[3], _[4]) for _ in res2]
                LOG.vvvv(f"All supported {func_name} ops: {res2}")
                # remove op args
                if func_name == "reduce":
                    args_def = new_args_def[:1] + new_args_def[2:]
                    py_args_s = py_args[:1] + py_args[2:]
                else:
                    args_def = new_args_def[:-1]
                    py_args_s = py_args[:-1]
                # find the last type id(float64)
                # add "_" suffix for all function
                if func_name == "unary":
                    last_tid = res2.index(("", "float64"))
                # for each functor
                for tid, (bind_name, func_name2) in enumerate(res2):
                    # add _ for types
                    if func_name == "unary" and tid <= last_tid:
                        func_name3 = func_name2 + "_"
                    elif func_name == "reduce":
                        func_name4 = func_name2
                        func_name2 = "reduce_" + func_name2
                        func_name3 = func_name2
                    else:
                        func_name3 = func_name2
                    if len(bind_name) == 0:
                        bind_name = func_name2
                    if func_name == "reduce":
                        args = new_args[:1] + [f'ns_{func_name4}'
                                               ] + new_args[2:]
                    else:
                        args = new_args[:-1] + [f'ns_{func_name2}']
                    add_src(func_name3 + '_' * hid, args_def, name2, args,
                            vh2v_src, bind_name, py_args_s, jit_cc_src,
                            doc_string, attrs)

    jit_src = f"""
    #pragma once
    #include "pyjt/py_obj_holder.h"
    #include "var.h"
    #include "var_holder.h"
    #include "ops/op_register.h"
    {jit_headers}
    
    namespace jittor {{
    // fix make_array(py::array) undefine reference
    #pragma GCC visibility push(default)
    #define JIT_NAMESPACE {export+"_maker" if export else "jit_op_maker"}
    // @pyjt(ops)
    // @attrs(submodule{",core_name="+export if export else ""})
    namespace JIT_NAMESPACE {{
    {"".join(jit_cc_src)}

    void initer() {{
        string extra_flags = R"({extra_flags})";
        {"".join(initer)}
    }}
    int caller = (initer(), 0);
    
    }} // JIT_NAMESPACE
    }} // jittor
    {f'''
    namespace jittor {{
    extern void pyjt_def_{export}(PyObject*);
    }}

    static void init_module(PyModuleDef* mdef, PyObject* m) {{
        mdef->m_doc = "User defined custom ops";
        jittor::pyjt_def_{export}(m);
    }}
    PYJF_MODULE_INIT({export});

    ''' if export else ""}
    """
    return jit_src
Exemplo n.º 9
0
        # nvcc warning is noise
        nvcc_flags += " -w "
        nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' "
        if os.environ.get("cuda_debug", "0") == "1":
            nvcc_flags += " -G "
        return nvcc_flags

    nvcc_flags = convert_nvcc_flags(nvcc_flags)

# build core
gen_jit_flags()
gen_jit_tests()
op_headers = run_cmd('find -L src/ops/ | grep "op.h$"',
                     jittor_path).splitlines()
jit_src = gen_jit_op_maker(op_headers)
LOG.vvvv(jit_src)
with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f:
    f.write(jit_src)
cc_flags += f' -I{cache_path} '
# gen pyjt
pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path)

# initialize order:
# 1. registers
# 2. generate source
# 3. op_utils
# 4. other
files2 = pyjt_gen_src
files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines()
at_beginning = [
    "src/ops/op_utils.cc",
Exemplo n.º 10
0
def run_cmd(cmd):
    LOG.vvvv(f"Run cmd: {cmd}")
    assert os.system(cmd) == 0, f"Run cmd failed: {cmd}"