def record_params(self, parameters_dict): if os.environ.get("use_auto_diff", '1') == '0': return rid = self.rid self.rid += 1 global has_error pps = {} for k, v in parameters_dict.items(): if k.endswith("num_batches_tracked"): continue pps[k] = v ps = {name: convert(param) for name, param in pps.items()} fpath = os.path.join(self.base_path, f"{rid}-params.pkl") if os.path.isfile(fpath): with open(fpath, 'rb') as f: prev_ps = pickle.load(f) if len(prev_ps) != len(ps): has_error += 1 LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}") for k in ps: a = ps[k] if k not in prev_ps: has_error += 1 LOG.e(f"prev param <{k}> not found.") continue b = prev_ps[k] if a.shape != b.shape: has_error += 1 LOG.e( f"Params <{k}> shape not match {a.shape} != {b.shape}") continue std_a, mean_a = a.std(), a.mean() std_b, mean_b = b.std(), b.mean() n = a.size # law of large number std_mean_a = (std_a + std_b) / 2 / np.sqrt(n) + 1e-6 std_std_a = (std_a + std_b) / 2 / np.sqrt((n - 1) / 2) + 1e-6 x = 4 if np.abs(mean_a - mean_b) > x * std_mean_a: has_error += 1 LOG.e( f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}" ) elif np.abs(std_a - std_b) > x * std_std_a: has_error += 1 LOG.e( f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}" ) else: LOG.i(f"check param ok: <{k}> shape:{a.shape}") var = pps[k] if hasattr(var, "copy_"): import torch var.data.copy_(torch.from_numpy(b)) else: var.assign(b) else: with open(fpath, 'wb') as f: pickle.dump(ps, f) LOG.i(f"save params ok")
def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", "src/utils/str_utils.cc", ] if os.name == 'nt': files = [x.replace('/', '\\') for x in files] global jit_utils_core_files jit_utils_core_files = files recompile = compile( cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please rerun your command.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags + f" {opt_flags} ", files, jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
def check(self, name, pre_data, data): global has_error if type(pre_data) != type(data): LOG.e( f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}" ) has_error += 1 return if isinstance(pre_data, (list, tuple)): if len(pre_data) != len(data): has_error += 1 LOG.e( f"Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) n = max(len(pre_data), len(data)) for i in range(n): a = pre_data[i] if i < len(pre_data) else "None" b = data[i] if i < len(data) else "None" self.check(name + f".{i}", a, b) elif isinstance(pre_data, np.ndarray): if pre_data.shape != data.shape: has_error += 1 LOG.e(f"Ndarray shape <{name}> not match") return self.check_array(name, pre_data, data) elif isinstance(pre_data, dict): if len(pre_data) != len(data): has_error += 1 LOG.w( f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}" ) for k in pre_data: pv = pre_data[k] if k not in data: has_error += 1 msg = f"Key <{k}> not in data, Name <{name}>" if isinstance(pv, np.ndarray): LOG.e(msg) else: LOG.w(msg) continue self.check(name + f".{k}", pre_data[k], data[k]) else: if pre_data != data: has_error += 1 LOG.e( f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}" )
def check_cache_compile(): files = [ "src/utils/cache_compile.cc", "src/utils/log.cc", "src/utils/tracer.cc", "src/utils/jit_utils.cc", ] global jit_utils_core_files jit_utils_core_files = files recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True) if recompile and jit_utils.cc: LOG.e("jit_utils updated, please restart jittor.") sys.exit(0) if not jit_utils.cc: with jit_utils.import_scope(import_flags): jit_utils.try_import_jit_utils_core() assert jit_utils.cc # recompile, generate cache key compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
def check_array(self, name, a, b): rtol = self.rtol atol = self.atol global has_error err = np.abs(a-b) tol = atol + rtol * np.abs(b) is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5)) index = np.where(is_error) assert len(index)>0 if len(index[0]) == 0: return has_error += 1 LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}") i = tuple( i[0] for i in index ) err_rate = is_error.mean() LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ") if err_rate > 0.01: LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return rid = self.rid self.rid += 1 fpath = os.path.join(self.base_path, f"{rid}.pkl") data = convert(data) if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) if pre_name != name: global has_error has_error += 1 LOG.e(f"The {rid} result name not match, {pre_name} != {name}") self.rid -= 1 return LOG.i(f"check {rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {rid}:<{name}> ok")
def record(self, name, data, ex_name=""): if os.environ.get("use_auto_diff", '1') == '0': return self.record_status[name] += 1 fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl") data = convert(data) self.rid += 1 if self.mode == 'check': if os.path.isfile(fpath): with open(fpath, 'rb') as f: pre_name, pre_data = pickle.load(f) LOG.i(f"check {self.rid}:<{ex_name}{name}> ...") self.check(ex_name + name, pre_data, data) else: global has_error has_error += 1 LOG.e(f"No previous result found: {name}") return else: with open(fpath, 'wb') as f: pickle.dump((name, data), f) LOG.i(f"save {self.rid}:<{name}> ok")