예제 #1
0
def compile_src(src, h, basename):
    res = list(reg.finditer(src, re.S))
    if len(res)==0: return
    class_ranges = None
    class_name = None
    class_info = None
    submodule_name = None
    submodule_ranges = None
    submodule_info = None
    defs = []
    LOG.vv(("find in", h))
    for x in res:
        LOG.vvv((x, x.groups()))
        g = x.groups()
        doc = g[1]
        pyjt = g[3]
        attrs = g[5]
        esplit = lambda x: [] if x==None else \
            [ a.strip() for a in x.split(",") if len(a.strip()) ]
        attrs = parse_attrs(attrs)
        pynames = esplit(pyjt)
        end = x.end()
        def find_bc(i):
            while src[i] not in "({;":
                i += 1
            j = i+1
            if src[i]==';':
                return i, j
            presum = 1
            while True:
                if src[j] in "({[":
                    presum += 1
                elif src[j] in ")}]":
                    presum -= 1
                    if presum==0:
                        s = src[i]+src[j]
                        assert s in ("()","{}","()"), "braces not match "+s
                        return i, j
                j += 1
        # // @pyjt(DType)
        # struct DType {
        #              ^ --> a
        #     .....
        # } <--- b
        # or
        # // @pyjt(hash)
        # inline uint hash(const char* input)
        #                 ^ --> a           ^ --> b
        a, b = find_bc(end)
        is_property = 0
        if src[a] == ';':
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            is_property = 1
        if src[a] == '{':
            assert len(pynames)==1
            if "submodule" in attrs:
                assert submodule_ranges==None
                submodule_ranges = (a, b)
                submodule_name = src[end:a-1].strip().split()[-1]
                submodule_info = {
                    "pynames": pynames,
                    "attrs": attrs
                }
                continue
            assert class_ranges==None
            class_ranges = (a, b)
            class_name = src[end:a-1].strip().split()[-1]
            class_info = {
                "pynames": pynames,
                "attrs": attrs
            }
            continue
        is_scope_def = False
        is_static = False
        scope_name = ""
        if class_ranges != None:
            if class_ranges[0] < a and a < class_ranges[1]:
                is_scope_def = True
                scope_name = class_name
        if submodule_ranges != None:
            if submodule_ranges[0] < a and a < submodule_ranges[1]:
                is_scope_def = True
                scope_name = submodule_name
                is_static = True
        dec = src[end:b+1].strip()
        arr = src[end:a].strip().split()
        func_name = arr[-1]

        is_constructor = False
        if is_scope_def and func_name==class_name:
            is_constructor = True

        args = []
        for arg in split_args(src[a+1:b]):
            if arg=="": continue
            default = ""
            if "=" in arg:
                arg, default = arg.split('=')
                default = default
            arg = arg.strip()
            name = arg.split(' ')[-1]
            tp = arg[:-len(name)]
            tp = tp.strip()
            prev_tp = tp
            # const string& ----> string
            if tp.startswith("const") and tp.endswith("&"):
                tp = tp[5:-1].strip()
            # T&& -> T
            if tp.endswith("&&"):
                tp = tp[:-2].strip()
            # ArrayArgs& -> ArrayArgs
            if tp.endswith("&"):
                tp = tp[:-1].strip()
            args.append((tp, name.strip(), default.strip(), prev_tp))
        return_t = ""
        for a in arr[:-1]:
            if a in ["", "inline", "constexpr"]: continue
            if a == "static":
                is_static = True
                continue
            if return_t != "": return_t += " "
            return_t += a

        if is_scope_def and class_info and "submodule" in class_info["attrs"]:
            is_static = True

        for pid, pyname in enumerate(pynames):
            for rname in [ "__lt__", "__le__", "__gt__",
                "__ge__", "__eq__", "__ne__"]:
                if pyname.endswith(rname):
                    attrs[rname] = 1
                    pynames[pid] = pyname.replace(rname, "__richcmp__")

        def_info = {
            "is_scope_def": is_scope_def,
            "is_constructor": is_constructor,
            "is_static": is_static,
            "is_property": is_property,
            "func_name": func_name,
            "args": args, # [(type,name,defaut), ...]
            "return_t": return_t, # return type
            "dec": dec, # full string of xxx(A a, B b)
            "pynames": pynames, # names in @pyjt(...)
            "attrs": attrs, # attrs in @attrs(...)
            "doc": doc,
            "scope_name": scope_name,
        }
        if is_property:
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            assert is_scope_def and not is_static
            def_info["is_property"] = 1
            def_info["pynames"] = ["__get__"+n for n in pynames]
            assert return_t != "void"
            defs.append(dict(def_info))
            def_info["pynames"] = ["__set__"+n for n in pynames]
            assert len(args) == 0
            def_info["args"] = [(def_info["return_t"], func_name, "", "")]
            def_info["return_t"] = "void"
            defs.append(dict(def_info))
            continue
        else:
            defs.append(def_info)
        LOG.vvv(json.dumps(def_info, indent=4))
    # deal with defs
    if len(defs) == 0: return
    # include_name = h[4:] # remove "src/" prefix
    include_name = h
    code = []
    class_defs_code = []
    class_getsets_code = []
    class_gets = OrderedDict()
    class_sets = OrderedDict()
    class_slots_code = []
    submodule_defs_code = []
    def_targets = OrderedDict()
    for df in defs:
        for name in df["pynames"]:
            if df["is_scope_def"] and '.' not in name:
                if df["scope_name"] == class_name:
                    name = class_info["pynames"][0] + '.' + name
                else:
                    name = submodule_info["pynames"][0] + '.' + name
            if name not in def_targets:
                def_targets[name] = []
            def_targets[name].append(df)
    for name in def_targets:
        dfs = def_targets[name]
        target_scope_name = None
        LOG.vv(name)
        if "." in name:
            target_scope_name, name = name.split(".")
        # array for each df:
        arr_func_quick_check_runable = []
        arr_func_args_convert = []
        arr_fill_with_default = []
        arr_func_call = []
        arr_has_return = []
        self_as_arg0 = False
        for df in dfs:
            self_as_arg0 = class_info and \
                target_scope_name == class_info["pynames"][0] and \
                df["scope_name"] == submodule_name \
                and not name.startswith("__")
            res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0))
            arr_func_quick_check_runable.append(res[0])
            arr_func_args_convert.append(res[1])
            arr_fill_with_default.append(res[2])
            arr_func_call.append(res[3])
            arr_has_return.append(res[4])
            
        slot_name = None
        func_cast = ""
        func_fill = ""
        if name == "__init__":
            slot_name = "tp_init"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__repr__":
            slot_name = "tp_repr"
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__get__"):
            slot_name = "tp_gets"
            name = name[len("__get__"):]
            func_head = "(PyObject* self, void*) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__set__"):
            slot_name = "tp_sets"
            name = name[len("__set__"):]
            func_head = "(PyObject* self, PyObject* arg, void*) -> int"
            func_fill = """
                int64 n=1;
                PyObject** args = &arg;
                (void)n, (void)args;
            """

        elif name == "__call__":
            slot_name = "tp_call"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__dealloc__":
            slot_name = "tp_dealloc"
            func_head = "(PyObject* self) -> void"
            func_fill = "int64 n = 0"
        
        elif name in binary_number_slots:
            slot_name = "tp_as_number->"+binary_number_slots[name]
            func_head = "(PyObject* self, PyObject* b) -> PyObject*"
            if name.endswith("pow__"):
                func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """
        
        elif name in unary_number_slots:
            slot_name = "tp_as_number->"+unary_number_slots[name]
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = """
                int64 n = 1;
                PyObject* args[] = {self};
                (void)n, (void)args;
            """
        
        elif name == "__richcmp__":
            slot_name = "tp_richcompare"
            func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """

        elif name == "__len__":
            slot_name = "tp_as_sequence->sq_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__map_len__":
            slot_name = "tp_as_mapping->mp_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__getitem__":
            slot_name = "tp_as_sequence->sq_item"
            func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                (void)n;
                if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
                    PyErr_SetString(PyExc_IndexError, "");
                    return 0;
                }}
            """

        elif name == "__map_getitem__":
            slot_name = "tp_as_mapping->mp_subscript"
            func_head = "(PyObject* self, PyObject* arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                PyObject* args[] = {{arg0}};
                (void)n;
            """

        elif name.startswith("__"):
            LOG.f(f"Not support slot {name}")
            continue

        else:
            func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*"
            func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))"
            # if not return, return py_none
            arr_has_return = [ True for _ in arr_has_return ]

        arr_func_return = []
        doc_all = ""
        decs = "Declarations:\n"
        for did, has_return in enumerate(arr_has_return):
            df = dfs[did]
            func_call = arr_func_call[did]
            if df["doc"]:
                doc_all += "Document:\n"
                doc_all += df["doc"]
            doc_all += "\nDeclaration:\n"
            doc_all += df["dec"]
            decs += df["dec"]+'\n'
            if has_return:
                assert "-> int" not in func_head
                if "-> PyObject*" in func_head:
                    if "return_self" in df["attrs"]:
                        arr_func_return.append(
                            f"return (({func_call}), Py_INCREF(self), self)")
                    else:
                        arr_func_return.append(
                            f"return {get_pytype_map(df['return_t'],1)}(({func_call}))")
                    func_return_failed = "return nullptr"
                else:
                    arr_func_return.append(
                        f"return ({func_call});")
                    func_return_failed = "return -1"
            else:
                if "-> int" in func_head:
                    arr_func_return.append(f"return ({func_call},0)")
                    func_return_failed = "return -1"
                else:
                    assert "-> void" in func_head
                    arr_func_return.append(f"{func_call};return")
                    func_return_failed = "return"
        func = f"""
        {func_cast}[]{func_head} {{
            try {{
                {func_fill};
                uint64 arg_filled=0;
                (void)arg_filled;
                {"".join([f'''
                if ({arr_func_quick_check_runable[did]}) {{
                    {arr_func_args_convert[did]};
                    {arr_fill_with_default[did]};
                    {arr_func_return[did]};
                }}
                '''
                for did in range(len(arr_func_return))
                ])}
                LOGf << "Not a valid call";
            }} catch (const std::exception& e) {{
                PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
                    e.what(),
                    R""({decs})""
                );
            }}
            {func_return_failed};
        }}
        """

        if slot_name:
            if slot_name=="tp_gets":
                class_gets[name] = {
                    "func": func,
                    "doc": doc_all
                }
                continue
            if slot_name=="tp_sets":
                class_sets[name] = {
                    "func": func,
                    "doc": ""
                }
                continue
            class_slots_code.append(f"""
            tp.{slot_name} = {func};
            """)
            continue
        need_static = ""
        if df["is_scope_def"] and df["is_static"] and \
            df["scope_name"] == class_name and \
            "submodule" not in class_info["attrs"]:
            need_static = " | METH_STATIC"
        func = (f"""
        {{ R""({name})"",
        {func},
        METH_FASTCALL | METH_KEYWORDS{need_static},
        R""({doc_all})""
        }}""")
        if df["is_scope_def"]:
            if df["scope_name"] == class_name or \
                (class_info and \
                    target_scope_name == class_info["pynames"][0]):
                class_defs_code.append(func)
            else:
                submodule_defs_code.append(func)
        else:
            code.append(func)
    prop_names = list(set(class_gets.keys()).union(class_sets.keys()))
    prop_names = sorted(prop_names)
    for prop_name in prop_names:
        get_func = "NULL"
        set_func = "NULL"
        doc = ""
        if prop_name in class_gets:
            get_func = class_gets[prop_name]["func"]
            if class_gets[prop_name]["doc"]:
                doc += class_gets[prop_name]["doc"]
        if prop_name in class_sets:
            set_func = class_sets[prop_name]["func"]
            if class_sets[prop_name]["doc"]:
                doc += class_sets[prop_name]["doc"]
        class_getsets_code.append(f"""
            {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}}
        """)
    code.append("{0,0,0,0}")
    class_defs_code.append("{0,0,0,0}")
    class_getsets_code.append("{0,0,0,0}")
    submodule_defs_code.append("{0,0,0,0}")
    core_name = "jittor_core"
    if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]:
        core_name = class_info["attrs"]["core_name"]
    if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
        core_name = submodule_info["attrs"]["core_name"]
    has_map = class_name in ["VarHolder", "NanoVector"]
    has_seq = class_name == "NanoVector"
    code = f"""
    #include "pyjt/py_converter.h"
    #include "common.h"
    #include "{include_name}"

    namespace jittor {{

    {
    "" if class_name is None else
    f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else
    f"PyTypeObject Pyjt{class_name};"
    }
    
    void pyjt_def_{basename}(PyObject* m) {{
        static PyMethodDef defs[] = {{
            {",".join(code)}
        }};
        ASSERT(PyModule_AddFunctions(m, defs)==0);
        {
        f'''
        static PyMethodDef class_defs[] = {{
            {",".join(class_defs_code)}
        }};
        static PyGetSetDef class_getsets[] = {{
            {",".join(class_getsets_code)}
        }};

        static PyNumberMethods number_methods = {{0}};
        {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;"
        if "heaptype" in class_info["attrs"] else
        f"auto& tp = Pyjt{class_name};"}
        tp.tp_as_number = &number_methods;

        {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""}
        {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""}

        {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""}
        {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""}
        
        tp.tp_name = "{core_name}.{class_info["pynames"][0]}";
        tp.tp_basicsize = GET_OBJ_SIZE({class_name});
        tp.tp_new = PyType_GenericNew;
        tp.tp_flags = Py_TPFLAGS_DEFAULT;
        {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);"
        if "heaptype" in class_info["attrs"] else ""}
        tp.tp_methods = &class_defs[0];
        tp.tp_getset = &class_getsets[0];
        {"".join(class_slots_code)};
        ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0);
        Py_INCREF(&tp);
        ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp));
        ''' if class_name is not None else ""
        }
        {f'''

        // sub module def
        static PyMethodDef submodule_defs[] = {{
            {",".join(submodule_defs_code)}
        }};
        auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}");
        ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0);
        ASSERT(sub);
        ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub));
        ''' if submodule_name is not None else ""
        }

    }}

    }}
    """
    return code
예제 #2
0
def search_file(dirs, name):
    for d in dirs:
        fname = os.path.join(d, name)
        if os.path.isfile(fname):
            return fname
    LOG.f(f"file {name} not found in {dirs}")