def compile(cache_path, jittor_path): headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines() headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines() headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \ [ os.path.join(cache_path, h) for h in headers2 ] basenames = [] pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() # jit_op_maker.h merge compile with var_holder.h if h.endswith("src/var_holder.h"): continue if h.endswith("jit_op_maker.h"): with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = h.split("/")[-1].split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code) pyjt_names.append(fname) return pyjt_names
def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename, h, class_info): # func_head is a string like: # (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* lib_name = os.path.basename(h).split("_")[0] # TODO: fix/add var help if target_scope_name == "Var": target_scope_name = None if target_scope_name: if target_scope_name == "flags": help_name = "flags" else: help_name = "" + target_scope_name + '.' + name else: help_name = name if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]: help_name = lib_name + '.' + help_name help_cmd = f"help(jt.{help_name})" LOG.vvv("gen err from func_head", func_head) args = func_head[1:].split(")")[0].split(",") error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\"" error_code += r' << "\n\nTypes of your inputs are:\n"' for arg in args: arg = arg.strip() if arg.startswith("PyObject* "): t, n = arg.split(' ') if n == "args" or n == "_args": error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} " elif n == "kw": error_code += f" << PyKwArgPrinter{{{n}}} " else: error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} " elif arg.startswith("PyObject** "): t, n = arg.split(' ') error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} " break else: LOG.vvv("Unhandled arg", arg) LOG.vvv("gen err from func_head", func_head, " -> ", error_code) return error_code
def compile_custom_ops(filenames, extra_flags="", return_module=False, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND, gen_name_=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return_module: return module rather than ops(default: False) return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] pyjt_includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith( ".cu"): builds.append(name) if name.endswith(".h"): dirname = os.path.dirname(name) if dirname.endswith("inc"): includes.append(dirname) with open(name, "r") as f: if "@pyjt" in f.read(): pyjt_includes.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if gen_name_ != "": gen_name = gen_name_ if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h") gen_lib = os.path.join("custom_ops", gen_name + extension_suffix) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) # gen src initialize first builds.insert(0, gen_src_fname) def insert_anchor(gen_src, anchor_str, insert_str): # insert insert_str after anchor_str into gen_src return gen_src.replace(anchor_str, anchor_str + insert_str, 1) for name in pyjt_includes: LOG.i("handle pyjt_include", name) bname = name.split("/")[-1].split(".")[0] gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + "_" + bname + ".cc") pyjt_compiler.compile_single(name, gen_src_fname) builds.insert(1, gen_src_fname) gen_src = insert_anchor(gen_src, "namespace jittor {", f"extern void pyjt_def_{bname}(PyObject* m);") gen_src = insert_anchor( gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {", f"jittor::pyjt_def_{bname}(m);") with open(gen_head_fname, "w") as f: f.write(gen_src) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) # unlock scope when initialize with lock.unlock_scope(): with jit_utils.import_scope(dlopen_flags): exec(f"import {gen_name}") mod = locals()[gen_name] if return_module: return mod return mod.ops
def compile_src(src, h, basename): res = list(reg.finditer(src, re.S)) if len(res)==0: return class_ranges = None class_name = None class_info = None submodule_name = None submodule_ranges = None submodule_info = None defs = [] LOG.vv(("find in", h)) for x in res: LOG.vvv((x, x.groups())) g = x.groups() doc = g[1] pyjt = g[3] attrs = g[5] esplit = lambda x: [] if x==None else \ [ a.strip() for a in x.split(",") if len(a.strip()) ] attrs = parse_attrs(attrs) pynames = esplit(pyjt) end = x.end() def find_bc(i): while src[i] not in "({;": i += 1 j = i+1 if src[i]==';': return i, j presum = 1 while True: if src[j] in "({[": presum += 1 elif src[j] in ")}]": presum -= 1 if presum==0: s = src[i]+src[j] assert s in ("()","{}","()"), "braces not match "+s return i, j j += 1 # // @pyjt(DType) # struct DType { # ^ --> a # ..... # } <--- b # or # // @pyjt(hash) # inline uint hash(const char* input) # ^ --> a ^ --> b a, b = find_bc(end) is_property = 0 if src[a] == ';': # This case # class XXX { # // @pyjt(property) # T property; # } is_property = 1 if src[a] == '{': assert len(pynames)==1 if "submodule" in attrs: assert submodule_ranges==None submodule_ranges = (a, b) submodule_name = src[end:a-1].strip().split()[-1] submodule_info = { "pynames": pynames, "attrs": attrs } continue assert class_ranges==None class_ranges = (a, b) class_name = src[end:a-1].strip().split()[-1] class_info = { "pynames": pynames, "attrs": attrs } continue is_scope_def = False is_static = False scope_name = "" if class_ranges != None: if class_ranges[0] < a and a < class_ranges[1]: is_scope_def = True scope_name = class_name if submodule_ranges != None: if submodule_ranges[0] < a and a < submodule_ranges[1]: is_scope_def = True scope_name = submodule_name is_static = True dec = src[end:b+1].strip() arr = src[end:a].strip().split() func_name = arr[-1] is_constructor = False if is_scope_def and func_name==class_name: is_constructor = True args = [] for arg in split_args(src[a+1:b]): if arg=="": continue default = "" if "=" in arg: arg, default = arg.split('=') default = default arg = arg.strip() name = arg.split(' ')[-1] tp = arg[:-len(name)] tp = tp.strip() prev_tp = tp # const string& ----> string if tp.startswith("const") and tp.endswith("&"): tp = tp[5:-1].strip() # T&& -> T if tp.endswith("&&"): tp = tp[:-2].strip() # ArrayArgs& -> ArrayArgs if tp.endswith("&"): tp = tp[:-1].strip() args.append((tp, name.strip(), default.strip(), prev_tp)) return_t = "" for a in arr[:-1]: if a in ["", "inline", "constexpr"]: continue if a == "static": is_static = True continue if return_t != "": return_t += " " return_t += a if is_scope_def and class_info and "submodule" in class_info["attrs"]: is_static = True for pid, pyname in enumerate(pynames): for rname in [ "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: if pyname.endswith(rname): attrs[rname] = 1 pynames[pid] = pyname.replace(rname, "__richcmp__") def_info = { "is_scope_def": is_scope_def, "is_constructor": is_constructor, "is_static": is_static, "is_property": is_property, "func_name": func_name, "args": args, # [(type,name,defaut), ...] "return_t": return_t, # return type "dec": dec, # full string of xxx(A a, B b) "pynames": pynames, # names in @pyjt(...) "attrs": attrs, # attrs in @attrs(...) "doc": doc, "scope_name": scope_name, } if is_property: # This case # class XXX { # // @pyjt(property) # T property; # } assert is_scope_def and not is_static def_info["is_property"] = 1 def_info["pynames"] = ["__get__"+n for n in pynames] assert return_t != "void" defs.append(dict(def_info)) def_info["pynames"] = ["__set__"+n for n in pynames] assert len(args) == 0 def_info["args"] = [(def_info["return_t"], func_name, "", "")] def_info["return_t"] = "void" defs.append(dict(def_info)) continue else: defs.append(def_info) LOG.vvv(json.dumps(def_info, indent=4)) # deal with defs if len(defs) == 0: return # include_name = h[4:] # remove "src/" prefix include_name = h code = [] class_defs_code = [] class_getsets_code = [] class_gets = OrderedDict() class_sets = OrderedDict() class_slots_code = [] submodule_defs_code = [] def_targets = OrderedDict() for df in defs: for name in df["pynames"]: if df["is_scope_def"] and '.' not in name: if df["scope_name"] == class_name: name = class_info["pynames"][0] + '.' + name else: name = submodule_info["pynames"][0] + '.' + name if name not in def_targets: def_targets[name] = [] def_targets[name].append(df) for name in def_targets: dfs = def_targets[name] target_scope_name = None LOG.vv(name) if "." in name: target_scope_name, name = name.split(".") # array for each df: arr_func_quick_check_runable = [] arr_func_args_convert = [] arr_fill_with_default = [] arr_func_call = [] arr_has_return = [] self_as_arg0 = False for df in dfs: self_as_arg0 = class_info and \ target_scope_name == class_info["pynames"][0] and \ df["scope_name"] == submodule_name \ and not name.startswith("__") res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0)) arr_func_quick_check_runable.append(res[0]) arr_func_args_convert.append(res[1]) arr_fill_with_default.append(res[2]) arr_func_call.append(res[3]) arr_has_return.append(res[4]) slot_name = None func_cast = "" func_fill = "" if name == "__init__": slot_name = "tp_init" func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int" func_fill = """ int64 n = Py_SIZE(_args); auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); (void)n, (void)args; // TODO: support kw CHECK(kw==0); """ elif name == "__repr__": slot_name = "tp_repr" func_head = "(PyObject* self) -> PyObject*" func_fill = "int64 n = 0; (void)n;" elif name.startswith("__get__"): slot_name = "tp_gets" name = name[len("__get__"):] func_head = "(PyObject* self, void*) -> PyObject*" func_fill = "int64 n = 0; (void)n;" elif name.startswith("__set__"): slot_name = "tp_sets" name = name[len("__set__"):] func_head = "(PyObject* self, PyObject* arg, void*) -> int" func_fill = """ int64 n=1; PyObject** args = &arg; (void)n, (void)args; """ elif name == "__call__": slot_name = "tp_call" func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*" func_fill = """ int64 n = Py_SIZE(_args); auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); (void)n, (void)args; // TODO: support kw CHECK(kw==0); """ elif name == "__dealloc__": slot_name = "tp_dealloc" func_head = "(PyObject* self) -> void" func_fill = "int64 n = 0" elif name in binary_number_slots: slot_name = "tp_as_number->"+binary_number_slots[name] func_head = "(PyObject* self, PyObject* b) -> PyObject*" if name.endswith("pow__"): func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*" func_fill = """ int64 n = 2; PyObject* args[] = {self, b}; (void)n, (void)args; """ elif name in unary_number_slots: slot_name = "tp_as_number->"+unary_number_slots[name] func_head = "(PyObject* self) -> PyObject*" func_fill = """ int64 n = 1; PyObject* args[] = {self}; (void)n, (void)args; """ elif name == "__richcmp__": slot_name = "tp_richcompare" func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*" func_fill = """ int64 n = 2; PyObject* args[] = {self, b}; (void)n, (void)args; """ elif name == "__len__": slot_name = "tp_as_sequence->sq_length" func_head = "(PyObject* self) -> Py_ssize_t" func_fill = """ int64 n = 0; (void)n; """ elif name == "__map_len__": slot_name = "tp_as_mapping->mp_length" func_head = "(PyObject* self) -> Py_ssize_t" func_fill = """ int64 n = 0; (void)n; """ elif name == "__getitem__": slot_name = "tp_as_sequence->sq_item" func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*" func_fill = f""" int64 n = 1; (void)n; if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{ PyErr_SetString(PyExc_IndexError, ""); return 0; }} """ elif name == "__map_getitem__": slot_name = "tp_as_mapping->mp_subscript" func_head = "(PyObject* self, PyObject* arg0) -> PyObject*" func_fill = f""" int64 n = 1; PyObject* args[] = {{arg0}}; (void)n; """ elif name.startswith("__"): LOG.f(f"Not support slot {name}") continue else: func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*" func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))" # if not return, return py_none arr_has_return = [ True for _ in arr_has_return ] arr_func_return = [] doc_all = "" decs = "Declarations:\n" for did, has_return in enumerate(arr_has_return): df = dfs[did] func_call = arr_func_call[did] if df["doc"]: doc_all += "Document:\n" doc_all += df["doc"] doc_all += "\nDeclaration:\n" doc_all += df["dec"] decs += df["dec"]+'\n' if has_return: assert "-> int" not in func_head if "-> PyObject*" in func_head: if "return_self" in df["attrs"]: arr_func_return.append( f"return (({func_call}), Py_INCREF(self), self)") else: arr_func_return.append( f"return {get_pytype_map(df['return_t'],1)}(({func_call}))") func_return_failed = "return nullptr" else: arr_func_return.append( f"return ({func_call});") func_return_failed = "return -1" else: if "-> int" in func_head: arr_func_return.append(f"return ({func_call},0)") func_return_failed = "return -1" else: assert "-> void" in func_head arr_func_return.append(f"{func_call};return") func_return_failed = "return" func = f""" {func_cast}[]{func_head} {{ try {{ {func_fill}; uint64 arg_filled=0; (void)arg_filled; {"".join([f''' if ({arr_func_quick_check_runable[did]}) {{ {arr_func_args_convert[did]}; {arr_fill_with_default[did]}; {arr_func_return[did]}; }} ''' for did in range(len(arr_func_return)) ])} LOGf << "Not a valid call"; }} catch (const std::exception& e) {{ PyErr_Format(PyExc_RuntimeError, "%s\\n%s", e.what(), R""({decs})"" ); }} {func_return_failed}; }} """ if slot_name: if slot_name=="tp_gets": class_gets[name] = { "func": func, "doc": doc_all } continue if slot_name=="tp_sets": class_sets[name] = { "func": func, "doc": "" } continue class_slots_code.append(f""" tp.{slot_name} = {func}; """) continue need_static = "" if df["is_scope_def"] and df["is_static"] and \ df["scope_name"] == class_name and \ "submodule" not in class_info["attrs"]: need_static = " | METH_STATIC" func = (f""" {{ R""({name})"", {func}, METH_FASTCALL | METH_KEYWORDS{need_static}, R""({doc_all})"" }}""") if df["is_scope_def"]: if df["scope_name"] == class_name or \ (class_info and \ target_scope_name == class_info["pynames"][0]): class_defs_code.append(func) else: submodule_defs_code.append(func) else: code.append(func) prop_names = list(set(class_gets.keys()).union(class_sets.keys())) prop_names = sorted(prop_names) for prop_name in prop_names: get_func = "NULL" set_func = "NULL" doc = "" if prop_name in class_gets: get_func = class_gets[prop_name]["func"] if class_gets[prop_name]["doc"]: doc += class_gets[prop_name]["doc"] if prop_name in class_sets: set_func = class_sets[prop_name]["func"] if class_sets[prop_name]["doc"]: doc += class_sets[prop_name]["doc"] class_getsets_code.append(f""" {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}} """) code.append("{0,0,0,0}") class_defs_code.append("{0,0,0,0}") class_getsets_code.append("{0,0,0,0}") submodule_defs_code.append("{0,0,0,0}") core_name = "jittor_core" if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]: core_name = class_info["attrs"]["core_name"] if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]: core_name = submodule_info["attrs"]["core_name"] has_map = class_name in ["VarHolder", "NanoVector"] has_seq = class_name == "NanoVector" code = f""" #include "pyjt/py_converter.h" #include "common.h" #include "{include_name}" namespace jittor {{ { "" if class_name is None else f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else f"PyTypeObject Pyjt{class_name};" } void pyjt_def_{basename}(PyObject* m) {{ static PyMethodDef defs[] = {{ {",".join(code)} }}; ASSERT(PyModule_AddFunctions(m, defs)==0); { f''' static PyMethodDef class_defs[] = {{ {",".join(class_defs_code)} }}; static PyGetSetDef class_getsets[] = {{ {",".join(class_getsets_code)} }}; static PyNumberMethods number_methods = {{0}}; {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;" if "heaptype" in class_info["attrs"] else f"auto& tp = Pyjt{class_name};"} tp.tp_as_number = &number_methods; {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""} {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""} {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""} {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""} tp.tp_name = "{core_name}.{class_info["pynames"][0]}"; tp.tp_basicsize = GET_OBJ_SIZE({class_name}); tp.tp_new = PyType_GenericNew; tp.tp_flags = Py_TPFLAGS_DEFAULT; {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);" if "heaptype" in class_info["attrs"] else ""} tp.tp_methods = &class_defs[0]; tp.tp_getset = &class_getsets[0]; {"".join(class_slots_code)}; ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0); Py_INCREF(&tp); ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp)); ''' if class_name is not None else "" } {f''' // sub module def static PyMethodDef submodule_defs[] = {{ {",".join(submodule_defs_code)} }}; auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}"); ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0); ASSERT(sub); ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub)); ''' if submodule_name is not None else "" } }} }} """ return code
def compile_custom_ops(filenames, extra_flags=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith( ".cu"): builds.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h") gen_lib = os.path.join("custom_ops", gen_name + extension_suffix) with open(gen_head_fname, "w") as f: f.write(gen_src) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname) # gen src initialize first builds.insert(0, gen_src_fname) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, cc_flags + opt_flags + includes + extra_flags, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): exec(f"import {gen_name}") return (locals()[gen_name]).ops