示例#1
0
def compile_extern():
    # compile llvm passes
    if cc_type != "clang":
        return
    global kernel_opt_flags
    cache_path_llvm = os.path.join(cache_path, "llvm")
    jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm")
    clang_dir = os.path.dirname(get_full_path_of_executable(cc_path))
    assert clang_dir.endswith(
        "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}"
    llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include"))
    assert os.path.isdir(llvm_include), "LLVM include path not found"
    make_cache_dir(cache_path_llvm)
    files = os.listdir(jittor_path_llvm)
    # test_pass.cc is used for test link problem of llvm pass plugin
    test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
    with open(test_pass_path, 'w') as f:
        f.write("int main() {return 0;}")

    # -fno-rtti fix link error

    # -Wl,-znodelete fix segfault
    # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287

    # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass
    # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt

    # try different flags
    try_flags = [
        " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ",
        " -Wl,-znodelete ",
    ]
    found_flags_id = -1
    for fname in files:
        for i, flag in enumerate(try_flags):
            if found_flags_id != -1 and found_flags_id != i:
                continue
            so_name = os.path.join(cache_path_llvm,
                                   os.path.splitext(fname)[0] + f".{i}.so")
            compile(cc_path,
                    f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'",
                    [os.path.join(jittor_path_llvm, fname)], so_name)
            # if not found available flags, we test it.
            if found_flags_id == -1:
                try:
                    s = run_cmd(
                        f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}",
                        cache_path_llvm,
                        print_error=False)
                except Exception as e:
                    LOG.v(f"Try flag {flag} failed: {e}")
                    continue
                found_flags_id = i
            kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' "
            break
        else:
            LOG.w("Clang is used, but LLVM pass plugin is unable to link.")
            break
    LOG.vv(f"Compile extern llvm passes: {str(files)}")
示例#2
0
文件: lock.py 项目: lzhengning/jittor
 def unlock(self):
     if fcntl:
         fcntl.flock(self.handle, fcntl.LOCK_UN)
     else:
         hfile = win32file._get_osfhandle(self.handle.fileno())
         win32file.UnlockFileEx(hfile, 0, -0x10000, _OVERLAPPED)
     self.is_locked = False
     LOG.vv(f'UNLOCK PID: {os.getpid()}')
示例#3
0
文件: lock.py 项目: lzhengning/jittor
 def lock(self):
     if fcntl:
         fcntl.flock(self.handle, fcntl.LOCK_EX)
     else:
         hfile = win32file._get_osfhandle(self.handle.fileno())
         win32file.LockFileEx(hfile, 2, 0, -0x10000, _OVERLAPPED)
     self.is_locked = True
     LOG.vv(f'LOCK PID: {os.getpid()}')
