Esempio n. 1
0
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
        "src/utils/str_utils.cc",
    ]
    if os.name == 'nt':
        files = [x.replace('/', '\\') for x in files]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(
        cc_path, cc_flags + f" {opt_flags} ", files,
        jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please rerun your command.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags + f" {opt_flags} ", files,
                jit_utils.cache_path + '/jit_utils_core' + extension_suffix,
                True)
Esempio n. 2
0
def calculate_md5(file_path, chunk_size=1024 * 1024):
    md5 = hashlib.md5()
    with open(file_path, 'rb') as f:
        for chunk in iter(lambda: f.read(chunk_size), b''):
            md5.update(chunk)
    md5 = md5.hexdigest()
    LOG.v(f"file {file_path} md5: {md5}")
    return md5
Esempio n. 3
0
def check_pybt(gdb_path, python_path):
    if gdb_path == '' or python_path == '':
        return False
    ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'")
    if 'python frame' in ret:
        LOG.v("py-bt found in gdb.")
        return True
    return False
Esempio n. 4
0
def env_or_try_find(name, bname):
    if name in os.environ:
        path = os.environ[name]
        if path != "":
            version = jit_utils.get_version(path)
            LOG.i(f"Found {bname}{version} at {path}")
        return path
    return try_find_exe(bname)
Esempio n. 5
0
 def unlock(self):
     if fcntl:
         fcntl.flock(self.handle, fcntl.LOCK_UN)
     else:
         hfile = win32file._get_osfhandle(self.handle.fileno())
         win32file.UnlockFileEx(hfile, 0, -0x10000, _OVERLAPPED)
     self.is_locked = False
     LOG.vv(f'UNLOCK PID: {os.getpid()}')
Esempio n. 6
0
 def lock(self):
     if fcntl:
         fcntl.flock(self.handle, fcntl.LOCK_EX)
     else:
         hfile = win32file._get_osfhandle(self.handle.fileno())
         win32file.LockFileEx(hfile, 2, 0, -0x10000, _OVERLAPPED)
     self.is_locked = True
     LOG.vv(f'LOCK PID: {os.getpid()}')
Esempio n. 7
0
    def display_worker_status(self):
        ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:

.. code-block:: console

        progress:479/5005
        batch(s): 0.302 wait(s):0.000
        recv(s): 0.069  to_jittor(s):0.021
        recv_raw_call: 6720.0
        last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1]
        ID      wait(s) load(s) send(s) total
        #0      0.000   1.340   2.026   3.366   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #1      0.000   1.451   3.607   5.058   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #2      0.000   1.278   1.235   2.513   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #3      0.000   1.426   1.927   3.353   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #4      0.000   1.452   1.074   2.526   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #5      0.000   1.422   3.204   4.625   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #6      0.000   1.445   1.953   3.398   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
        #7      0.000   1.582   0.507   2.090   Buffer(free=0.000% l=308283552 r=308283552 size=536870912)

Meaning of the outputs:

* progress: dataset loading progress (current/total)
* batch: batch time, exclude data loading time
* wait: time of main proc wait worker proc
* recv: time of recv batch data
* to_jittor: time of batch data to jittor variable
* recv_raw_call: total number of underlying recv_raw called
* last 10 workers: id of last 10 workers which main proc load from.
* table meaning
    * ID: worker id
    * wait: worker wait time
    * open: worker image open time
    * load: worker load time
    * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes).

