def check(self, name, pre_data, data): global has_error if pre_data is None and isinstance(data, np.ndarray): if (data == 0).all(): LOG.i(f"name {name} is None") else: LOG.e(f"name {name} is non-zero") return if type(pre_data) != type(data): LOG.e( f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}" ) has_error += 1 return if isinstance(pre_data, (list, tuple)): if len(pre_data) != len(data): has_error += 1 LOG.e( f"Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) n = max(len(pre_data), len(data)) for i in range(n): a = pre_data[i] if i < len(pre_data) else "None" b = data[i] if i < len(data) else "None" self.check(name + f".{i}", a, b) elif isinstance(pre_data, np.ndarray): if len(pre_data.shape) == 0: pre_data = np.array([pre_data]) if len(data.shape) == 0: data = np.array([data]) if pre_data.shape != data.shape: has_error += 1 LOG.e( f"Ndarray shape <{name}> not match {pre_data.shape} != {data.shape}" ) return self.check_array(name, pre_data, data) elif isinstance(pre_data, dict): if len(pre_data) != len(data): has_error += 1 LOG.w( f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) for k in pre_data: pv = pre_data[k] if k not in data: has_error += 1 msg = f"Key <{k}> not in data, Name <{name}>" if isinstance(pv, np.ndarray): LOG.e(msg) else: LOG.w(msg) continue self.check(name + f".{k}", pre_data[k], data[k]) else: if pre_data != data: has_error += 1 LOG.e( f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}" )
def compile_extern(): # compile llvm passes if cc_type != "clang": return global kernel_opt_flags cache_path_llvm = os.path.join(cache_path, "llvm") jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm") clang_dir = os.path.dirname(get_full_path_of_executable(cc_path)) assert clang_dir.endswith( "bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}" llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include")) assert os.path.isdir(llvm_include), "LLVM include path not found" make_cache_dir(cache_path_llvm) files = os.listdir(jittor_path_llvm) # test_pass.cc is used for test link problem of llvm pass plugin test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc") with open(test_pass_path, 'w') as f: f.write("int main() {return 0;}") # -fno-rtti fix link error # -Wl,-znodelete fix segfault # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287 # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt # try different flags try_flags = [ " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ", " -Wl,-znodelete ", ] found_flags_id = -1 for fname in files: for i, flag in enumerate(try_flags): if found_flags_id != -1 and found_flags_id != i: continue so_name = os.path.join(cache_path_llvm, os.path.splitext(fname)[0] + f".{i}.so") compile(cc_path, f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'", [os.path.join(jittor_path_llvm, fname)], so_name) # if not found available flags, we test it. if found_flags_id == -1: try: s = run_cmd( f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}", cache_path_llvm, print_error=False) except Exception as e: LOG.v(f"Try flag {flag} failed: {e}") continue found_flags_id = i kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' " break else: LOG.w("Clang is used, but LLVM pass plugin is unable to link.") break LOG.vv(f"Compile extern llvm passes: {str(files)}")
def check_array(self, name, a, b): rtol = self.rtol atol = self.atol global has_error err = np.abs(a-b) tol = atol + rtol * np.abs(b) is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5)) index = np.where(is_error) assert len(index)>0 if len(index[0]) == 0: return has_error += 1 LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}") i = tuple( i[0] for i in index ) err_rate = is_error.mean() LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ") if err_rate > 0.01: LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
def __iter__(self): if self.total_len is None: self.total_len = len(self) if self.shuffle == False: index_list = get_order_list(self.total_len) else: index_list = get_random_list(self.total_len) # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 # [........] [........] ... # 00011122 00011122 # if last batch is smaller than world_size # pad to world_size # last batch # [.] -> [012] if jt.in_mpi: world_size = mpi.world_size() world_rank = mpi.world_rank() index_list = np.int32(index_list) mpi.broadcast(index_list, 0) assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" real_batch_size = (self.batch_size-1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = self.total_len // self.batch_size last_batch = self.total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[ :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] real_last_batch = (last_batch-1)//world_size+1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch if l >= r: l = r-1 index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l self.real_len = len(index_list) self.real_batch_size = real_batch_size assert self.total_len // self.batch_size == \ self.real_len // self.real_batch_size else: self.real_len = self.total_len self.real_batch_size = self.batch_size self.batch_len = self.__batch_len__() if not hasattr(self, "workers") and self.num_workers: self._init_workers() if self.num_workers: self._stop_all_workers() self.index_list_numpy[:] = index_list gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() with gid_lock: gid_obj.value = 0 self.gidc.notify_all() start = time.time() self.batch_time = 0 for i in range(self.batch_len): # try not get lock first if gid_obj.value <= i: with gid_lock: if gid_obj.value <= i: if mp_log_v: print("wait") self.gidc.wait() now = time.time() self.wait_time = now - start start = now self.last_id = i worker_id = self.idmap[i] w = self.workers[worker_id] if mp_log_v: print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) batch = w.buffer.recv() now = time.time() self.recv_time = now - start start = now if mp_log_v: print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) batch = self.to_jittor(batch) now = time.time() self.to_jittor_time = now - start start = now yield batch now = time.time() self.batch_time = now - start start = now else: batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.real_batch_size: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) yield batch_data
if has_cuda: nvcc_flags = convert_nvcc_flags(cc_flags) nvcc_version = list(jit_utils.get_int_version(nvcc_path)) max_arch = 1000 if nvcc_version < [ 11, ]: max_arch = 75 elif nvcc_version < [11, 1]: max_arch = 80 if len(flags.cuda_archs): min_arch = 30 archs = [] for arch in flags.cuda_archs: if arch < min_arch: LOG.w(f"CUDA arch({arch})<{min_arch} is not supported") continue if arch > max_arch: LOG.w( f"CUDA arch({arch})>{max_arch} will be backward-compatible" ) arch = max_arch archs.append(arch) flags.cuda_archs = archs nvcc_flags += f" -arch=compute_{min(archs)} " nvcc_flags += ''.join(map(lambda x: f' -code=sm_{x} ', archs)) flags.cc_path = cc_path flags.cc_type = cc_type flags.cc_flags = cc_flags + kernel_opt_flags flags.nvcc_path = nvcc_path
def _get_index_list(self): if self.total_len is None: self.total_len = len(self) # maybe rewrite by sampler total_len = self.total_len if self.sampler: index_list = list(self.sampler.__iter__()) total_len = len(index_list) # check is not batch sampler if len(index_list): assert not isinstance( index_list[0], (list, tuple)), "Batch sampler not support yet." elif self.shuffle == False: index_list = get_order_list(self.total_len) else: # using _shuffle_rng to generate multiprocess # consist shuffle list # index_list = get_random_list(self.total_len) index_list = self._shuffle_rng.permutation(range(self.total_len)) # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 # [........] [........] ... # 00011122 00011122 # if last batch is smaller than world_size # pad to world_size # last batch # [.] -> [012] if jt.in_mpi: world_size = mpi.world_size() world_rank = mpi.world_rank() index_list = np.int32(index_list) # TODO: mpi broadcast in subprocess has bug, fix it # mpi.broadcast(index_list, 0) assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" real_batch_size = (self.batch_size - 1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = total_len // self.batch_size last_batch = total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[:, real_batch_size * world_rank:real_batch_size * (world_rank + 1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] real_last_batch = (last_batch - 1) // world_size + 1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch if l >= r: l = r - 1 index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l self.real_len = len(index_list) self.real_batch_size = real_batch_size assert total_len // self.batch_size == \ self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" else: self.real_len = self.total_len self.real_batch_size = self.batch_size self.batch_len = self.__batch_len__() return index_list
def get_param_name(self, p): if id(p) not in self.param_name_map: LOG.w("Param name not found", p.shape, id(p)) return "noname" + str(list(p.shape)) return self.param_name_map[id(p)]