示例#4
0
def gen_jit_flags():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)

    flags_defs = []
    visit = {}

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for _, args in defs:
            args = args.split(",")
            type = args[0].strip()
            name = args[1].strip()
            if not has_cuda and "cuda" in name and name != "use_cuda":
                continue
            default = args[2].strip()
            doc = ",".join(args[3:])
            doc = eval(f"({doc})")
            LOG.vv(f"Find define {name} from {src_name}")
            if name in visit:
                continue
            visit[name] = 1
            jit_declares.append(f"DECLARE_FLAG({type}, {name});")
            flags_defs.append(f"""
                /* {name}(type:{type}, default:{default}): {doc} */
                // @pyjt(__get__{name})
                {type} _get_{name}() {{ return {name}; }}
                // @pyjt(__set__{name})
                void _set_{name}({type} v) {{ set_{name}(v); }}
                {f'''// @pyjt(__set__{name})
                void _set_{name}(bool v) {{ set_{name}(v); }}
                ''' if type=="int" else ""}
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #include "utils/flags.h"

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(flags)
    struct _Flags {{
        // @pyjt(__init__)
        _Flags() {{}}
        {"".join(flags_defs)}
    }};

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
        f.write(jit_src)
示例#5
0
def gen_jit_tests():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("JIT_TEST\\((.*?)\\)")
    names = set()
    test_defs = []

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for name in defs:
            LOG.vv(f"Find test {name} from {src_name}")
            assert name not in names, f"Conflict test name {name}"
            names.add(name)
            jit_declares.append(f"JIT_TEST({name});")
            test_defs.append(f"""
                /* From {src_name} */
                // @pyjt({name})
                static inline void test_{name}() {{ jit_test_{name}(); }} 
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #pragma once
    #include "common.h"

    void expect_error(std::function<void()> func) {{
        try {{ func(); }}
        catch (...) {{ return; }}
        CHECK(0) << "Missing error";
    }}

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(tests)
    // @attrs(submodule)
    namespace tests {{
        {"".join(test_defs)}
    }}

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
        f.write(jit_src)
示例#6
0
def gen_jit_op_maker(op_headers, export=False, extra_flags=""):
    def add_src(cc_func_name, cc_args, op_name, op_args, src, pybind_name,
                py_args, jit_cc_src, doc_string, attrs):
        has_ir = set([
            "add", "sub", "mul", "matmul", "truediv", "floordiv", "mod",
            "divmod", "pow", "lshift", "rshift", "and", "xor", "or"
        ])
        pybind_names = [s.strip() for s in pybind_name.split(",")]
        cc_make_args = [arg.replace("VarHolder*", "Var*") for arg in cc_args]
        op_make_args = [arg.replace("->var", "") for arg in op_args]
        py_args = [arg.replace("Var*", "VarHolder*") for arg in py_args]
        op_args = []
        cc_args_with_default = []
        for i, arg in enumerate(cc_args):
            pre_arg = arg.split()[-1].split('=')[0]
            op_arg = None
            if arg.startswith("VarHolder*"):
                op_arg = pre_arg + "->var"
            elif arg.startswith("vector<VarHolder*>"):
                op_arg = f"convert({pre_arg})"
            if "&&" in arg:
                if op_arg == None:
                    op_arg = "move(" + pre_arg + ")"
                op_make_args[i] = "move(" + pre_arg + ")"
            if op_arg == None: op_arg = pre_arg
            op_args.append(op_arg)
            py_arg = py_args[i]
            if "_a=" not in py_arg:
                cc_args_with_default.append(arg)
                continue
            py_arg = py_arg.split("_a=")[1]
            cc_args_with_default.append(arg + "=" + py_arg)
        cc_args = cc_args_with_default
        # steps of Op creation:
        # 1. new op
        # 2. new output var (create_output in op constructor)
        # 3. take over op's output VarPtr from outputs_holder
        # 4. set op's output
        # 5. set op's input
        # 6. infer shape(op->init())
        if "multiple_outputs" not in attrs:
            jit_cc_src.append(f"""
            VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{
                Op* _op = new {op_name}({", ".join(op_make_args)});
                if (_op->outputs_holder.size() != 1) {{
                    delete _op;
                    LOGf << "Wrong output size of" << \"{op_name}\";
                }}
                if (_op->flags.get(NodeFlags::_forwarded)) {{
                    VarPtr output(move(_op->outputs_holder[0]));
                    delete _op;
                    return output;
                }}
                _op->outputs_holder[0]->set_inputs({{_op}});
                VarPtr output(move(_op->outputs_holder[0]));
                {src.replace("->var","")};
                _op->init();
                return output;
            }}
            """)
        else:
            jit_cc_src.append(f"""
            vector<VarPtr> make_{cc_func_name}({", ".join(cc_make_args)}) {{
                Op* _op = new {op_name}({", ".join(op_make_args)});
                if (_op->flags.get(NodeFlags::_forwarded)) {{
                    vector<VarPtr> outputs = move(_op->outputs_holder);
                    delete _op;
                    return outputs;
                }}
                vector<VarPtr> outputs = move(_op->outputs_holder);
                for (uint i=0; i<outputs.size(); i++)
                    outputs[i]->set_inputs({{_op}});
                {src.replace("->var","")};
                _op->init();
                return outputs;
            }}
            """)
        if pybind_name == 'None':
            return
        pyjt_names = []
        for pybind_name in pybind_names:
            if pybind_name.startswith("__"):
                pyjt_names.append("Var." + pybind_name)
            else:
                pyjt_names.append(pybind_name)
                if len(cc_args) > 0 and cc_args[0].startswith("VarHolder* "):
                    pyjt_names.append("Var." + pybind_name)
        if "multiple_outputs" in attrs:
            jit_cc_src.append(f"""
            /*{doc_string}*/
            // @pyjt({",".join(pyjt_names)})
            vector<VarHolder*> {cc_func_name}({", ".join(cc_args)}) {{
                return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)
        else:
            jit_cc_src.append(f"""
            /*{doc_string}*/
            // @pyjt({",".join(pyjt_names)})
            VarHolder* {cc_func_name}({", ".join(cc_args)}) {{
                return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)
        need_ir_define = False
        ir_name = None
        for pybind_name in pybind_names:
            if pybind_name.startswith("__") and pybind_name[2:-2] in has_ir:
                need_ir_define = True
                assert ir_name is None
                ir_name = pybind_name[2:-2]
        if need_ir_define:
            assert len(cc_args) > 0 and cc_args[0].startswith("VarHolder* ")
            this = cc_args[0].split()[-1]
            jit_cc_src.append(f"""
            // @pyjt(Var.__i{ir_name}__)
            // @attrs(return_self)
            VarHolder* i{cc_func_name}({", ".join(cc_args)}) {{
                *{this} = make_{cc_func_name}({", ".join(op_args)});
                return {this};
            }}
            """)
            assert len(cc_args) > 1 and cc_args[1].startswith(
                "VarHolder* "), cc_args
            r_cc_args = [cc_args[1], cc_args[0]] + cc_args[2:]
            r_py_args = [py_args[1], py_args[0]] + py_args[2:]
            jit_cc_src.append(f"""
            VarHolder* r{cc_func_name}({", ".join(r_cc_args)}) {{
                return new VarHolder(make_{cc_func_name}({", ".join(op_args)}));
            }}
            """)

    jit_cc_src = []
    jit_headers = ""
    initer = []
    pybind_reg = '(/\\*(.*?)\\*/\\s*)?(//\\s*@pybind\\(([^\\n]*)\\)\\s*)?'
    pybind_attrs_reg = pybind_reg + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?'
    for header in op_headers:
        # xxx_xxx_op
        name = os.path.basename(header)
        name = os.path.splitext(name)[0]
        # xxx_xxx
        assert name.endswith("_op")
        func_name = name[:-3]
        # XxxXxxOp
        name2 = map(lambda s: s[:1].upper() + s[1:], name.split('_'))
        name2 = "".join(name2)
        with open(os.path.join(jittor_path, header), encoding='utf8') as f:
            src = f.read()
        # XxxXxxOp(args)
        res = re.findall(pybind_attrs_reg + '[^~](' + name2 + "\\([^\\n]*\\))",
                         src, re.S)
        assert len(res) >= 1, "Wrong op args in " + header
        # registe op
        cc_name = os.path.join(jittor_path, header[:-2] + ".cc")
        constructors = []
        for i in range(len(res)):
            name = 'make_' + func_name + '_' * i
            constructors.append(f"{{ &typeid(&{name}), (void*)&{name} }}")
        constructors = ",".join(constructors)
        var_member_reg = r"\n\s*Var\b(.*);"
        var_member_match = re.findall(var_member_reg, src)
        var_member_match = " ".join(var_member_match)
        for c in "*,":
            var_member_match = var_member_match.replace(c, " ")
        var_member = var_member_match.split()
        LOG.vv("var_member_match " + var_member_match)
        LOG.vv("var_member " + str(var_member))
        var_member_src = [
            f"VAR_MEMBER_NAME_AND_OFFSET({name}, {name2})"
            for name in var_member
        ]
        var_member_src = ",".join(var_member_src)
        initer.append(
            f'\n        op_registe({{ "{func_name}", R"({cc_name})", extra_flags, {{{constructors}}}, {{{var_member_src}}} }});'
        )
        for hid, h_def in enumerate(res):
            h_def = list(h_def)
            # // @attrs(...)
            attrs = {}
            if h_def[4] != "":
                attrs = pyjt_compiler.parse_attrs(h_def[5])
            del h_def[4:6]
            # /* doc_string */
            # // @pybind(bind_name)
            # XxxXxxOp(args_def)
            doc_string = h_def[1].strip()
            h_def = h_def[2:]
            args_def = h_def[2][len(name2) + 1:-1]
            bind_name = h_def[1]
            if bind_name == "":
                bind_name = func_name
            if args_def == "":
                args = []
            else:
                args = list(
                    map(lambda s: s.split()[-1].split('=')[0],
                        args_def.split(',')))
            # py_args: "arg"_a=default
            py_args = []
            new_args_def = []
            new_args = []
            # source of convert VarHolder* to Var*
            vh2v_src = []
            more_src = []
            for arg, arg_def in zip(args, args_def.split(',')):
                py_arg = f'"{arg}"_a'
                if '=' in arg_def:
                    py_arg += "=" + arg_def.split('=')[-1]
                    arg_def = arg_def.split('=')[0]
                py_args.append(py_arg)
                arg_type = arg_def[:-(len(arg) + 1)].strip()
                if arg_type == "Var*":
                    new_args_def.append("VarHolder* " + arg)
                    vh2v_src.append(arg + "->var")
                    new_args.append(arg + "->var")
                elif arg_type.startswith("vector<Var*>"):
                    new_args_def.append(
                        arg_type.replace("Var", "VarHolder") + ' ' + arg)
                    new_args.append(arg)
                    more_src.append(f"_op->add_inputs({arg});")
                else:
                    new_args_def.append(arg_def)
                    new_args.append(arg)
            vh2v_src = "_op->set_inputs({" + ", ".join(vh2v_src) + "});" + \
                "".join(more_src)
            LOG.vvvv(f"Find op: {name2} args: {new_args}")
            if header.startswith("src/"):
                jit_headers += f"#include \"{header[4:]}\"\n"
            else:
                jit_headers += f"#include \"{header}\"\n"
            add_src(func_name + '_' * hid, new_args_def, name2, new_args,
                    vh2v_src, bind_name, py_args, jit_cc_src, doc_string,
                    attrs)
            if func_name in ["binary", "unary", "reduce"]:
                # generate binary op alias
                with open(os.path.join(jittor_path,
                                       f"src/ops/{func_name}_op.cc"),
                          encoding="utf-8") as f:
                    src = f.read()
                src = src.split(f"unordered_set<string> {func_name}_ops = "
                                "{")[1].split("};")[0]
                res2 = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src,
                                  re.S)
                # remove /* doc_string */ pattern
                res2 = [(_[3], _[4]) for _ in res2]
                LOG.vvvv(f"All supported {func_name} ops: {res2}")
                # remove op args
                if func_name == "reduce":
                    args_def = new_args_def[:1] + new_args_def[2:]
                    py_args_s = py_args[:1] + py_args[2:]
                else:
                    args_def = new_args_def[:-1]
                    py_args_s = py_args[:-1]
                # find the last type id(float64)
                # add "_" suffix for all function
                if func_name == "unary":
                    last_tid = res2.index(("", "float64"))
                # for each functor
                for tid, (bind_name, func_name2) in enumerate(res2):
                    # add _ for types
                    if func_name == "unary" and tid <= last_tid:
                        func_name3 = func_name2 + "_"
                    elif func_name == "reduce":
                        func_name4 = func_name2
                        func_name2 = "reduce_" + func_name2
                        func_name3 = func_name2
                    else:
                        func_name3 = func_name2
                    if len(bind_name) == 0:
                        bind_name = func_name2
                    if func_name == "reduce":
                        args = new_args[:1] + [f'ns_{func_name4}'
                                               ] + new_args[2:]
                    else:
                        args = new_args[:-1] + [f'ns_{func_name2}']
                    add_src(func_name3 + '_' * hid, args_def, name2, args,
                            vh2v_src, bind_name, py_args_s, jit_cc_src,
                            doc_string, attrs)

    jit_src = f"""
    #pragma once
    #include "pyjt/py_obj_holder.h"
    #include "var.h"
    #include "var_holder.h"
    #include "ops/op_register.h"
    {jit_headers}
    
    namespace jittor {{
    // fix make_array(py::array) undefine reference
    #pragma GCC visibility push(default)
    #define JIT_NAMESPACE {export+"_maker" if export else "jit_op_maker"}
    // @pyjt(ops)
    // @attrs(submodule{",core_name="+export if export else ""})
    namespace JIT_NAMESPACE {{
    {"".join(jit_cc_src)}

    void initer() {{
        string extra_flags = R"({extra_flags})";
        {"".join(initer)}
    }}
    int caller = (initer(), 0);
    
    }} // JIT_NAMESPACE
    }} // jittor
    {f'''
    namespace jittor {{
    extern void pyjt_def_{export}(PyObject*);
    }}

    static void init_module(PyModuleDef* mdef, PyObject* m) {{
        mdef->m_doc = "User defined custom ops";
        jittor::pyjt_def_{export}(m);
    }}
    PYJF_MODULE_INIT({export});

    ''' if export else ""}
    """
    return jit_src
示例#7
0
    if at_beginning[i] not in files4:
        continue
    files4.remove(at_beginning[i])
    files4.insert(i, at_beginning[i])
for v in at_last:
    if v not in files4:
        continue
    files4.remove(v)
    files4.append(v)
registers = [name for name in files4 if "register" in name]
for name in registers:
    files4.remove(name)
files = registers + files2 + files4
for file in jit_utils_core_files:
    files.remove(file)
LOG.vv("compile order:", files)

# manual Link omp using flags(os.RTLD_NOW | os.RTLD_GLOBAL)
# if cc_type=="icc":
#     os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
libname = {"clang": "omp", "icc": "iomp5", "g++": "gomp"}[cc_type]
libname = ctypes.util.find_library(libname)
assert libname is not None, "openmp library not found"
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)

version_file = os.path.join(jittor_path, "version")
if os.path.isfile(version_file):
    with open(version_file, 'r') as f:
        version = f.read().strip()
    key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
    # TODO: open the website
示例#8
0
def compile_src(src, h, basename):
    res = list(reg.finditer(src, re.S))
    if len(res)==0: return
    class_ranges = None
    class_name = None
    class_info = None
    submodule_name = None
    submodule_ranges = None
    submodule_info = None
    defs = []
    LOG.vv(("find in", h))
    for x in res:
        LOG.vvv((x, x.groups()))
        g = x.groups()
        doc = g[1]
        pyjt = g[3]
        attrs = g[5]
        esplit = lambda x: [] if x==None else \
            [ a.strip() for a in x.split(",") if len(a.strip()) ]
        attrs = parse_attrs(attrs)
        pynames = esplit(pyjt)
        end = x.end()
        def find_bc(i):
            while src[i] not in "({;":
                i += 1
            j = i+1
            if src[i]==';':
                return i, j
            presum = 1
            while True:
                if src[j] in "({[":
                    presum += 1
                elif src[j] in ")}]":
                    presum -= 1
                    if presum==0:
                        s = src[i]+src[j]
                        assert s in ("()","{}","()"), "braces not match "+s
                        return i, j
                j += 1
        # // @pyjt(DType)
        # struct DType {
        #              ^ --> a
        #     .....
        # } <--- b
        # or
        # // @pyjt(hash)
        # inline uint hash(const char* input)
        #                 ^ --> a           ^ --> b
        a, b = find_bc(end)
        is_property = 0
        if src[a] == ';':
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            is_property = 1
        if src[a] == '{':
            assert len(pynames)==1
            if "submodule" in attrs:
                assert submodule_ranges==None
                submodule_ranges = (a, b)
                submodule_name = src[end:a-1].strip().split()[-1]
                submodule_info = {
                    "pynames": pynames,
                    "attrs": attrs
                }
                continue
            assert class_ranges==None
            class_ranges = (a, b)
            class_name = src[end:a-1].strip().split()[-1]
            class_info = {
                "pynames": pynames,
                "attrs": attrs
            }
            continue
        is_scope_def = False
        is_static = False
        scope_name = ""
        if class_ranges != None:
            if class_ranges[0] < a and a < class_ranges[1]:
                is_scope_def = True
                scope_name = class_name
        if submodule_ranges != None:
            if submodule_ranges[0] < a and a < submodule_ranges[1]:
                is_scope_def = True
                scope_name = submodule_name
                is_static = True
        dec = src[end:b+1].strip()
        arr = src[end:a].strip().split()
        func_name = arr[-1]

        is_constructor = False
        if is_scope_def and func_name==class_name:
            is_constructor = True

        args = []
        for arg in split_args(src[a+1:b]):
            if arg=="": continue
            default = ""
            if "=" in arg:
                arg, default = arg.split('=')
                default = default
            arg = arg.strip()
            name = arg.split(' ')[-1]
            tp = arg[:-len(name)]
            tp = tp.strip()
            prev_tp = tp
            # const string& ----> string
            if tp.startswith("const") and tp.endswith("&"):
                tp = tp[5:-1].strip()
            # T&& -> T
            if tp.endswith("&&"):
                tp = tp[:-2].strip()
            # ArrayArgs& -> ArrayArgs
            if tp.endswith("&"):
                tp = tp[:-1].strip()
            args.append((tp, name.strip(), default.strip(), prev_tp))
        return_t = ""
        for a in arr[:-1]:
            if a in ["", "inline", "constexpr"]: continue
            if a == "static":
                is_static = True
                continue
            if return_t != "": return_t += " "
            return_t += a

        if is_scope_def and class_info and "submodule" in class_info["attrs"]:
            is_static = True

        for pid, pyname in enumerate(pynames):
            for rname in [ "__lt__", "__le__", "__gt__",
                "__ge__", "__eq__", "__ne__"]:
                if pyname.endswith(rname):
                    attrs[rname] = 1
                    pynames[pid] = pyname.replace(rname, "__richcmp__")

        def_info = {
            "is_scope_def": is_scope_def,
            "is_constructor": is_constructor,
            "is_static": is_static,
            "is_property": is_property,
            "func_name": func_name,
            "args": args, # [(type,name,defaut), ...]
            "return_t": return_t, # return type
            "dec": dec, # full string of xxx(A a, B b)
            "pynames": pynames, # names in @pyjt(...)
            "attrs": attrs, # attrs in @attrs(...)
            "doc": doc,
            "scope_name": scope_name,
        }
        if is_property:
            # This case
            # class XXX {
            #     // @pyjt(property)
            #     T property;
            # }
            assert is_scope_def and not is_static
            def_info["is_property"] = 1
            def_info["pynames"] = ["__get__"+n for n in pynames]
            assert return_t != "void"
            defs.append(dict(def_info))
            def_info["pynames"] = ["__set__"+n for n in pynames]
            assert len(args) == 0
            def_info["args"] = [(def_info["return_t"], func_name, "", "")]
            def_info["return_t"] = "void"
            defs.append(dict(def_info))
            continue
        else:
            defs.append(def_info)
        LOG.vvv(json.dumps(def_info, indent=4))
    # deal with defs
    if len(defs) == 0: return
    # include_name = h[4:] # remove "src/" prefix
    include_name = h
    code = []
    class_defs_code = []
    class_getsets_code = []
    class_gets = OrderedDict()
    class_sets = OrderedDict()
    class_slots_code = []
    submodule_defs_code = []
    def_targets = OrderedDict()
    for df in defs:
        for name in df["pynames"]:
            if df["is_scope_def"] and '.' not in name:
                if df["scope_name"] == class_name:
                    name = class_info["pynames"][0] + '.' + name
                else:
                    name = submodule_info["pynames"][0] + '.' + name
            if name not in def_targets:
                def_targets[name] = []
            def_targets[name].append(df)
    for name in def_targets:
        dfs = def_targets[name]
        target_scope_name = None
        LOG.vv(name)
        if "." in name:
            target_scope_name, name = name.split(".")
        # array for each df:
        arr_func_quick_check_runable = []
        arr_func_args_convert = []
        arr_fill_with_default = []
        arr_func_call = []
        arr_has_return = []
        self_as_arg0 = False
        for df in dfs:
            self_as_arg0 = class_info and \
                target_scope_name == class_info["pynames"][0] and \
                df["scope_name"] == submodule_name \
                and not name.startswith("__")
            res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0))
            arr_func_quick_check_runable.append(res[0])
            arr_func_args_convert.append(res[1])
            arr_fill_with_default.append(res[2])
            arr_func_call.append(res[3])
            arr_has_return.append(res[4])
            
        slot_name = None
        func_cast = ""
        func_fill = ""
        if name == "__init__":
            slot_name = "tp_init"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__repr__":
            slot_name = "tp_repr"
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__get__"):
            slot_name = "tp_gets"
            name = name[len("__get__"):]
            func_head = "(PyObject* self, void*) -> PyObject*"
            func_fill = "int64 n = 0; (void)n;"

        elif name.startswith("__set__"):
            slot_name = "tp_sets"
            name = name[len("__set__"):]
            func_head = "(PyObject* self, PyObject* arg, void*) -> int"
            func_fill = """
                int64 n=1;
                PyObject** args = &arg;
                (void)n, (void)args;
            """

        elif name == "__call__":
            slot_name = "tp_call"
            func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*"
            func_fill = """
                int64 n = Py_SIZE(_args);
                auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0);
                (void)n, (void)args;
                // TODO: support kw
                CHECK(kw==0);
            """

        elif name == "__dealloc__":
            slot_name = "tp_dealloc"
            func_head = "(PyObject* self) -> void"
            func_fill = "int64 n = 0"
        
        elif name in binary_number_slots:
            slot_name = "tp_as_number->"+binary_number_slots[name]
            func_head = "(PyObject* self, PyObject* b) -> PyObject*"
            if name.endswith("pow__"):
                func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """
        
        elif name in unary_number_slots:
            slot_name = "tp_as_number->"+unary_number_slots[name]
            func_head = "(PyObject* self) -> PyObject*"
            func_fill = """
                int64 n = 1;
                PyObject* args[] = {self};
                (void)n, (void)args;
            """
        
        elif name == "__richcmp__":
            slot_name = "tp_richcompare"
            func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*"
            func_fill = """
                int64 n = 2;
                PyObject* args[] = {self, b};
                (void)n, (void)args;
            """

        elif name == "__len__":
            slot_name = "tp_as_sequence->sq_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__map_len__":
            slot_name = "tp_as_mapping->mp_length"
            func_head = "(PyObject* self) -> Py_ssize_t"
            func_fill = """
                int64 n = 0;  
                (void)n;
            """

        elif name == "__getitem__":
            slot_name = "tp_as_sequence->sq_item"
            func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                (void)n;
                if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{
                    PyErr_SetString(PyExc_IndexError, "");
                    return 0;
                }}
            """

        elif name == "__map_getitem__":
            slot_name = "tp_as_mapping->mp_subscript"
            func_head = "(PyObject* self, PyObject* arg0) -> PyObject*"
            func_fill = f"""
                int64 n = 1;
                PyObject* args[] = {{arg0}};
                (void)n;
            """

        elif name.startswith("__"):
            LOG.f(f"Not support slot {name}")
            continue

        else:
            func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*"
            func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))"
            # if not return, return py_none
            arr_has_return = [ True for _ in arr_has_return ]

        arr_func_return = []
        doc_all = ""
        decs = "Declarations:\n"
        for did, has_return in enumerate(arr_has_return):
            df = dfs[did]
            func_call = arr_func_call[did]
            if df["doc"]:
                doc_all += "Document:\n"
                doc_all += df["doc"]
            doc_all += "\nDeclaration:\n"
            doc_all += df["dec"]
            decs += df["dec"]+'\n'
            if has_return:
                assert "-> int" not in func_head
                if "-> PyObject*" in func_head:
                    if "return_self" in df["attrs"]:
                        arr_func_return.append(
                            f"return (({func_call}), Py_INCREF(self), self)")
                    else:
                        arr_func_return.append(
                            f"return {get_pytype_map(df['return_t'],1)}(({func_call}))")
                    func_return_failed = "return nullptr"
                else:
                    arr_func_return.append(
                        f"return ({func_call});")
                    func_return_failed = "return -1"
            else:
                if "-> int" in func_head:
                    arr_func_return.append(f"return ({func_call},0)")
                    func_return_failed = "return -1"
                else:
                    assert "-> void" in func_head
                    arr_func_return.append(f"{func_call};return")
                    func_return_failed = "return"
        func = f"""
        {func_cast}[]{func_head} {{
            try {{
                {func_fill};
                uint64 arg_filled=0;
                (void)arg_filled;
                {"".join([f'''
                if ({arr_func_quick_check_runable[did]}) {{
                    {arr_func_args_convert[did]};
                    {arr_fill_with_default[did]};
                    {arr_func_return[did]};
                }}
                '''
                for did in range(len(arr_func_return))
                ])}
                LOGf << "Not a valid call";
            }} catch (const std::exception& e) {{
                PyErr_Format(PyExc_RuntimeError, "%s\\n%s",
                    e.what(),
                    R""({decs})""
                );
            }}
            {func_return_failed};
        }}
        """

        if slot_name:
            if slot_name=="tp_gets":
                class_gets[name] = {
                    "func": func,
                    "doc": doc_all
                }
                continue
            if slot_name=="tp_sets":
                class_sets[name] = {
                    "func": func,
                    "doc": ""
                }
                continue
            class_slots_code.append(f"""
            tp.{slot_name} = {func};
            """)
            continue
        need_static = ""
        if df["is_scope_def"] and df["is_static"] and \
            df["scope_name"] == class_name and \
            "submodule" not in class_info["attrs"]:
            need_static = " | METH_STATIC"
        func = (f"""
        {{ R""({name})"",
        {func},
        METH_FASTCALL | METH_KEYWORDS{need_static},
        R""({doc_all})""
        }}""")
        if df["is_scope_def"]:
            if df["scope_name"] == class_name or \
                (class_info and \
                    target_scope_name == class_info["pynames"][0]):
                class_defs_code.append(func)
            else:
                submodule_defs_code.append(func)
        else:
            code.append(func)
    prop_names = list(set(class_gets.keys()).union(class_sets.keys()))
    prop_names = sorted(prop_names)
    for prop_name in prop_names:
        get_func = "NULL"
        set_func = "NULL"
        doc = ""
        if prop_name in class_gets:
            get_func = class_gets[prop_name]["func"]
            if class_gets[prop_name]["doc"]:
                doc += class_gets[prop_name]["doc"]
        if prop_name in class_sets:
            set_func = class_sets[prop_name]["func"]
            if class_sets[prop_name]["doc"]:
                doc += class_sets[prop_name]["doc"]
        class_getsets_code.append(f"""
            {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}}
        """)
    code.append("{0,0,0,0}")
    class_defs_code.append("{0,0,0,0}")
    class_getsets_code.append("{0,0,0,0}")
    submodule_defs_code.append("{0,0,0,0}")
    core_name = "jittor_core"
    if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]:
        core_name = class_info["attrs"]["core_name"]
    if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]:
        core_name = submodule_info["attrs"]["core_name"]
    has_map = class_name in ["VarHolder", "NanoVector"]
    has_seq = class_name == "NanoVector"
    code = f"""
    #include "pyjt/py_converter.h"
    #include "common.h"
    #include "{include_name}"

    namespace jittor {{

    {
    "" if class_name is None else
    f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else
    f"PyTypeObject Pyjt{class_name};"
    }
    
    void pyjt_def_{basename}(PyObject* m) {{
        static PyMethodDef defs[] = {{
            {",".join(code)}
        }};
        ASSERT(PyModule_AddFunctions(m, defs)==0);
        {
        f'''
        static PyMethodDef class_defs[] = {{
            {",".join(class_defs_code)}
        }};
        static PyGetSetDef class_getsets[] = {{
            {",".join(class_getsets_code)}
        }};

        static PyNumberMethods number_methods = {{0}};
        {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;"
        if "heaptype" in class_info["attrs"] else
        f"auto& tp = Pyjt{class_name};"}
        tp.tp_as_number = &number_methods;

        {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""}
        {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""}

        {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""}
        {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""}
        
        tp.tp_name = "{core_name}.{class_info["pynames"][0]}";
        tp.tp_basicsize = GET_OBJ_SIZE({class_name});
        tp.tp_new = PyType_GenericNew;
        tp.tp_flags = Py_TPFLAGS_DEFAULT;
        {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object<string>(tp.tp_name);"
        if "heaptype" in class_info["attrs"] else ""}
        tp.tp_methods = &class_defs[0];
        tp.tp_getset = &class_getsets[0];
        {"".join(class_slots_code)};
        ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0);
        Py_INCREF(&tp);
        ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp));
        ''' if class_name is not None else ""
        }
        {f'''

        // sub module def
        static PyMethodDef submodule_defs[] = {{
            {",".join(submodule_defs_code)}
        }};
        auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}");
        ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0);
        ASSERT(sub);
        ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub));
        ''' if submodule_name is not None else ""
        }

    }}

    }}
    """
    return code
示例#9
0
 def unlock(self):
     fcntl.flock(self.handle, fcntl.LOCK_UN)
     self.is_locked = False
     LOG.vv(f'UNLOCK PID: {os.getpid()}')
示例#10
0
 def lock(self):
     fcntl.flock(self.handle, fcntl.LOCK_EX)
     self.is_locked = True
     LOG.vv(f'LOCK PID: {os.getpid()}')