def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", "src/utils/str_utils.cc", ] if os.name == 'nt': files = [x.replace('/', '\\') for x in files] global jit_utils_core_files jit_utils_core_files = files recompile = compile( cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please rerun your command.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
def calculate_md5(file_path, chunk_size=1024 * 1024): md5 = hashlib.md5() with open(file_path, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) md5 = md5.hexdigest() LOG.v(f"file {file_path} md5: {md5}") return md5
def check_pybt(gdb_path, python_path): if gdb_path == '' or python_path == '': return False ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'") if 'python frame' in ret: LOG.v("py-bt found in gdb.") return True return False
def env_or_try_find(name, bname): if name in os.environ: path = os.environ[name] if path != "": version = jit_utils.get_version(path) LOG.i(f"Found {bname}{version} at {path}") return path return try_find_exe(bname)
def unlock(self): if fcntl: fcntl.flock(self.handle, fcntl.LOCK_UN) else: hfile = win32file._get_osfhandle(self.handle.fileno()) win32file.UnlockFileEx(hfile, 0, -0x10000, _OVERLAPPED) self.is_locked = False LOG.vv(f'UNLOCK PID: {os.getpid()}')
def lock(self): if fcntl: fcntl.flock(self.handle, fcntl.LOCK_EX) else: hfile = win32file._get_osfhandle(self.handle.fileno()) win32file.LockFileEx(hfile, 2, 0, -0x10000, _OVERLAPPED) self.is_locked = True LOG.vv(f'LOCK PID: {os.getpid()}')
def display_worker_status(self): ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow: .. code-block:: console progress:479/5005 batch(s): 0.302 wait(s):0.000 recv(s): 0.069 to_jittor(s):0.021 recv_raw_call: 6720.0 last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] ID wait(s) load(s) send(s) total #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912) Meaning of the outputs: * progress: dataset loading progress (current/total) * batch: batch time, exclude data loading time * wait: time of main proc wait worker proc * recv: time of recv batch data * to_jittor: time of batch data to jittor variable * recv_raw_call: total number of underlying recv_raw called * last 10 workers: id of last 10 workers which main proc load from. * table meaning * ID: worker id * wait: worker wait time * open: worker image open time * load: worker load time * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). Example:: from jittor.dataset import Dataset class YourDataset(Dataset): pass dataset = YourDataset().set_attrs(num_workers=8) for x, y in dataset: dataset.display_worker_status() ''' if not hasattr(self, "workers"): return msg = [""] msg.append(f"progress:{self.last_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): w = self.workers[i] s = w.status msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") LOG.i('\n'.join(msg))
def gen_jit_flags(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL) flags_defs = [] visit = {} for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for _, args in defs: args = args.split(",") type = args[0].strip() name = args[1].strip() if not has_cuda and "cuda" in name and name != "use_cuda": continue default = args[2].strip() doc = ",".join(args[3:]) doc = eval(f"({doc})") LOG.vv(f"Find define {name} from {src_name}") if name in visit: continue visit[name] = 1 jit_declares.append(f"DECLARE_FLAG({type}, {name});") flags_defs.append(f""" /* {name}(type:{type}, default:{default}): {doc} */ // @pyjt(__get__{name}) {type} _get_{name}() {{ return {name}; }} // @pyjt(__set__{name}) void _set_{name}({type} v) {{ set_{name}(v); }} {f'''// @pyjt(__set__{name}) void _set_{name}(bool v) {{ set_{name}(v); }} ''' if type=="int" else ""} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #include "utils/flags.h" namespace jittor {{ {jit_declares} // @pyjt(flags) struct _Flags {{ // @pyjt(__init__) _Flags() {{}} {"".join(flags_defs)} }}; }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w') as f: f.write(jit_src)
def install(path): LOG.i("Installing MSVC...") filename = "msvc.zip" url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename md5sum = "55f0c175fdf1419b124e0fc498b659d2" download_url_to_local(url, filename, path, md5sum) fullname = os.path.join(path, filename) import zipfile with zipfile.ZipFile(fullname, "r") as f: f.extractall(path)
def compile_single(head_file_name, src_file_name, src=None): basename = head_file_name.split("/")[-1].split(".")[0] if src==None: with open(head_file_name, 'r') as f: src = f.read() code = compile_src(src, head_file_name, basename) if not code: return False LOG.vvv("write to", src_file_name) LOG.vvvv(code) with open(src_file_name, 'w') as f: f.write(code) return True
def __init__(self, base_name, rtol=5e-2, atol=1e-3): if os.environ.get("use_auto_diff", '1') == '0': return hook_rand() self.rid = 0 self.base_name = base_name self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name) os.makedirs(self.base_path, exist_ok=True) self.rtol = rtol self.atol = atol LOG.i("Use cache path:", self.base_path) LOG.i(f"rtol:{rtol} atol:{atol}")
def compile(cache_path, jittor_path): headers1 = glob.glob(jittor_path + "/src/**/*.h", recursive=True) headers2 = glob.glob(cache_path + "/gen/**/*.h", recursive=True) headers = headers1 + headers2 basenames = [] pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() bh = os.path.basename(h) # jit_op_maker.h merge compile with var_holder.h if bh == "var_holder.h": continue if bh == "jit_op_maker.h": with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = bh.split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code) pyjt_names.append(fname) return pyjt_names
def save_input(self, *data): ''' for input, label in torch_dataloader: hook.save_input(data) ''' if self.mode == "save": self.record_status["[input]"] += 1 fpath = os.path.join( self.base_path, f"__input-{self.record_status['[input]']}.pkl") with open(fpath, 'wb') as f: pickle.dump(convert(data), f) LOG.i(f"save input: ok") else: raise RuntimeError("save_input is invalid in [check] mode")
def compile(cache_path, jittor_path): headers1 = run_cmd('find -L src/ | grep ".h$"', jittor_path).splitlines() headers2 = run_cmd('find gen/ | grep ".h$"', cache_path).splitlines() headers = [ os.path.join(jittor_path, h) for h in headers1 ] + \ [ os.path.join(cache_path, h) for h in headers2 ] basenames = [] pyjt_names = [] for h in headers: with open(h, 'r') as f: src = f.read() # jit_op_maker.h merge compile with var_holder.h if h.endswith("src/var_holder.h"): continue if h.endswith("jit_op_maker.h"): with open(os.path.join(jittor_path, "src", "var_holder.h"), "r") as f: src = f.read() + src basename = h.split("/")[-1].split(".")[0] fname = "pyjt_" + basename + ".cc" fname = os.path.join(cache_path, "gen", fname) check = compile_single(h, fname, src) if not check: continue basenames.append(basename) pyjt_names.append(fname) code = f""" #include "pyjt/numpy.h" #include "pyjt/py_converter.h" #include "common.h" namespace jittor {{ { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} void pyjt_def_all(PyObject* m) {{ numpy_init(); { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} }} }} """ fname = os.path.join(cache_path, "gen", "pyjt_all.cc") LOG.vvv(("write to", fname)) LOG.vvvv(code) with open(fname, "w") as f: f.write(code) pyjt_names.append(fname) return pyjt_names
def compile_custom_ops(filenames, extra_flags=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): builds.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name+".h") gen_lib = os.path.join("custom_ops", gen_name+extension_suffix) with open(gen_head_fname, "w") as f: f.write(gen_src) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname) # gen src initialize first builds.insert(0, gen_src_fname) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, cc_flags+opt_flags+includes+extra_flags, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) with jit_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND): exec(f"import {gen_name}") return (locals()[gen_name]).ops
def gen_jit_tests(): all_src = run_cmd('find -L src/ | grep "cc$"', jittor_path).splitlines() jit_declares = [] re_def = re.compile("JIT_TEST\\((.*?)\\)") names = set() test_defs = [] for src_name in all_src: src_name = os.path.join(jittor_path, src_name) with open(src_name) as f: src = f.read() defs = re_def.findall(src) for name in defs: LOG.vv(f"Find test {name} from {src_name}") assert name not in names, f"Conflict test name {name}" names.add(name) jit_declares.append(f"JIT_TEST({name});") test_defs.append(f""" /* From {src_name} */ // @pyjt({name}) static inline void test_{name}() {{ jit_test_{name}(); }} """) jit_declares = "\n ".join(jit_declares) jit_src = f""" #pragma once #include "common.h" void expect_error(std::function<void()> func) {{ try {{ func(); }} catch (...) {{ return; }} CHECK(0) << "Missing error"; }} namespace jittor {{ {jit_declares} // @pyjt(tests) // @attrs(submodule) namespace tests {{ {"".join(test_defs)} }} }} // jittor """ LOG.vvvv(jit_src) with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w') as f: f.write(jit_src)
def install_cuda(): cuda_driver_version = get_cuda_driver() if not cuda_driver_version: return None LOG.i("cuda_driver_version: ", cuda_driver_version) if cuda_driver_version >= [11, 2]: cuda_tgz = "cuda11.2_cudnn8_linux.tgz" md5 = "b93a1a5d19098e93450ee080509e9836" elif cuda_driver_version >= [ 11, ]: cuda_tgz = "cuda11.0_cudnn8_linux.tgz" md5 = "5dbdb43e35b4db8249027997720bf1ca" elif cuda_driver_version >= [10, 2]: cuda_tgz = "cuda10.2_cudnn7_linux.tgz" md5 = "40f0563e8eb176f53e55943f6d212ad7" elif cuda_driver_version >= [ 10, ]: cuda_tgz = "cuda10.0_cudnn7_linux.tgz" md5 = "f16d3ff63f081031d21faec3ec8b7dac" else: raise RuntimeError( f"Unsupport cuda driver version: {cuda_driver_version}") jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda") nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") sys.path.append(nvcc_lib_path) new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path os.environ["LD_LIBRARY_PATH"] = new_ld_path if os.path.isfile(nvcc_path): return nvcc_path os.makedirs(jtcuda_path, exist_ok=True) cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz) download_url_to_local( "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + cuda_tgz, cuda_tgz, jtcuda_path, md5) import tarfile with tarfile.open(cuda_tgz_path, "r") as tar: tar.extractall(cuda_tgz_path[:-4]) assert os.path.isfile(nvcc_path) return nvcc_path
def hook_rand(): global rand_hooked if rand_hooked: return rand_hooked = True np.random.seed(0) if "torch" in sys.modules: LOG.i("Hook torch.rand") torch = sys.modules["torch"] torch.rand = hook_pt_rand torch.normal = hook_pt_normal torch.manual_seed(0) if "jittor" in sys.modules: jittor = sys.modules["jittor"] LOG.i("Hook jittor.random") jittor.random = hook_jt_rand jittor.seed(0)
def load_input(self): ''' for fake_input, fake_label in jittor_dataset: input, label = hook.load_input() input = jt.array(input) label = jt.array(label) ''' if self.mode == "check": self.record_status["[input]"] += 1 fpath = os.path.join( self.base_path, f"__input-{self.record_status['[input]']}.pkl") with open(fpath, 'rb') as f: data = pickle.load(f) LOG.i(f"load input: ok") return data else: raise RuntimeError("load_input is invalid in [save] mode")
def __init__(self, root, transform=None): super().__init__() self.root = root self.transform = transform self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) self.class_to_idx = {v:k for k,v in enumerate(self.classes)} self.imgs = [] image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) for i, class_name in enumerate(self.classes): class_dir = os.path.join(root, class_name) for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): for fname in sorted(fnames): if os.path.splitext(fname)[-1].lower() in image_exts: path = os.path.join(class_dir, fname) self.imgs.append((path, i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs))
def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", ] global jit_utils_core_files jit_utils_core_files = files recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please restart jittor.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
def compile_extern(): # compile llvm passes if cc_type != "clang": return global kernel_opt_flags cache_path_llvm = os.path.join(cache_path, "llvm") jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm") clang_dir = os.path.dirname(get_full_path_of_executable(cc_path)) assert clang_dir.endswith( "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}" llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include")) assert os.path.isdir(llvm_include), "LLVM include path not found" make_cache_dir(cache_path_llvm) files = os.listdir(jittor_path_llvm) # test_pass.cc is used for test link problem of llvm pass plugin test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc") with open(test_pass_path, 'w') as f: f.write("int main() {return 0;}") # -fno-rtti fix link error # -Wl,-znodelete fix segfault # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287 # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt # try different flags try_flags = [ " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ", " -Wl,-znodelete ", ] found_flags_id = -1 for fname in files: for i, flag in enumerate(try_flags): if found_flags_id != -1 and found_flags_id != i: continue so_name = os.path.join(cache_path_llvm, os.path.splitext(fname)[0] + f".{i}.so") compile(cc_path, f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'", [os.path.join(jittor_path_llvm, fname)], so_name) # if not found available flags, we test it. if found_flags_id == -1: try: s = run_cmd( f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}", cache_path_llvm, print_error=False) except Exception as e: LOG.v(f"Try flag {flag} failed: {e}") continue found_flags_id = i kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' " break else: LOG.w("Clang is used, but LLVM pass plugin is unable to link.") break LOG.vv(f"Compile extern llvm passes: {str(files)}")
def hook_module(self, mod, mod_name=""): if os.environ.get("use_auto_diff", '1') == '0': return if mod_name != "": mod_name = "<" + mod_name + ">" self.hooked_models[mod_name] = mod def forward_hook(self2, input, output, kw=None): ex_name = '[' + self2.__class__.__name__ + ']' if "relu" not in self2.__class__.__name__.lower(): # not test relu, because input may be inplaced self.record(self2.__ad_mod_name__ + ".input", input, ex_name) self.record(self2.__ad_mod_name__ + ".output", output, ex_name) if kw is not None: self.record(self2.__ad_mod_name__ + ".kw", kw, ex_name) names = [] for name, module in mod.named_modules(): ns = name.split('.') skip = 0 for n in ns: if n.startswith('_'): skip = 1 if skip: LOG.i("skip", name) continue name = mod_name + name module.__ad_mod_name__ = name names.append(name) module.register_forward_hook(forward_hook) mod_class_name = module.__class__.__name__.lower() # make dropout in eval mod and record dropout.p if "dropout" in mod_class_name: self.record(name + '.p', module.p, "[" + mod_class_name + "]") module.eval() ps = {mod_name + k: v for k, v in mod.state_dict().items()} self.record_params(ps, mod_name) self.record("module names", names)
def run(): start_time = time.time() fop_num = 10000 fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5] # fop_output_num = (1, 0) # [1,1] inner_op_num = (0, 3) fop_type_num = 63 # how many different fuse op input_queue_num = 15 queue = [1.0]*(input_queue_num+1) x = get_xorshf96() rand = lambda x, l, r: l+((x())&r) ops = ["add", "subtract", "multiply", "divide"] get_op = lambda x: ops[(x())&3] for i in range(fop_num): prev = bc(queue[rand(x,0,input_queue_num)]) y = get_xorshf96(x()&fop_type_num) inum = rand(y, *fop_input_num) q = [prev] for i in range(inum-1): n = bc(queue[rand(x,0,input_queue_num)]) prev = jt.binary(prev, n, get_op(y)) q.append(prev) innum = rand(y,*inner_op_num) for _ in range(innum): j = rand(y,0,len(q)-1) n = q[j] prev = jt.binary(prev, n, get_op(y)) q[j] = prev prev = rd(prev) queue[rand(x,0,input_queue_num)] = prev a = jt.array(0.0) for x in queue: a += x LOG.i("build graph", time.time()-start_time, jt.liveness_info().values()) start_time = time.time() a.sync() LOG.i("execute", time.time()-start_time)
def check_array(self, name, a, b): rtol = self.rtol atol = self.atol global has_error err = np.abs(a-b) tol = atol + rtol * np.abs(b) is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5)) index = np.where(is_error) assert len(index)>0 if len(index[0]) == 0: return has_error += 1 LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}") i = tuple( i[0] for i in index ) err_rate = is_error.mean() LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ") if err_rate > 0.01: LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename, h, class_info): # func_head is a string like: # (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* lib_name = os.path.basename(h).split("_")[0] # TODO: fix/add var help if target_scope_name == "Var": target_scope_name = None if target_scope_name: if target_scope_name == "flags": help_name = "flags" else: help_name = "" + target_scope_name + '.' + name else: help_name = name if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]: help_name = lib_name + '.' + help_name help_cmd = f"help(jt.{help_name})" LOG.vvv("gen err from func_head", func_head) args = func_head[1:].split(")")[0].split(",") error_code = f" << \"Wrong inputs arguments, Please refer to examples(e.g. {help_cmd}).\"" error_code += r' << "\n\nTypes of your inputs are:\n"' for arg in args: arg = arg.strip() if arg.startswith("PyObject* "): t, n = arg.split(' ') if n == "args" or n == "_args": error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} " elif n == "kw": error_code += f" << PyKwArgPrinter{{{n}}} " else: error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} " elif arg.startswith("PyObject** "): t, n = arg.split(' ') error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} " break else: LOG.vvv("Unhandled arg", arg) LOG.vvv("gen err from func_head", func_head, " -> ", error_code) return error_code
def __init__(self, base_name, rtol=5e-2, atol=1e-3): if os.environ.get("use_auto_diff", '1') == '0': return hook_rand() self.rid = 0 self.base_name = base_name self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name) if not os.path.exists(self.base_path): os.makedirs(self.base_path, exist_ok=True) self.mode = 'save' else: self.mode = 'check' self.record_status = defaultdict(int) self.rtol = rtol self.atol = atol self.param_name_map = {} self.hooked_models = {} LOG.i(f"Jittor AutoDiff: [{self.mode}] mode") LOG.i("Use cache path:", self.base_path) LOG.i(f"rtol:{rtol} atol:{atol}")
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return rid = self.rid self.rid += 1 fpath = os.path.join(self.base_path, f"{rid}.pkl") data = convert(data) if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) if pre_name != name: global has_error has_error += 1 LOG.e(f"The {rid} result name not match, {pre_name} != {name}") self.rid -= 1 return LOG.i(f"check {rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {rid}:<{name}> ok")
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return self.record_status[name] += 1 fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl") data = convert(data) self.rid += 1 if self.mode == 'check': if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) LOG.i(f"check {self.rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: global has_error has_error += 1 LOG.e(f"No previous result found: {name}") return else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {self.rid}:<{name}> ok")
def try_find_exe(*args): try: return find_exe(*args) except: LOG.v(f"{args[0]} not found.") return ""