def compile_custom_ops(filenames, extra_flags=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): builds.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h") gen_lib = os.path.join("custom_ops", gen_name+extension_suffix) with open(gen_head_fname, "w") as f: f.write(gen_src) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname) # gen src initialize first builds.insert(0, gen_src_fname) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): exec(f"import {gen_name}") return (locals()[gen_name]).ops
def gen_jit_flags(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL) flags_defs = [] visit = {} for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for _, args in defs: args = args.split(",") type = args[0].strip() name = args[1].strip() if not has_cuda and "cuda" in name and name != "use_cuda": continue default = args[2].strip() doc = ",".join(args[3:]) doc = eval(f"({doc})") LOG.vv(f"Find define {name} from {src_name}") if name in visit: continue visit[name] = 1 jit_declares.append(f"DECLARE_FLAG({type}, {name});") flags_defs.append(f""" /* {name}(type:{type}, default:{default}): {doc} */ // @pyjt(__get__{name}) {type} _get_{name}() {{ return {name}; }} // @pyjt(__set__{name}) void _set_{name}({type} v) {{ set_{name}(v); }} {f'''// @pyjt(__set__{name}) void _set_{name}(bool v) {{ set_{name}(v); }} ''' if type=="int" else ""} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #include "utils/flags.h" namespace jittor {{ {jit_declares} // @pyjt(flags) struct _Flags {{ // @pyjt(__init__) _Flags() {{}} {"".join(flags_defs)} }}; }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f: f.write(jit_src)
def compile_single(head_file_name, src_file_name, src=None): basename = head_file_name.split("/")[-1].split(".")[0] if src==None: with open(head_file_name, 'r') as f: src = f.read() code = compile_src(src, head_file_name, basename) if not code: return False LOG.vvv("write to", src_file_name) LOG.vvvv(code) with open(src_file_name, 'w') as f: f.write(code) return True
def compile(cache_path, jittor_path): headers1 = glob.glob(jittor_path + "/src/**/*.h", recursive=True) headers2 = glob.glob(cache_path + "/gen/**/*.h", recursive=True) headers = headers1 + headers2 basenames = [] pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() bh = os.path.basename(h) # jit_op_maker.h merge compile with var_holder.h if bh == "var_holder.h": continue if bh == "jit_op_maker.h": with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = bh.split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code) pyjt_names.append(fname) return pyjt_names
def compile(cache_path, jittor_path): headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines() headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines() headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \ [ os.path.join(cache_path, h) for h in headers2 ] basenames = [] pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() # jit_op_maker.h merge compile with var_holder.h if h.endswith("src/var_holder.h"): continue if h.endswith("jit_op_maker.h"): with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = h.split("/")[-1].split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code) pyjt_names.append(fname) return pyjt_names
def gen_jit_tests(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("JIT_TEST\\((.*?)\\)") names = set() test_defs = [] for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for name in defs: LOG.vv(f"Find test {name} from {src_name}") assert name not in names, f"Conflict test name {name}" names.add(name) jit_declares.append(f"JIT_TEST({name});") test_defs.append(f""" /* From {src_name} */ // @pyjt({name}) static inline void test_{name}() {{ jit_test_{name}(); }} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #pragma once #include "common.h" void expect_error(std::function<void()> func) {{ try {{ func(); }} catch (...) {{ return; }} CHECK(0) << "Missing error"; }} namespace jittor {{ {jit_declares} // @pyjt(tests) // @attrs(submodule) namespace tests {{ {"".join(test_defs)} }} }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f: f.write(jit_src)
def compile_custom_ops(filenames, extra_flags="", return_module=False, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND, gen_name_=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return_module: return module rather than ops(default: False) return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] pyjt_includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith( ".cu"): builds.append(name) if name.endswith(".h"): dirname = os.path.dirname(name) if dirname.endswith("inc"): includes.append(dirname) with open(name, "r") as f: if "@pyjt" in f.read(): pyjt_includes.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if gen_name_ != "": gen_name = gen_name_ if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h") gen_lib = os.path.join("custom_ops", gen_name + extension_suffix) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) # gen src initialize first builds.insert(0, gen_src_fname) def insert_anchor(gen_src, anchor_str, insert_str): # insert insert_str after anchor_str into gen_src return gen_src.replace(anchor_str, anchor_str + insert_str, 1) for name in pyjt_includes: LOG.i("handle pyjt_include", name) bname = name.split("/")[-1].split(".")[0] gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + "_" + bname + ".cc") pyjt_compiler.compile_single(name, gen_src_fname) builds.insert(1, gen_src_fname) gen_src = insert_anchor(gen_src, "namespace jittor {", f"extern void pyjt_def_{bname}(PyObject* m);") gen_src = insert_anchor( gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {", f"jittor::pyjt_def_{bname}(m);") with open(gen_head_fname, "w") as f: f.write(gen_src) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) # unlock scope when initialize with lock.unlock_scope(): with jit_utils.import_scope(dlopen_flags): exec(f"import {gen_name}") mod = locals()[gen_name] if return_module: return mod return mod.ops
def gen_jit_op_maker(op_headers, export=False, extra_flags=""): def add_src(cc_func_name, cc_args, op_name, op_args, src, pybind_name, py_args, jit_cc_src, doc_string, attrs): has_ir = set([ "add", "sub", "mul", "matmul", "truediv", "floordiv", "mod", "divmod", "pow", "lshift", "rshift", "and", "xor", "or" ]) pybind_names = [s.strip() for s in pybind_name.split(",")] cc_make_args = [arg.replace("VarHolder*", "Var*") for arg in cc_args] op_make_args = [arg.replace("->var", "") for arg in op_args] py_args = [arg.replace("Var*", "VarHolder*") for arg in py_args] op_args = [] cc_args_with_default = [] for i, arg in enumerate(cc_args): pre_arg = arg.split()[-1].split('=')[0] op_arg = None if arg.startswith("VarHolder*"): op_arg = pre_arg + "->var" elif arg.startswith("vector<VarHolder*>"): op_arg = f"convert({pre_arg})" if "&&" in arg: if op_arg == None: op_arg = "move(" + pre_arg + ")" op_make_args[i] = "move(" + pre_arg + ")" if op_arg == None: op_arg = pre_arg op_args.append(op_arg) py_arg = py_args[i] if "_a=" not in py_arg: cc_args_with_default.append(arg) continue py_arg = py_arg.split("_a=")[1] cc_args_with_default.append(arg + "=" + py_arg) cc_args = cc_args_with_default # steps of Op creation: # 1. new op # 2. new output var (create_output in op constructor) # 3. take over op's output VarPtr from outputs_holder # 4. set op's output # 5. set op's input # 6. infer shape(op->init()) if "multiple_outputs" not in attrs: jit_cc_src.append(f""" VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{ Op* _op = new {op_name}({", ".join(op_make_args)}); if (_op->outputs_holder.size() != 1) {{ delete _op; LOGf << "Wrong output size of" << \"{op_name}\"; }} if (_op->flags.get(NodeFlags::_forwarded)) {{ VarPtr output(move(_op->outputs_holder[0])); delete _op; return output; }} _op->outputs_holder[0]->set_inputs({{_op}}); VarPtr output(move(_op->outputs_holder[0])); {src.replace("->var","")}; _op->init(); return output; }} """) else: jit_cc_src.append(f""" vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{ Op* _op = new {op_name}({", ".join(op_make_args)}); if (_op->flags.get(NodeFlags::_forwarded)) {{ vector<VarPtr> outputs = move(_op->outputs_holder); delete _op; return outputs; }} vector<VarPtr> outputs = move(_op->outputs_holder); for (uint i=0; i<outputs.size(); i++) outputs[i]->set_inputs({{_op}}); {src.replace("->var","")}; _op->init(); return outputs; }} """) if pybind_name == 'None': return pyjt_names = [] for pybind_name in pybind_names: if pybind_name.startswith("__"): pyjt_names.append("Var." + pybind_name) else: pyjt_names.append(pybind_name) if len(cc_args) > 0 and cc_args[0].startswith("VarHolder* "): pyjt_names.append("Var." + pybind_name) if "multiple_outputs" in attrs: jit_cc_src.append(f""" /*{doc_string}*/ // @pyjt({",".join(pyjt_names)}) vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{ return make_vh_vector(make_{cc_func_name}({", ".join(op_args)})); }} """) else: jit_cc_src.append(f""" /*{doc_string}*/ // @pyjt({",".join(pyjt_names)}) VarHolder* {cc_func_name}({", ".join(cc_args)}) {{ return new VarHolder(make_{cc_func_name}({", ".join(op_args)})); }} """) need_ir_define = False ir_name = None for pybind_name in pybind_names: if pybind_name.startswith("__") and pybind_name[2:-2] in has_ir: need_ir_define = True assert ir_name is None ir_name = pybind_name[2:-2] if need_ir_define: assert len(cc_args) > 0 and cc_args[0].startswith("VarHolder* ") this = cc_args[0].split()[-1] jit_cc_src.append(f""" // @pyjt(Var.__i{ir_name}__) // @attrs(return_self) VarHolder* i{cc_func_name}({", ".join(cc_args)}) {{ *{this} = make_{cc_func_name}({", ".join(op_args)}); return {this}; }} """) assert len(cc_args) > 1 and cc_args[1].startswith( "VarHolder* "), cc_args r_cc_args = [cc_args[1], cc_args[0]] + cc_args[2:] r_py_args = [py_args[1], py_args[0]] + py_args[2:] jit_cc_src.append(f""" VarHolder* r{cc_func_name}({", ".join(r_cc_args)}) {{ return new VarHolder(make_{cc_func_name}({", ".join(op_args)})); }} """) jit_cc_src = [] jit_headers = "" initer = [] pybind_reg = '(/\\*(.*?)\\*/\\s*)?(//\\s*@pybind\\(([^\\n]*)\\)\\s*)?' pybind_attrs_reg = pybind_reg + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?' for header in op_headers: # xxx_xxx_op name = os.path.basename(header) name = os.path.splitext(name)[0] # xxx_xxx assert name.endswith("_op") func_name = name[:-3] # XxxXxxOp name2 = map(lambda s: s[:1].upper() + s[1:], name.split('_')) name2 = "".join(name2) with open(os.path.join(jittor_path, header), encoding='utf8') as f: src = f.read() # XxxXxxOp(args) res = re.findall(pybind_attrs_reg + '[^~](' + name2 + "\\([^\\n]*\\))", src, re.S) assert len(res) >= 1, "Wrong op args in " + header # registe op cc_name = os.path.join(jittor_path, header[:-2] + ".cc") constructors = [] for i in range(len(res)): name = 'make_' + func_name + '_' * i constructors.append(f"{{ &typeid(&{name}), (void*)&{name} }}") constructors = ",".join(constructors) var_member_reg = r"\n\s*Var\b(.*);" var_member_match = re.findall(var_member_reg, src) var_member_match = " ".join(var_member_match) for c in "*,": var_member_match = var_member_match.replace(c, " ") var_member = var_member_match.split() LOG.vv("var_member_match " + var_member_match) LOG.vv("var_member " + str(var_member)) var_member_src = [ f"VAR_MEMBER_NAME_AND_OFFSET({name}, {name2})" for name in var_member ] var_member_src = ",".join(var_member_src) initer.append( f'\n op_registe({{ "{func_name}", R"({cc_name})", extra_flags, {{{constructors}}}, {{{var_member_src}}} }});' ) for hid, h_def in enumerate(res): h_def = list(h_def) # // @attrs(...) attrs = {} if h_def[4] != "": attrs = pyjt_compiler.parse_attrs(h_def[5]) del h_def[4:6] # /* doc_string */ # // @pybind(bind_name) # XxxXxxOp(args_def) doc_string = h_def[1].strip() h_def = h_def[2:] args_def = h_def[2][len(name2) + 1:-1] bind_name = h_def[1] if bind_name == "": bind_name = func_name if args_def == "": args = [] else: args = list( map(lambda s: s.split()[-1].split('=')[0], args_def.split(','))) # py_args: "arg"_a=default py_args = [] new_args_def = [] new_args = [] # source of convert VarHolder* to Var* vh2v_src = [] more_src = [] for arg, arg_def in zip(args, args_def.split(',')): py_arg = f'"{arg}"_a' if '=' in arg_def: py_arg += "=" + arg_def.split('=')[-1] arg_def = arg_def.split('=')[0] py_args.append(py_arg) arg_type = arg_def[:-(len(arg) + 1)].strip() if arg_type == "Var*": new_args_def.append("VarHolder* " + arg) vh2v_src.append(arg + "->var") new_args.append(arg + "->var") elif arg_type.startswith("vector<Var*>"): new_args_def.append( arg_type.replace("Var", "VarHolder") + ' ' + arg) new_args.append(arg) more_src.append(f"_op->add_inputs({arg});") else: new_args_def.append(arg_def) new_args.append(arg) vh2v_src = "_op->set_inputs({" + ", ".join(vh2v_src) + "});" + \ "".join(more_src) LOG.vvvv(f"Find op: {name2} args: {new_args}") if header.startswith("src/"): jit_headers += f"#include \"{header[4:]}\"\n" else: jit_headers += f"#include \"{header}\"\n" add_src(func_name + '_' * hid, new_args_def, name2, new_args, vh2v_src, bind_name, py_args, jit_cc_src, doc_string, attrs) if func_name in ["binary", "unary", "reduce"]: # generate binary op alias with open(os.path.join(jittor_path, f"src/ops/{func_name}_op.cc"), encoding="utf-8") as f: src = f.read() src = src.split(f"unordered_set<string> {func_name}_ops = " "{")[1].split("};")[0] res2 = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src, re.S) # remove /* doc_string */ pattern res2 = [(_[3], _[4]) for _ in res2] LOG.vvvv(f"All supported {func_name} ops: {res2}") # remove op args if func_name == "reduce": args_def = new_args_def[:1] + new_args_def[2:] py_args_s = py_args[:1] + py_args[2:] else: args_def = new_args_def[:-1] py_args_s = py_args[:-1] # find the last type id(float64) # add "_" suffix for all function if func_name == "unary": last_tid = res2.index(("", "float64")) # for each functor for tid, (bind_name, func_name2) in enumerate(res2): # add _ for types if func_name == "unary" and tid <= last_tid: func_name3 = func_name2 + "_" elif func_name == "reduce": func_name4 = func_name2 func_name2 = "reduce_" + func_name2 func_name3 = func_name2 else: func_name3 = func_name2 if len(bind_name) == 0: bind_name = func_name2 if func_name == "reduce": args = new_args[:1] + [f'ns_{func_name4}' ] + new_args[2:] else: args = new_args[:-1] + [f'ns_{func_name2}'] add_src(func_name3 + '_' * hid, args_def, name2, args, vh2v_src, bind_name, py_args_s, jit_cc_src, doc_string, attrs) jit_src = f""" #pragma once #include "pyjt/py_obj_holder.h" #include "var.h" #include "var_holder.h" #include "ops/op_register.h" {jit_headers} namespace jittor {{ // fix make_array(py::array) undefine reference #pragma GCC visibility push(default) #define JIT_NAMESPACE {export+"_maker" if export else "jit_op_maker"} // @pyjt(ops) // @attrs(submodule{",core_name="+export if export else ""}) namespace JIT_NAMESPACE {{ {"".join(jit_cc_src)} void initer() {{ string extra_flags = R"({extra_flags})"; {"".join(initer)} }} int caller = (initer(), 0); }} // JIT_NAMESPACE }} // jittor {f''' namespace jittor {{ extern void pyjt_def_{export}(PyObject*); }} static void init_module(PyModuleDef* mdef, PyObject* m) {{ mdef->m_doc = "User defined custom ops"; jittor::pyjt_def_{export}(m); }} PYJF_MODULE_INIT({export}); ''' if export else ""} """ return jit_src
# nvcc warning is noise nvcc_flags += " -w " nvcc_flags += f" -I'{os.path.join(jittor_path, 'extern/cuda/inc')}' " if os.environ.get("cuda_debug", "0") == "1": nvcc_flags += " -G " return nvcc_flags nvcc_flags = convert_nvcc_flags(nvcc_flags) # build core gen_jit_flags() gen_jit_tests() op_headers = run_cmd('find -L src/ops/ | grep "op.h$"', jittor_path).splitlines() jit_src = gen_jit_op_maker(op_headers) LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w') as f: f.write(jit_src) cc_flags += f' -I{cache_path} ' # gen pyjt pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path) # initialize order: # 1. registers # 2. generate source # 3. op_utils # 4. other files2 = pyjt_gen_src files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines() at_beginning = [ "src/ops/op_utils.cc",
def run_cmd(cmd): LOG.vvvv(f"Run cmd: {cmd}") assert os.system(cmd) == 0, f"Run cmd failed: {cmd}"