def compile_src(src, h, basename): res = list(reg.finditer(src, re.S)) if len(res)==0: return class_ranges = None class_name = None class_info = None submodule_name = None submodule_ranges = None submodule_info = None defs = [] LOG.vv(("find in", h)) for x in res: LOG.vvv((x, x.groups())) g = x.groups() doc = g[1] pyjt = g[3] attrs = g[5] esplit = lambda x: [] if x==None else \ [ a.strip() for a in x.split(",") if len(a.strip()) ] attrs = parse_attrs(attrs) pynames = esplit(pyjt) end = x.end() def find_bc(i): while src[i] not in "({;": i += 1 j = i+1 if src[i]==';': return i, j presum = 1 while True: if src[j] in "({[": presum += 1 elif src[j] in ")}]": presum -= 1 if presum==0: s = src[i]+src[j] assert s in ("()","{}","()"), "braces not match "+s return i, j j += 1 # // @pyjt(DType) # struct DType { # ^ --> a # ..... # } <--- b # or # // @pyjt(hash) # inline uint hash(const char* input) # ^ --> a ^ --> b a, b = find_bc(end) is_property = 0 if src[a] == ';': # This case # class XXX { # // @pyjt(property) # T property; # } is_property = 1 if src[a] == '{': assert len(pynames)==1 if "submodule" in attrs: assert submodule_ranges==None submodule_ranges = (a, b) submodule_name = src[end:a-1].strip().split()[-1] submodule_info = { "pynames": pynames, "attrs": attrs } continue assert class_ranges==None class_ranges = (a, b) class_name = src[end:a-1].strip().split()[-1] class_info = { "pynames": pynames, "attrs": attrs } continue is_scope_def = False is_static = False scope_name = "" if class_ranges != None: if class_ranges[0] < a and a < class_ranges[1]: is_scope_def = True scope_name = class_name if submodule_ranges != None: if submodule_ranges[0] < a and a < submodule_ranges[1]: is_scope_def = True scope_name = submodule_name is_static = True dec = src[end:b+1].strip() arr = src[end:a].strip().split() func_name = arr[-1] is_constructor = False if is_scope_def and func_name==class_name: is_constructor = True args = [] for arg in split_args(src[a+1:b]): if arg=="": continue default = "" if "=" in arg: arg, default = arg.split('=') default = default arg = arg.strip() name = arg.split(' ')[-1] tp = arg[:-len(name)] tp = tp.strip() prev_tp = tp # const string& ----> string if tp.startswith("const") and tp.endswith("&"): tp = tp[5:-1].strip() # T&& -> T if tp.endswith("&&"): tp = tp[:-2].strip() # ArrayArgs& -> ArrayArgs if tp.endswith("&"): tp = tp[:-1].strip() args.append((tp, name.strip(), default.strip(), prev_tp)) return_t = "" for a in arr[:-1]: if a in ["", "inline", "constexpr"]: continue if a == "static": is_static = True continue if return_t != "": return_t += " " return_t += a if is_scope_def and class_info and "submodule" in class_info["attrs"]: is_static = True for pid, pyname in enumerate(pynames): for rname in [ "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]: if pyname.endswith(rname): attrs[rname] = 1 pynames[pid] = pyname.replace(rname, "__richcmp__") def_info = { "is_scope_def": is_scope_def, "is_constructor": is_constructor, "is_static": is_static, "is_property": is_property, "func_name": func_name, "args": args, # [(type,name,defaut), ...] "return_t": return_t, # return type "dec": dec, # full string of xxx(A a, B b) "pynames": pynames, # names in @pyjt(...) "attrs": attrs, # attrs in @attrs(...) "doc": doc, "scope_name": scope_name, } if is_property: # This case # class XXX { # // @pyjt(property) # T property; # } assert is_scope_def and not is_static def_info["is_property"] = 1 def_info["pynames"] = ["__get__"+n for n in pynames] assert return_t != "void" defs.append(dict(def_info)) def_info["pynames"] = ["__set__"+n for n in pynames] assert len(args) == 0 def_info["args"] = [(def_info["return_t"], func_name, "", "")] def_info["return_t"] = "void" defs.append(dict(def_info)) continue else: defs.append(def_info) LOG.vvv(json.dumps(def_info, indent=4)) # deal with defs if len(defs) == 0: return # include_name = h[4:] # remove "src/" prefix include_name = h code = [] class_defs_code = [] class_getsets_code = [] class_gets = OrderedDict() class_sets = OrderedDict() class_slots_code = [] submodule_defs_code = [] def_targets = OrderedDict() for df in defs: for name in df["pynames"]: if df["is_scope_def"] and '.' not in name: if df["scope_name"] == class_name: name = class_info["pynames"][0] + '.' + name else: name = submodule_info["pynames"][0] + '.' + name if name not in def_targets: def_targets[name] = [] def_targets[name].append(df) for name in def_targets: dfs = def_targets[name] target_scope_name = None LOG.vv(name) if "." in name: target_scope_name, name = name.split(".") # array for each df: arr_func_quick_check_runable = [] arr_func_args_convert = [] arr_fill_with_default = [] arr_func_call = [] arr_has_return = [] self_as_arg0 = False for df in dfs: self_as_arg0 = class_info and \ target_scope_name == class_info["pynames"][0] and \ df["scope_name"] == submodule_name \ and not name.startswith("__") res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0)) arr_func_quick_check_runable.append(res[0]) arr_func_args_convert.append(res[1]) arr_fill_with_default.append(res[2]) arr_func_call.append(res[3]) arr_has_return.append(res[4]) slot_name = None func_cast = "" func_fill = "" if name == "__init__": slot_name = "tp_init" func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int" func_fill = """ int64 n = Py_SIZE(_args); auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); (void)n, (void)args; // TODO: support kw CHECK(kw==0); """ elif name == "__repr__": slot_name = "tp_repr" func_head = "(PyObject* self) -> PyObject*" func_fill = "int64 n = 0; (void)n;" elif name.startswith("__get__"): slot_name = "tp_gets" name = name[len("__get__"):] func_head = "(PyObject* self, void*) -> PyObject*" func_fill = "int64 n = 0; (void)n;" elif name.startswith("__set__"): slot_name = "tp_sets" name = name[len("__set__"):] func_head = "(PyObject* self, PyObject* arg, void*) -> int" func_fill = """ int64 n=1; PyObject** args = &arg; (void)n, (void)args; """ elif name == "__call__": slot_name = "tp_call" func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*" func_fill = """ int64 n = Py_SIZE(_args); auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); (void)n, (void)args; // TODO: support kw CHECK(kw==0); """ elif name == "__dealloc__": slot_name = "tp_dealloc" func_head = "(PyObject* self) -> void" func_fill = "int64 n = 0" elif name in binary_number_slots: slot_name = "tp_as_number->"+binary_number_slots[name] func_head = "(PyObject* self, PyObject* b) -> PyObject*" if name.endswith("pow__"): func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*" func_fill = """ int64 n = 2; PyObject* args[] = {self, b}; (void)n, (void)args; """ elif name in unary_number_slots: slot_name = "tp_as_number->"+unary_number_slots[name] func_head = "(PyObject* self) -> PyObject*" func_fill = """ int64 n = 1; PyObject* args[] = {self}; (void)n, (void)args; """ elif name == "__richcmp__": slot_name = "tp_richcompare" func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*" func_fill = """ int64 n = 2; PyObject* args[] = {self, b}; (void)n, (void)args; """ elif name == "__len__": slot_name = "tp_as_sequence->sq_length" func_head = "(PyObject* self) -> Py_ssize_t" func_fill = """ int64 n = 0; (void)n; """ elif name == "__map_len__": slot_name = "tp_as_mapping->mp_length" func_head = "(PyObject* self) -> Py_ssize_t" func_fill = """ int64 n = 0; (void)n; """ elif name == "__getitem__": slot_name = "tp_as_sequence->sq_item" func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*" func_fill = f""" int64 n = 1; (void)n; if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{ PyErr_SetString(PyExc_IndexError, ""); return 0; }} """ elif name == "__map_getitem__": slot_name = "tp_as_mapping->mp_subscript" func_head = "(PyObject* self, PyObject* arg0) -> PyObject*" func_fill = f""" int64 n = 1; PyObject* args[] = {{arg0}}; (void)n; """ elif name.startswith("__"): LOG.f(f"Not support slot {name}") continue else: func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*" func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))" # if not return, return py_none arr_has_return = [ True for _ in arr_has_return ] arr_func_return = [] doc_all = "" decs = "Declarations:\n" for did, has_return in enumerate(arr_has_return): df = dfs[did] func_call = arr_func_call[did] if df["doc"]: doc_all += "Document:\n" doc_all += df["doc"] doc_all += "\nDeclaration:\n" doc_all += df["dec"] decs += df["dec"]+'\n' if has_return: assert "-> int" not in func_head if "-> PyObject*" in func_head: if "return_self" in df["attrs"]: arr_func_return.append( f"return (({func_call}), Py_INCREF(self), self)") else: arr_func_return.append( f"return {get_pytype_map(df['return_t'],1)}(({func_call}))") func_return_failed = "return nullptr" else: arr_func_return.append( f"return ({func_call});") func_return_failed = "return -1" else: if "-> int" in func_head: arr_func_return.append(f"return ({func_call},0)") func_return_failed = "return -1" else: assert "-> void" in func_head arr_func_return.append(f"{func_call};return") func_return_failed = "return" func = f""" {func_cast}[]{func_head} {{ try {{ {func_fill}; uint64 arg_filled=0; (void)arg_filled; {"".join([f''' if ({arr_func_quick_check_runable[did]}) {{ {arr_func_args_convert[did]}; {arr_fill_with_default[did]}; {arr_func_return[did]}; }} ''' for did in range(len(arr_func_return)) ])} LOGf << "Not a valid call"; }} catch (const std::exception& e) {{ PyErr_Format(PyExc_RuntimeError, "%s\\n%s", e.what(), R""({decs})"" ); }} {func_return_failed}; }} """ if slot_name: if slot_name=="tp_gets": class_gets[name] = { "func": func, "doc": doc_all } continue if slot_name=="tp_sets": class_sets[name] = { "func": func, "doc": "" } continue class_slots_code.append(f""" tp.{slot_name} = {func}; """) continue need_static = "" if df["is_scope_def"] and df["is_static"] and \ df["scope_name"] == class_name and \ "submodule" not in class_info["attrs"]: need_static = " | METH_STATIC" func = (f""" {{ R""({name})"", {func}, METH_FASTCALL | METH_KEYWORDS{need_static}, R""({doc_all})"" }}""") if df["is_scope_def"]: if df["scope_name"] == class_name or \ (class_info and \ target_scope_name == class_info["pynames"][0]): class_defs_code.append(func) else: submodule_defs_code.append(func) else: code.append(func) prop_names = list(set(class_gets.keys()).union(class_sets.keys())) prop_names = sorted(prop_names) for prop_name in prop_names: get_func = "NULL" set_func = "NULL" doc = "" if prop_name in class_gets: get_func = class_gets[prop_name]["func"] if class_gets[prop_name]["doc"]: doc += class_gets[prop_name]["doc"] if prop_name in class_sets: set_func = class_sets[prop_name]["func"] if class_sets[prop_name]["doc"]: doc += class_sets[prop_name]["doc"] class_getsets_code.append(f""" {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}} """) code.append("{0,0,0,0}") class_defs_code.append("{0,0,0,0}") class_getsets_code.append("{0,0,0,0}") submodule_defs_code.append("{0,0,0,0}") core_name = "jittor_core" if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]: core_name = class_info["attrs"]["core_name"] if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]: core_name = submodule_info["attrs"]["core_name"] has_map = class_name in ["VarHolder", "NanoVector"] has_seq = class_name == "NanoVector" code = f""" #include "pyjt/py_converter.h" #include "common.h" #include "{include_name}" namespace jittor {{ { "" if class_name is None else f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else f"PyTypeObject Pyjt{class_name};" } void pyjt_def_{basename}(PyObject* m) {{ static PyMethodDef defs[] = {{ {",".join(code)} }}; ASSERT(PyModule_AddFunctions(m, defs)==0); { f''' static PyMethodDef class_defs[] = {{ {",".join(class_defs_code)} }}; static PyGetSetDef class_getsets[] = {{ {",".join(class_getsets_code)} }}; static PyNumberMethods number_methods = {{0}}; {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;" if "heaptype" in class_info["attrs"] else f"auto& tp = Pyjt{class_name};"} tp.tp_as_number = &number_methods; {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""} {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""} {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""} {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""} tp.tp_name = "{core_name}.{class_info["pynames"][0]}"; tp.tp_basicsize = GET_OBJ_SIZE({class_name}); tp.tp_new = PyType_GenericNew; tp.tp_flags = Py_TPFLAGS_DEFAULT; {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);" if "heaptype" in class_info["attrs"] else ""} tp.tp_methods = &class_defs[0]; tp.tp_getset = &class_getsets[0]; {"".join(class_slots_code)}; ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0); Py_INCREF(&tp); ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp)); ''' if class_name is not None else "" } {f''' // sub module def static PyMethodDef submodule_defs[] = {{ {",".join(submodule_defs_code)} }}; auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}"); ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0); ASSERT(sub); ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub)); ''' if submodule_name is not None else "" } }} }} """ return code
def search_file(dirs, name): for d in dirs: fname = os.path.join(d, name) if os.path.isfile(fname): return fname LOG.f(f"file {name} not found in {dirs}")