Example::
  
  from jittor.dataset import Dataset
  class YourDataset(Dataset):
      pass
  dataset = YourDataset().set_attrs(num_workers=8)
  for x, y in dataset:
      dataset.display_worker_status()
        '''
        if not hasattr(self, "workers"):
            return
        msg = [""]
        msg.append(f"progress:{self.last_id}/{self.batch_len}")
        msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
        msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
        msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
        msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
        for i in range(self.num_workers):
            w = self.workers[i]
            s = w.status
            msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}")
        LOG.i('\n'.join(msg))
Esempio n. 8
0
def gen_jit_flags():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL)

    flags_defs = []
    visit = {}

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for _, args in defs:
            args = args.split(",")
            type = args[0].strip()
            name = args[1].strip()
            if not has_cuda and "cuda" in name and name != "use_cuda":
                continue
            default = args[2].strip()
            doc = ",".join(args[3:])
            doc = eval(f"({doc})")
            LOG.vv(f"Find define {name} from {src_name}")
            if name in visit:
                continue
            visit[name] = 1
            jit_declares.append(f"DECLARE_FLAG({type}, {name});")
            flags_defs.append(f"""
                /* {name}(type:{type}, default:{default}): {doc} */
                // @pyjt(__get__{name})
                {type} _get_{name}() {{ return {name}; }}
                // @pyjt(__set__{name})
                void _set_{name}({type} v) {{ set_{name}(v); }}
                {f'''// @pyjt(__set__{name})
                void _set_{name}(bool v) {{ set_{name}(v); }}
                ''' if type=="int" else ""}
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #include "utils/flags.h"

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(flags)
    struct _Flags {{
        // @pyjt(__init__)
        _Flags() {{}}
        {"".join(flags_defs)}
    }};

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f:
        f.write(jit_src)
Esempio n. 9
0
def install(path):
    LOG.i("Installing MSVC...")
    filename = "msvc.zip"
    url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
    md5sum = "55f0c175fdf1419b124e0fc498b659d2"
    download_url_to_local(url, filename, path, md5sum)
    fullname = os.path.join(path, filename)
    import zipfile
    with zipfile.ZipFile(fullname, "r") as f:
        f.extractall(path)
Esempio n. 10
0
def compile_single(head_file_name, src_file_name, src=None):
    basename = head_file_name.split("/")[-1].split(".")[0]
    if src==None:
        with open(head_file_name, 'r') as f:
            src = f.read()
    code = compile_src(src, head_file_name, basename)
    if not code: return False
    LOG.vvv("write to", src_file_name)
    LOG.vvvv(code)
    with open(src_file_name, 'w') as f:
        f.write(code)
    return True
Esempio n. 11
0
 def __init__(self, base_name, rtol=5e-2, atol=1e-3):
     if os.environ.get("use_auto_diff", '1') == '0':
         return
     hook_rand()
     self.rid = 0
     self.base_name = base_name
     self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
     os.makedirs(self.base_path, exist_ok=True)
     self.rtol = rtol
     self.atol = atol
     LOG.i("Use cache path:", self.base_path)
     LOG.i(f"rtol:{rtol} atol:{atol}")
Esempio n. 12
0
def compile(cache_path, jittor_path):
    headers1 = glob.glob(jittor_path + "/src/**/*.h", recursive=True)
    headers2 = glob.glob(cache_path + "/gen/**/*.h", recursive=True)
    headers = headers1 + headers2
    basenames = []
    pyjt_names = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        bh = os.path.basename(h)
        # jit_op_maker.h merge compile with var_holder.h
        if bh == "var_holder.h": continue
        if bh == "jit_op_maker.h":
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = bh.split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)
        pyjt_names.append(fname)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
    pyjt_names.append(fname)
    return pyjt_names
Esempio n. 13
0
 def save_input(self, *data):
     '''
         for input, label in torch_dataloader:
             hook.save_input(data)
     '''
     if self.mode == "save":
         self.record_status["[input]"] += 1
         fpath = os.path.join(
             self.base_path, f"__input-{self.record_status['[input]']}.pkl")
         with open(fpath, 'wb') as f:
             pickle.dump(convert(data), f)
         LOG.i(f"save input: ok")
     else:
         raise RuntimeError("save_input is invalid in [check] mode")
Esempio n. 14
0
def compile(cache_path, jittor_path):
    headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines()
    headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines()
    headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \
        [ os.path.join(cache_path, h) for h in headers2 ]
    basenames = []
    pyjt_names = []
    for h in headers:
        with open(h, 'r') as f:
            src = f.read()

        # jit_op_maker.h merge compile with var_holder.h
        if h.endswith("src/var_holder.h"): continue
        if h.endswith("jit_op_maker.h"):
            with open(os.path.join(jittor_path, "src", "var_holder.h"),
                      "r") as f:
                src = f.read() + src
        basename = h.split("/")[-1].split(".")[0]
        fname = "pyjt_" + basename + ".cc"
        fname = os.path.join(cache_path, "gen", fname)
        check = compile_single(h, fname, src)

        if not check: continue

        basenames.append(basename)
        pyjt_names.append(fname)

    code = f"""
    #include "pyjt/numpy.h"
    #include "pyjt/py_converter.h"
    #include "common.h"

    namespace jittor {{

    { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])}

    void pyjt_def_all(PyObject* m) {{
        numpy_init();
        { " ".join([f"pyjt_def_{n}(m);" for n in basenames])}
    }}

    }}
    """
    fname = os.path.join(cache_path, "gen", "pyjt_all.cc")
    LOG.vvv(("write to", fname))
    LOG.vvvv(code)
    with open(fname, "w") as f:
        f.write(code)
    pyjt_names.append(fname)
    return pyjt_names
Esempio n. 15
0
def compile_custom_ops(filenames, extra_flags=""):
    """Compile custom ops
    filenames: path of op source files, filenames must be
        pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the 
        type name of op must be XxxXxxOp.
    extra_flags: extra compile flags
    return: compiled ops
    """
    srcs = {}
    headers = {}
    builds = []
    includes = []
    for name in filenames:
        name = os.path.realpath(name)
        if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"):
            builds.append(name)
        bname = os.path.basename(name)
        bname = os.path.splitext(bname)[0]
        if bname.endswith("_op"):
            bname = bname[:-3]
            if name.endswith(".cc"):
                srcs[bname] = name
            elif name.endswith(".h"):
                includes.append(os.path.dirname(name))
                headers[bname] = name
    assert len(srcs) == len(headers), "Source and header names not match"
    for name in srcs:
        assert name in headers, f"Header of op {name} not found"
    gen_name = "gen_ops_" + "_".join(headers.keys())
    if len(gen_name) > 100:
        gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

    includes = set(includes)
    includes = "".join(map(lambda x: f" -I'{x}' ", includes))
    LOG.vvvv(f"Include flags:{includes}")

    op_extra_flags = includes + extra_flags

    gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags)
    make_cache_dir(os.path.join(cache_path, "custom_ops"))
    gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc")
    gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h")
    gen_lib = os.path.join("custom_ops", gen_name+extension_suffix)
    with open(gen_head_fname, "w") as f:
        f.write(gen_src)
    pyjt_compiler.compile_single(gen_head_fname, gen_src_fname)
    # gen src initialize first
    builds.insert(0, gen_src_fname)
    LOG.vvv(f"Build custum ops lib:{gen_lib}")
    LOG.vvvv(f"Build sources:{builds}")
    compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib)

    # add python path and import
    LOG.vvv(f"Import custum ops lib:{gen_lib}")
    lib_path = os.path.join(cache_path, "custom_ops")
    if lib_path not in os.sys.path:
        os.sys.path.append(lib_path)
    with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
        exec(f"import {gen_name}")
    return (locals()[gen_name]).ops
Esempio n. 16
0
def gen_jit_tests():
    all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines()
    jit_declares = []
    re_def = re.compile("JIT_TEST\\((.*?)\\)")
    names = set()
    test_defs = []

    for src_name in all_src:
        src_name = os.path.join(jittor_path, src_name)
        with open(src_name) as f:
            src = f.read()
        defs = re_def.findall(src)
        for name in defs:
            LOG.vv(f"Find test {name} from {src_name}")
            assert name not in names, f"Conflict test name {name}"
            names.add(name)
            jit_declares.append(f"JIT_TEST({name});")
            test_defs.append(f"""
                /* From {src_name} */
                // @pyjt({name})
                static inline void test_{name}() {{ jit_test_{name}(); }} 
            """)

    jit_declares = "\n    ".join(jit_declares)
    jit_src = f"""
    #pragma once
    #include "common.h"

    void expect_error(std::function<void()> func) {{
        try {{ func(); }}
        catch (...) {{ return; }}
        CHECK(0) << "Missing error";
    }}

    namespace jittor {{
    
    {jit_declares}

    // @pyjt(tests)
    // @attrs(submodule)
    namespace tests {{
        {"".join(test_defs)}
    }}

    }} // jittor
    """
    LOG.vvvv(jit_src)
    with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f:
        f.write(jit_src)
Esempio n. 17
0
def install_cuda():
    cuda_driver_version = get_cuda_driver()
    if not cuda_driver_version:
        return None
    LOG.i("cuda_driver_version: ", cuda_driver_version)

    if cuda_driver_version >= [11, 2]:
        cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
        md5 = "b93a1a5d19098e93450ee080509e9836"
    elif cuda_driver_version >= [
            11,
    ]:
        cuda_tgz = "cuda11.0_cudnn8_linux.tgz"
        md5 = "5dbdb43e35b4db8249027997720bf1ca"
    elif cuda_driver_version >= [10, 2]:
        cuda_tgz = "cuda10.2_cudnn7_linux.tgz"
        md5 = "40f0563e8eb176f53e55943f6d212ad7"
    elif cuda_driver_version >= [
            10,
    ]:
        cuda_tgz = "cuda10.0_cudnn7_linux.tgz"
        md5 = "f16d3ff63f081031d21faec3ec8b7dac"
    else:
        raise RuntimeError(
            f"Unsupport cuda driver version: {cuda_driver_version}")
    jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor",
                               "jtcuda")
    nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc")
    nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64")
    sys.path.append(nvcc_lib_path)
    new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path
    os.environ["LD_LIBRARY_PATH"] = new_ld_path

    if os.path.isfile(nvcc_path):
        return nvcc_path

    os.makedirs(jtcuda_path, exist_ok=True)
    cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz)
    download_url_to_local(
        "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + cuda_tgz, cuda_tgz,
        jtcuda_path, md5)

    import tarfile
    with tarfile.open(cuda_tgz_path, "r") as tar:
        tar.extractall(cuda_tgz_path[:-4])

    assert os.path.isfile(nvcc_path)
    return nvcc_path
Esempio n. 18
0
def hook_rand():
    global rand_hooked
    if rand_hooked: return
    rand_hooked = True
    np.random.seed(0)
    if "torch" in sys.modules:
        LOG.i("Hook torch.rand")
        torch = sys.modules["torch"]
        torch.rand = hook_pt_rand
        torch.normal = hook_pt_normal
        torch.manual_seed(0)
    if "jittor" in sys.modules:
        jittor = sys.modules["jittor"]
        LOG.i("Hook jittor.random")
        jittor.random = hook_jt_rand
        jittor.seed(0)
Esempio n. 19
0
 def load_input(self):
     '''
         for fake_input, fake_label in jittor_dataset:
             input, label = hook.load_input()
             input = jt.array(input)
             label = jt.array(label)
     '''
     if self.mode == "check":
         self.record_status["[input]"] += 1
         fpath = os.path.join(
             self.base_path, f"__input-{self.record_status['[input]']}.pkl")
         with open(fpath, 'rb') as f:
             data = pickle.load(f)
         LOG.i(f"load input: ok")
         return data
     else:
         raise RuntimeError("load_input is invalid in [save] mode")
Esempio n. 20
0
 def __init__(self, root, transform=None):
     super().__init__()
     self.root = root
     self.transform = transform
     self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()])
     self.class_to_idx = {v:k for k,v in enumerate(self.classes)}
     self.imgs = []
     image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
     
     for i, class_name in enumerate(self.classes):
         class_dir = os.path.join(root, class_name)
         for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)):
             for fname in sorted(fnames):
                 if os.path.splitext(fname)[-1].lower() in image_exts:
                     path = os.path.join(class_dir, fname)
                     self.imgs.append((path, i))
     LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
     self.set_attrs(total_len=len(self.imgs))
Esempio n. 21
0
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
    ]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please restart jittor.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
Esempio n. 22
0
def compile_extern():
    # compile llvm passes
    if cc_type != "clang":
        return
    global kernel_opt_flags
    cache_path_llvm = os.path.join(cache_path, "llvm")
    jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm")
    clang_dir = os.path.dirname(get_full_path_of_executable(cc_path))
    assert clang_dir.endswith(
        "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}"
    llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include"))
    assert os.path.isdir(llvm_include), "LLVM include path not found"
    make_cache_dir(cache_path_llvm)
    files = os.listdir(jittor_path_llvm)
    # test_pass.cc is used for test link problem of llvm pass plugin
    test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc")
    with open(test_pass_path, 'w') as f:
        f.write("int main() {return 0;}")

    # -fno-rtti fix link error

    # -Wl,-znodelete fix segfault
    # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287

    # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass
    # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt

    # try different flags
    try_flags = [
        " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ",
        " -Wl,-znodelete ",
    ]
    found_flags_id = -1
    for fname in files:
        for i, flag in enumerate(try_flags):
            if found_flags_id != -1 and found_flags_id != i:
                continue
            so_name = os.path.join(cache_path_llvm,
                                   os.path.splitext(fname)[0] + f".{i}.so")
            compile(cc_path,
                    f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'",
                    [os.path.join(jittor_path_llvm, fname)], so_name)
            # if not found available flags, we test it.
            if found_flags_id == -1:
                try:
                    s = run_cmd(
                        f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}",
                        cache_path_llvm,
                        print_error=False)
                except Exception as e:
                    LOG.v(f"Try flag {flag} failed: {e}")
                    continue
                found_flags_id = i
            kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' "
            break
        else:
            LOG.w("Clang is used, but LLVM pass plugin is unable to link.")
            break
    LOG.vv(f"Compile extern llvm passes: {str(files)}")
Esempio n. 23
0
    def hook_module(self, mod, mod_name=""):
        if os.environ.get("use_auto_diff", '1') == '0':
            return
        if mod_name != "":
            mod_name = "<" + mod_name + ">"
        self.hooked_models[mod_name] = mod

        def forward_hook(self2, input, output, kw=None):
            ex_name = '[' + self2.__class__.__name__ + ']'
            if "relu" not in self2.__class__.__name__.lower():
                # not test relu, because input may be inplaced
                self.record(self2.__ad_mod_name__ + ".input", input, ex_name)
            self.record(self2.__ad_mod_name__ + ".output", output, ex_name)
            if kw is not None:
                self.record(self2.__ad_mod_name__ + ".kw", kw, ex_name)

        names = []
        for name, module in mod.named_modules():
            ns = name.split('.')
            skip = 0
            for n in ns:
                if n.startswith('_'):
                    skip = 1
            if skip:
                LOG.i("skip", name)
                continue
            name = mod_name + name
            module.__ad_mod_name__ = name
            names.append(name)
            module.register_forward_hook(forward_hook)
            mod_class_name = module.__class__.__name__.lower()
            # make dropout in eval mod and record dropout.p
            if "dropout" in mod_class_name:
                self.record(name + '.p', module.p, "[" + mod_class_name + "]")
                module.eval()
        ps = {mod_name + k: v for k, v in mod.state_dict().items()}
        self.record_params(ps, mod_name)
        self.record("module names", names)
 def run():
     start_time = time.time()
     fop_num = 10000
     fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5]
     # fop_output_num = (1, 0) # [1,1]
     inner_op_num = (0, 3)
     fop_type_num = 63 # how many different fuse op
     input_queue_num = 15
     queue = [1.0]*(input_queue_num+1)
     x = get_xorshf96()
     rand = lambda x, l, r: l+((x())&r)
     ops = ["add", "subtract", "multiply", "divide"]
     get_op = lambda x: ops[(x())&3]
     for i in range(fop_num):
         prev = bc(queue[rand(x,0,input_queue_num)])
         y = get_xorshf96(x()&fop_type_num)
         inum = rand(y, *fop_input_num)
         q = [prev]
         for i in range(inum-1):
             n = bc(queue[rand(x,0,input_queue_num)])
             prev = jt.binary(prev, n, get_op(y))
             q.append(prev)
         innum = rand(y,*inner_op_num)
         for _ in range(innum):
             j = rand(y,0,len(q)-1)
             n = q[j]
             prev = jt.binary(prev, n, get_op(y))
             q[j] = prev
         prev = rd(prev)
         queue[rand(x,0,input_queue_num)] = prev
     a = jt.array(0.0)
     for x in queue:
         a += x
     LOG.i("build graph", time.time()-start_time, jt.liveness_info().values())
     start_time = time.time()
     a.sync()
     LOG.i("execute", time.time()-start_time)
Esempio n. 25
0
    def check_array(self, name, a, b):
        rtol = self.rtol
        atol = self.atol
        global has_error
        err = np.abs(a-b)
        tol = atol + rtol * np.abs(b)
        is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5))
        index = np.where(is_error)
        assert len(index)>0
        if len(index[0]) == 0:
            return

        has_error += 1
        LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}")
        i = tuple( i[0] for i in index )
        err_rate = is_error.mean()
        LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ")
        if err_rate > 0.01:
            LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
Esempio n. 26
0
def generate_error_code_from_func_header(func_head, target_scope_name, name,
                                         dfs, basename, h, class_info):
    # func_head is a string like:
    # (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*
    lib_name = os.path.basename(h).split("_")[0]
    # TODO: fix/add var help
    if target_scope_name == "Var": target_scope_name = None
    if target_scope_name:
        if target_scope_name == "flags":
            help_name = "flags"
        else:
            help_name = "" + target_scope_name + '.' + name
    else:
        help_name = name
    if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
        help_name = lib_name + '.' + help_name
    help_cmd = f"help(jt.{help_name})"

    LOG.vvv("gen err from func_head", func_head)
    args = func_head[1:].split(")")[0].split(",")
    error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\""
    error_code += r' << "\n\nTypes of your inputs are:\n"'
    for arg in args:
        arg = arg.strip()
        if arg.startswith("PyObject* "):
            t, n = arg.split(' ')
            if n == "args" or n == "_args":
                error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} "
            elif n == "kw":
                error_code += f" << PyKwArgPrinter{{{n}}} "
            else:
                error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} "
        elif arg.startswith("PyObject** "):
            t, n = arg.split(' ')
            error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} "
            break
        else:
            LOG.vvv("Unhandled arg", arg)
    LOG.vvv("gen err from func_head", func_head, " -> ", error_code)
    return error_code
Esempio n. 27
0
    def __init__(self, base_name, rtol=5e-2, atol=1e-3):
        if os.environ.get("use_auto_diff", '1') == '0':
            return
        hook_rand()
        self.rid = 0
        self.base_name = base_name
        self.base_path = os.path.join(str(Path.home()), ".cache", "jittor",
                                      "auto_diff", base_name)
        if not os.path.exists(self.base_path):
            os.makedirs(self.base_path, exist_ok=True)
            self.mode = 'save'
        else:
            self.mode = 'check'

        self.record_status = defaultdict(int)
        self.rtol = rtol
        self.atol = atol
        self.param_name_map = {}
        self.hooked_models = {}
        LOG.i(f"Jittor AutoDiff: [{self.mode}] mode")
        LOG.i("Use cache path:", self.base_path)
        LOG.i(f"rtol:{rtol} atol:{atol}")
Esempio n. 28
0
 def record(self, name, data, ex_name=""):
     if os.environ.get("use_auto_diff", '1') == '0':
         return
     rid = self.rid
     self.rid += 1
     fpath = os.path.join(self.base_path, f"{rid}.pkl")
     data = convert(data)
     if os.path.isfile(fpath):
         with open(fpath, 'rb') as f:
             pre_name, pre_data = pickle.load(f)
         if pre_name != name:
             global has_error
             has_error += 1
             LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
             self.rid -= 1
             return
         LOG.i(f"check {rid}:<{ex_name}{name}> ...")
         self.check(ex_name + name, pre_data, data)
     else:
         with open(fpath, 'wb') as f:
             pickle.dump((name, data), f)
         LOG.i(f"save {rid}:<{name}> ok")
Esempio n. 29
0
    def record(self, name, data, ex_name=""):
        if os.environ.get("use_auto_diff", '1') == '0':
            return
        self.record_status[name] += 1
        fpath = os.path.join(self.base_path,
                             f"{name}-{self.record_status[name]}.pkl")
        data = convert(data)
        self.rid += 1

        if self.mode == 'check':
            if os.path.isfile(fpath):
                with open(fpath, 'rb') as f:
                    pre_name, pre_data = pickle.load(f)
                LOG.i(f"check {self.rid}:<{ex_name}{name}> ...")
                self.check(ex_name + name, pre_data, data)
            else:
                global has_error
                has_error += 1
                LOG.e(f"No previous result found: {name}")
                return
        else:
            with open(fpath, 'wb') as f:
                pickle.dump((name, data), f)
            LOG.i(f"save {self.rid}:<{name}> ok")
Esempio n. 30
0
def try_find_exe(*args):
    try:
        return find_exe(*args)
    except:
        LOG.v(f"{args[0]} not found.")
        return ""