def check(self, name, pre_data, data): global has_error if pre_data is None and isinstance(data, np.ndarray): if (data == 0).all(): LOG.i(f"name {name} is None") else: LOG.e(f"name {name} is non-zero") return if type(pre_data) != type(data): LOG.e( f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}" ) has_error += 1 return if isinstance(pre_data, (list, tuple)): if len(pre_data) != len(data): has_error += 1 LOG.e( f"Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) n = max(len(pre_data), len(data)) for i in range(n): a = pre_data[i] if i < len(pre_data) else "None" b = data[i] if i < len(data) else "None" self.check(name + f".{i}", a, b) elif isinstance(pre_data, np.ndarray): if len(pre_data.shape) == 0: pre_data = np.array([pre_data]) if len(data.shape) == 0: data = np.array([data]) if pre_data.shape != data.shape: has_error += 1 LOG.e( f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}" ) return self.check_array(name, pre_data, data) elif isinstance(pre_data, dict): if len(pre_data) != len(data): has_error += 1 LOG.w( f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) for k in pre_data: pv = pre_data[k] if k not in data: has_error += 1 msg = f"Key <{k}> not in data, Name <{name}>" if isinstance(pv, np.ndarray): LOG.e(msg) else: LOG.w(msg) continue self.check(name + f".{k}", pre_data[k], data[k]) else: if pre_data != data: has_error += 1 LOG.e( f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}" )
def record_params(self, parameters_dict): if os.environ.get("use_auto_diff", '1') == '0': return rid = self.rid self.rid += 1 global has_error pps = {} for k, v in parameters_dict.items(): if k.endswith("num_batches_tracked"): continue pps[k] = v ps = {name: convert(param) for name, param in pps.items()} fpath = os.path.join(self.base_path, f"{rid}-params.pkl") if os.path.isfile(fpath): with open(fpath, 'rb') as f: prev_ps = pickle.load(f) if len(prev_ps) != len(ps): has_error += 1 LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}") for k in ps: a = ps[k] if k not in prev_ps: has_error += 1 LOG.e(f"prev param <{k}> not found.") continue b = prev_ps[k] if a.shape != b.shape: has_error += 1 LOG.e( f"Params <{k}> shape not match {a.shape} != {b.shape}") continue std_a, mean_a = a.std(), a.mean() std_b, mean_b = b.std(), b.mean() n = a.size # law of large number std_mean_a = (std_a + std_b) / 2 / np.sqrt(n) + 1e-6 std_std_a = (std_a + std_b) / 2 / np.sqrt((n - 1) / 2) + 1e-6 x = 4 if np.abs(mean_a - mean_b) > x * std_mean_a: has_error += 1 LOG.e( f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}" ) elif np.abs(std_a - std_b) > x * std_std_a: has_error += 1 LOG.e( f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}" ) else: LOG.i(f"check param ok: <{k}> shape:{a.shape}") var = pps[k] if hasattr(var, "copy_"): import torch var.data.copy_(torch.from_numpy(b)) else: var.assign(b) else: with open(fpath, 'wb') as f: pickle.dump(ps, f) LOG.i(f"save params ok")
def env_or_try_find(name, bname): if name in os.environ: path = os.environ[name] if path != "": version = jit_utils.get_version(path) LOG.i(f"Found {bname}{version} at {path}") return path return try_find_exe(bname)
def display_worker_status(self): ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow: .. code-block:: console progress:479/5005 batch(s): 0.302 wait(s):0.000 recv(s): 0.069 to_jittor(s):0.021 recv_raw_call: 6720.0 last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] ID wait(s) load(s) send(s) total #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912) Meaning of the outputs: * progress: dataset loading progress (current/total) * batch: batch time, exclude data loading time * wait: time of main proc wait worker proc * recv: time of recv batch data * to_jittor: time of batch data to jittor variable * recv_raw_call: total number of underlying recv_raw called * last 10 workers: id of last 10 workers which main proc load from. * table meaning * ID: worker id * wait: worker wait time * open: worker image open time * load: worker load time * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). Example:: from jittor.dataset import Dataset class YourDataset(Dataset): pass dataset = YourDataset().set_attrs(num_workers=8) for x, y in dataset: dataset.display_worker_status() ''' if not hasattr(self, "workers"): return msg = [""] msg.append(f"progress:{self.last_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): w = self.workers[i] s = w.status msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") LOG.i('\n'.join(msg))
def install(path): LOG.i("Installing MSVC...") filename = "msvc.zip" url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename md5sum = "55f0c175fdf1419b124e0fc498b659d2" download_url_to_local(url, filename, path, md5sum) fullname = os.path.join(path, filename) import zipfile with zipfile.ZipFile(fullname, "r") as f: f.extractall(path)
def __init__(self, base_name, rtol=5e-2, atol=1e-3): if os.environ.get("use_auto_diff", '1') == '0': return hook_rand() self.rid = 0 self.base_name = base_name self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name) os.makedirs(self.base_path, exist_ok=True) self.rtol = rtol self.atol = atol LOG.i("Use cache path:", self.base_path) LOG.i(f"rtol:{rtol} atol:{atol}")
def save_input(self, *data): ''' for input, label in torch_dataloader: hook.save_input(data) ''' if self.mode == "save": self.record_status["[input]"] += 1 fpath = os.path.join( self.base_path, f"__input-{self.record_status['[input]']}.pkl") with open(fpath, 'wb') as f: pickle.dump(convert(data), f) LOG.i(f"save input: ok") else: raise RuntimeError("save_input is invalid in [check] mode")
def install_cuda(): cuda_driver_version = get_cuda_driver() if not cuda_driver_version: return None LOG.i("cuda_driver_version: ", cuda_driver_version) if cuda_driver_version >= [11, 2]: cuda_tgz = "cuda11.2_cudnn8_linux.tgz" md5 = "b93a1a5d19098e93450ee080509e9836" elif cuda_driver_version >= [ 11, ]: cuda_tgz = "cuda11.0_cudnn8_linux.tgz" md5 = "5dbdb43e35b4db8249027997720bf1ca" elif cuda_driver_version >= [10, 2]: cuda_tgz = "cuda10.2_cudnn7_linux.tgz" md5 = "40f0563e8eb176f53e55943f6d212ad7" elif cuda_driver_version >= [ 10, ]: cuda_tgz = "cuda10.0_cudnn7_linux.tgz" md5 = "f16d3ff63f081031d21faec3ec8b7dac" else: raise RuntimeError( f"Unsupport cuda driver version: {cuda_driver_version}") jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda") nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") sys.path.append(nvcc_lib_path) new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path os.environ["LD_LIBRARY_PATH"] = new_ld_path if os.path.isfile(nvcc_path): return nvcc_path os.makedirs(jtcuda_path, exist_ok=True) cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz) download_url_to_local( "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + cuda_tgz, cuda_tgz, jtcuda_path, md5) import tarfile with tarfile.open(cuda_tgz_path, "r") as tar: tar.extractall(cuda_tgz_path[:-4]) assert os.path.isfile(nvcc_path) return nvcc_path
def hook_rand(): global rand_hooked if rand_hooked: return rand_hooked = True np.random.seed(0) if "torch" in sys.modules: LOG.i("Hook torch.rand") torch = sys.modules["torch"] torch.rand = hook_pt_rand torch.normal = hook_pt_normal torch.manual_seed(0) if "jittor" in sys.modules: jittor = sys.modules["jittor"] LOG.i("Hook jittor.random") jittor.random = hook_jt_rand jittor.seed(0)
def load_input(self): ''' for fake_input, fake_label in jittor_dataset: input, label = hook.load_input() input = jt.array(input) label = jt.array(label) ''' if self.mode == "check": self.record_status["[input]"] += 1 fpath = os.path.join( self.base_path, f"__input-{self.record_status['[input]']}.pkl") with open(fpath, 'rb') as f: data = pickle.load(f) LOG.i(f"load input: ok") return data else: raise RuntimeError("load_input is invalid in [save] mode")
def __init__(self, root, transform=None): super().__init__() self.root = root self.transform = transform self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) self.class_to_idx = {v:k for k,v in enumerate(self.classes)} self.imgs = [] image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) for i, class_name in enumerate(self.classes): class_dir = os.path.join(root, class_name) for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): for fname in sorted(fnames): if os.path.splitext(fname)[-1].lower() in image_exts: path = os.path.join(class_dir, fname) self.imgs.append((path, i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs))
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return rid = self.rid self.rid += 1 fpath = os.path.join(self.base_path, f"{rid}.pkl") data = convert(data) if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) if pre_name != name: global has_error has_error += 1 LOG.e(f"The {rid} result name not match, {pre_name} != {name}") self.rid -= 1 return LOG.i(f"check {rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {rid}:<{name}> ok")
def hook_module(self, mod, mod_name=""): if os.environ.get("use_auto_diff", '1') == '0': return if mod_name != "": mod_name = "<" + mod_name + ">" self.hooked_models[mod_name] = mod def forward_hook(self2, input, output, kw=None): ex_name = '[' + self2.__class__.__name__ + ']' if "relu" not in self2.__class__.__name__.lower(): # not test relu, because input may be inplaced self.record(self2.__ad_mod_name__ + ".input", input, ex_name) self.record(self2.__ad_mod_name__ + ".output", output, ex_name) if kw is not None: self.record(self2.__ad_mod_name__ + ".kw", kw, ex_name) names = [] for name, module in mod.named_modules(): ns = name.split('.') skip = 0 for n in ns: if n.startswith('_'): skip = 1 if skip: LOG.i("skip", name) continue name = mod_name + name module.__ad_mod_name__ = name names.append(name) module.register_forward_hook(forward_hook) mod_class_name = module.__class__.__name__.lower() # make dropout in eval mod and record dropout.p if "dropout" in mod_class_name: self.record(name + '.p', module.p, "[" + mod_class_name + "]") module.eval() ps = {mod_name + k: v for k, v in mod.state_dict().items()} self.record_params(ps, mod_name) self.record("module names", names)
def run(): start_time = time.time() fop_num = 10000 fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5] # fop_output_num = (1, 0) # [1,1] inner_op_num = (0, 3) fop_type_num = 63 # how many different fuse op input_queue_num = 15 queue = [1.0]*(input_queue_num+1) x = get_xorshf96() rand = lambda x, l, r: l+((x())&r) ops = ["add", "subtract", "multiply", "divide"] get_op = lambda x: ops[(x())&3] for i in range(fop_num): prev = bc(queue[rand(x,0,input_queue_num)]) y = get_xorshf96(x()&fop_type_num) inum = rand(y, *fop_input_num) q = [prev] for i in range(inum-1): n = bc(queue[rand(x,0,input_queue_num)]) prev = jt.binary(prev, n, get_op(y)) q.append(prev) innum = rand(y,*inner_op_num) for _ in range(innum): j = rand(y,0,len(q)-1) n = q[j] prev = jt.binary(prev, n, get_op(y)) q[j] = prev prev = rd(prev) queue[rand(x,0,input_queue_num)] = prev a = jt.array(0.0) for x in queue: a += x LOG.i("build graph", time.time()-start_time, jt.liveness_info().values()) start_time = time.time() a.sync() LOG.i("execute", time.time()-start_time)
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return self.record_status[name] += 1 fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl") data = convert(data) self.rid += 1 if self.mode == 'check': if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) LOG.i(f"check {self.rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: global has_error has_error += 1 LOG.e(f"No previous result found: {name}") return else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {self.rid}:<{name}> ok")
def __init__(self, base_name, rtol=5e-2, atol=1e-3): if os.environ.get("use_auto_diff", '1') == '0': return hook_rand() self.rid = 0 self.base_name = base_name self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name) if not os.path.exists(self.base_path): os.makedirs(self.base_path, exist_ok=True) self.mode = 'save' else: self.mode = 'check' self.record_status = defaultdict(int) self.rtol = rtol self.atol = atol self.param_name_map = {} self.hooked_models = {} LOG.i(f"Jittor AutoDiff: [{self.mode}] mode") LOG.i("Use cache path:", self.base_path) LOG.i(f"rtol:{rtol} atol:{atol}")
import_flags |= os.RTLD_DEEPBIND # if cc_type=="icc": # # weird link problem, icc omp library may conflict and cause segfault # import_flags = os.RTLD_NOW | os.RTLD_GLOBAL dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL if platform.system() == 'Linux': import_flags |= os.RTLD_DEEPBIND with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() jittor_path = find_jittor_path() check_debug_flags() sys.path.append(cache_path) LOG.i(f"Jittor({__version__}) src: {jittor_path}") LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}") LOG.i(f"cache_path: {cache_path}") with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() python_path = sys.executable # sometime python do not return the correct sys executable # this will happend when multiple python version installed ex_python_path = python_path + '.' + str(sys.version_info.minor) if os.path.isfile(ex_python_path): python_path = ex_python_path py3_config_path = jit_utils.py3_config_path # if jtcuda is already installed
def install_cuda(): cuda_driver_version = get_cuda_driver() if not cuda_driver_version: return None LOG.i("cuda_driver_version: ", cuda_driver_version) if "JTCUDA_VERSION" in os.environ: cuda_driver_version = list( map(int, os.environ["JTCUDA_VERSION"].split("."))) LOG.i("JTCUDA_VERSION: ", cuda_driver_version) if os.name == 'nt': if cuda_driver_version >= [11, 4]: cuda_tgz = "cuda11.4_cudnn8_win.zip" md5 = "06eed370d0d44bb2cc57809343911187" elif cuda_driver_version >= [11, 2]: cuda_tgz = "cuda11.2_cudnn8_win.zip" md5 = "b5543822c21bc460c1a414af47754556" elif cuda_driver_version >= [ 11, ]: cuda_tgz = "cuda11.0_cudnn8_win.zip" md5 = "7a248df76ee5e79623236b0560f8d1fd" elif cuda_driver_version >= [ 10, ]: cuda_tgz = "cuda10.2_cudnn7_win.zip" md5 = "7dd9963833a91371299a2ba58779dd71" else: raise RuntimeError( f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2" ) else: if cuda_driver_version >= [11, 2]: cuda_tgz = "cuda11.2_cudnn8_linux.tgz" md5 = "b93a1a5d19098e93450ee080509e9836" elif cuda_driver_version >= [ 11, ]: cuda_tgz = "cuda11.0_cudnn8_linux.tgz" md5 = "5dbdb43e35b4db8249027997720bf1ca" elif cuda_driver_version >= [10, 2]: cuda_tgz = "cuda10.2_cudnn7_linux.tgz" md5 = "40f0563e8eb176f53e55943f6d212ad7" elif cuda_driver_version >= [ 10, ]: cuda_tgz = "cuda10.0_cudnn7_linux.tgz" md5 = "f16d3ff63f081031d21faec3ec8b7dac" else: raise RuntimeError( f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0" ) jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda") nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") if os.name == 'nt': nvcc_path += '.exe' nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") sys.path.append(nvcc_lib_path) new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path os.environ["LD_LIBRARY_PATH"] = new_ld_path if os.path.isfile(nvcc_path): return nvcc_path os.makedirs(jtcuda_path, exist_ok=True) cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz) download_url_to_local( "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + cuda_tgz, cuda_tgz, jtcuda_path, md5) if cuda_tgz.endswith(".zip"): import zipfile zf = zipfile.ZipFile(cuda_tgz_path) zf.extractall(path=cuda_tgz_path[:-4]) else: import tarfile with tarfile.open(cuda_tgz_path, "r") as tar: tar.extractall(cuda_tgz_path[:-4]) assert os.path.isfile(nvcc_path), nvcc_path return nvcc_path
cc_flags = " " # os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first import_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND # if cc_type=="icc": # # weird link problem, icc omp library may conflict and cause segfault # import_flags = os.RTLD_NOW | os.RTLD_GLOBAL dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL | os.RTLD_DEEPBIND with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() jittor_path = find_jittor_path() check_debug_flags() sys.path.append(cache_path) LOG.i(f"Jittor({__version__}) src: {jittor_path}") LOG.i(f"cache_path: {cache_path}") with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() python_path = sys.executable py3_config_paths = [ sys.executable + "-config", os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config", f"/usr/bin/python3.{sys.version_info.minor}-config", os.path.dirname(sys.executable) + "/python3-config", ] if "python_config_path" in os.environ: py3_config_paths.insert(0, os.environ["python_config_path"])
"") + opt_flags + " -fopenmp " if ' -O' not in cc_flags: opt_flags += " -O2 " kernel_opt_flags += " -Ofast " lto_flags = "" if os.environ.get("enable_lto") == "1": if cc_type == "icc": lto_flags = " -flto -ipo -ipo-c " elif cc_type == "g++": lto_flags = " -flto -fuse-linker-plugin " else: lto_flags = " -flto " pybind_include = run_cmd(python_path + " -m pybind11 --includes") LOG.i(f"pybind_include: {pybind_include}") extension_suffix = run_cmd(py3_config_path + " --extension-suffix") LOG.i(f"extension_suffix: {extension_suffix}") make_cache_dir(cache_path) make_cache_dir(os.path.join(cache_path, "jit")) make_cache_dir(os.path.join(cache_path, "obj_files")) make_cache_dir(os.path.join(cache_path, "gen")) # build cache_compile cc_flags += pybind_include cc_flags += f" -I{jittor_path}/src " check_cache_compile() LOG.v(f"Get cache_compile: {jit_utils.cc}") # check cuda
def compile_custom_ops(filenames, extra_flags="", return_module=False, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND, gen_name_=""): """Compile custom ops filenames: path of op source files, filenames must be pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the type name of op must be XxxXxxOp. extra_flags: extra compile flags return_module: return module rather than ops(default: False) return: compiled ops """ srcs = {} headers = {} builds = [] includes = [] pyjt_includes = [] for name in filenames: name = os.path.realpath(name) if name.endswith(".cc") or name.endswith(".cpp") or name.endswith( ".cu"): builds.append(name) if name.endswith(".h"): dirname = os.path.dirname(name) if dirname.endswith("inc"): includes.append(dirname) with open(name, "r") as f: if "@pyjt" in f.read(): pyjt_includes.append(name) bname = os.path.basename(name) bname = os.path.splitext(bname)[0] if bname.endswith("_op"): bname = bname[:-3] if name.endswith(".cc"): srcs[bname] = name elif name.endswith(".h"): includes.append(os.path.dirname(name)) headers[bname] = name assert len(srcs) == len(headers), "Source and header names not match" for name in srcs: assert name in headers, f"Header of op {name} not found" gen_name = "gen_ops_" + "_".join(headers.keys()) if gen_name_ != "": gen_name = gen_name_ if len(gen_name) > 100: gen_name = gen_name[:80] + "___hash" + str(hash(gen_name)) includes = set(includes) includes = "".join(map(lambda x: f" -I'{x}' ", includes)) LOG.vvvv(f"Include flags:{includes}") op_extra_flags = includes + extra_flags gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) make_cache_dir(os.path.join(cache_path, "custom_ops")) gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + ".cc") gen_head_fname = os.path.join(cache_path, "custom_ops", gen_name + ".h") gen_lib = os.path.join("custom_ops", gen_name + extension_suffix) pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) # gen src initialize first builds.insert(0, gen_src_fname) def insert_anchor(gen_src, anchor_str, insert_str): # insert insert_str after anchor_str into gen_src return gen_src.replace(anchor_str, anchor_str + insert_str, 1) for name in pyjt_includes: LOG.i("handle pyjt_include", name) bname = name.split("/")[-1].split(".")[0] gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name + "_" + bname + ".cc") pyjt_compiler.compile_single(name, gen_src_fname) builds.insert(1, gen_src_fname) gen_src = insert_anchor(gen_src, "namespace jittor {", f"extern void pyjt_def_{bname}(PyObject* m);") gen_src = insert_anchor( gen_src, "init_module(PyModuleDef* mdef, PyObject* m) {", f"jittor::pyjt_def_{bname}(m);") with open(gen_head_fname, "w") as f: f.write(gen_src) LOG.vvv(f"Build custum ops lib:{gen_lib}") LOG.vvvv(f"Build sources:{builds}") compile(cc_path, extra_flags + cc_flags + opt_flags + includes, builds, gen_lib) # add python path and import LOG.vvv(f"Import custum ops lib:{gen_lib}") lib_path = os.path.join(cache_path, "custom_ops") if lib_path not in os.sys.path: os.sys.path.append(lib_path) # unlock scope when initialize with lock.unlock_scope(): with jit_utils.import_scope(dlopen_flags): exec(f"import {gen_name}") mod = locals()[gen_name] if return_module: return mod return mod.ops
def make_cache_dir(cache_path): if not os.path.isdir(cache_path): LOG.i(f"Create cache dir: {cache_path}") os.mkdir(cache_path)
# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL if platform.system() == 'Linux': import_flags |= os.RTLD_DEEPBIND with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() jittor_path = find_jittor_path() if os.name == 'nt': # prevent windows recompile jittor_path = jittor_path.lower() check_debug_flags() sys.path.append(cache_path) LOG.i(f"Jittor({__version__}) src: {jittor_path}") LOG.i( f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}" ) LOG.i(f"cache_path: {cache_path}") with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() python_path = sys.executable # sometime python do not return the correct sys executable # this will happend when multiple python version installed ex_python_path = python_path + '.' + str(sys.version_info.minor) if os.path.isfile(ex_python_path): python_path = ex_python_path
raise RuntimeError( f"Unsupport cuda driver version: {cuda_driver_version}") jtcuda_path = os.path.join(pathlib.Path.home(), ".cache", "jittor", "jtcuda") nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") sys.path.append(nvcc_lib_path) new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path os.environ["LD_LIBRARY_PATH"] = new_ld_path if os.path.isfile(nvcc_path): return nvcc_path os.makedirs(jtcuda_path, exist_ok=True) cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz) download_url_to_local( "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + cuda_tgz, cuda_tgz, jtcuda_path, md5) import tarfile with tarfile.open(cuda_tgz_path, "r") as tar: tar.extractall(cuda_tgz_path[:-4]) assert os.path.isfile(nvcc_path) return nvcc_path if __name__ == "__main__": nvcc_path = install_cuda() LOG.i("nvcc is installed at ", nvcc_path)
class lock_scope(_base_scope): def __enter__(self): self.is_locked = jittor_lock.is_locked if not self.is_locked: jittor_lock.lock() def __exit__(self, *exc): if not self.is_locked: jittor_lock.unlock() class unlock_scope(_base_scope): def __enter__(self): self.is_locked = jittor_lock.is_locked if self.is_locked: jittor_lock.unlock() def __exit__(self, *exc): if self.is_locked: jittor_lock.lock() lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock")) if not os.path.exists(lock_path): LOG.i("Create lock file:", lock_path) try: os.mknod(lock_path) except: pass jittor_lock = Lock(lock_path)