Ejemplo n.º 1
0
 def record_params(self, parameters_dict):
     if os.environ.get("use_auto_diff", '1') == '0':
         return
     rid = self.rid
     self.rid += 1
     global has_error
     pps = {}
     for k, v in parameters_dict.items():
         if k.endswith("num_batches_tracked"):
             continue
         pps[k] = v
     ps = {name: convert(param) for name, param in pps.items()}
     fpath = os.path.join(self.base_path, f"{rid}-params.pkl")
     if os.path.isfile(fpath):
         with open(fpath, 'rb') as f:
             prev_ps = pickle.load(f)
         if len(prev_ps) != len(ps):
             has_error += 1
             LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}")
         for k in ps:
             a = ps[k]
             if k not in prev_ps:
                 has_error += 1
                 LOG.e(f"prev param <{k}> not found.")
                 continue
             b = prev_ps[k]
             if a.shape != b.shape:
                 has_error += 1
                 LOG.e(
                     f"Params <{k}> shape not match {a.shape} != {b.shape}")
                 continue
             std_a, mean_a = a.std(), a.mean()
             std_b, mean_b = b.std(), b.mean()
             n = a.size
             # law of large number
             std_mean_a = (std_a + std_b) / 2 / np.sqrt(n) + 1e-6
             std_std_a = (std_a + std_b) / 2 / np.sqrt((n - 1) / 2) + 1e-6
             x = 4
             if np.abs(mean_a - mean_b) > x * std_mean_a:
                 has_error += 1
                 LOG.e(
                     f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}"
                 )
             elif np.abs(std_a - std_b) > x * std_std_a:
                 has_error += 1
                 LOG.e(
                     f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}"
                 )
             else:
                 LOG.i(f"check param ok: <{k}>  shape:{a.shape}")
             var = pps[k]
             if hasattr(var, "copy_"):
                 import torch
                 var.data.copy_(torch.from_numpy(b))
             else:
                 var.assign(b)
     else:
         with open(fpath, 'wb') as f:
             pickle.dump(ps, f)
         LOG.i(f"save params ok")
Ejemplo n.º 2
0
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
        "src/utils/str_utils.cc",
    ]
    if os.name == 'nt':
        files = [x.replace('/', '\\') for x in files]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(
        cc_path, cc_flags + f" {opt_flags} ", files,
        jit_utils.cache_path + '/jit_utils_core' + extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please rerun your command.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags + f" {opt_flags} ", files,
                jit_utils.cache_path + '/jit_utils_core' + extension_suffix,
                True)
Ejemplo n.º 3
0
 def check(self, name, pre_data, data):
     global has_error
     if type(pre_data) != type(data):
         LOG.e(
             f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}"
         )
         has_error += 1
         return
     if isinstance(pre_data, (list, tuple)):
         if len(pre_data) != len(data):
             has_error += 1
             LOG.e(
                 f"Name <{name}> len not match, {len(pre_data)} != {len(data)}"
             )
         n = max(len(pre_data), len(data))
         for i in range(n):
             a = pre_data[i] if i < len(pre_data) else "None"
             b = data[i] if i < len(data) else "None"
             self.check(name + f".{i}", a, b)
     elif isinstance(pre_data, np.ndarray):
         if pre_data.shape != data.shape:
             has_error += 1
             LOG.e(f"Ndarray shape <{name}> not match")
             return
         self.check_array(name, pre_data, data)
     elif isinstance(pre_data, dict):
         if len(pre_data) != len(data):
             has_error += 1
             LOG.w(
                 f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}"
             )
         for k in pre_data:
             pv = pre_data[k]
             if k not in data:
                 has_error += 1
                 msg = f"Key <{k}> not in data, Name <{name}>"
                 if isinstance(pv, np.ndarray):
                     LOG.e(msg)
                 else:
                     LOG.w(msg)
                 continue
             self.check(name + f".{k}", pre_data[k], data[k])
     else:
         if pre_data != data:
             has_error += 1
             LOG.e(
                 f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}"
             )
Ejemplo n.º 4
0
def check_cache_compile():
    files = [
        "src/utils/cache_compile.cc",
        "src/utils/log.cc",
        "src/utils/tracer.cc",
        "src/utils/jit_utils.cc",
    ]
    global jit_utils_core_files
    jit_utils_core_files = files
    recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
    if recompile and jit_utils.cc:
        LOG.e("jit_utils updated, please restart jittor.")
        sys.exit(0)
    if not jit_utils.cc:
        with jit_utils.import_scope(import_flags):
            jit_utils.try_import_jit_utils_core()
        assert jit_utils.cc
        # recompile, generate cache key
        compile(cc_path, cc_flags+f" {opt_flags} ", files, 'jit_utils_core'+extension_suffix, True)
Ejemplo n.º 5
0
    def check_array(self, name, a, b):
        rtol = self.rtol
        atol = self.atol
        global has_error
        err = np.abs(a-b)
        tol = atol + rtol * np.abs(b)
        is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5))
        index = np.where(is_error)
        assert len(index)>0
        if len(index[0]) == 0:
            return

        has_error += 1
        LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}")
        i = tuple( i[0] for i in index )
        err_rate = is_error.mean()
        LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ")
        if err_rate > 0.01:
            LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
Ejemplo n.º 6
0
 def record(self, name, data, ex_name=""):
     if os.environ.get("use_auto_diff", '1') == '0':
         return
     rid = self.rid
     self.rid += 1
     fpath = os.path.join(self.base_path, f"{rid}.pkl")
     data = convert(data)
     if os.path.isfile(fpath):
         with open(fpath, 'rb') as f:
             pre_name, pre_data = pickle.load(f)
         if pre_name != name:
             global has_error
             has_error += 1
             LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
             self.rid -= 1
             return
         LOG.i(f"check {rid}:<{ex_name}{name}> ...")
         self.check(ex_name + name, pre_data, data)
     else:
         with open(fpath, 'wb') as f:
             pickle.dump((name, data), f)
         LOG.i(f"save {rid}:<{name}> ok")
Ejemplo n.º 7
0
    def record(self, name, data, ex_name=""):
        if os.environ.get("use_auto_diff", '1') == '0':
            return
        self.record_status[name] += 1
        fpath = os.path.join(self.base_path,
                             f"{name}-{self.record_status[name]}.pkl")
        data = convert(data)
        self.rid += 1

        if self.mode == 'check':
            if os.path.isfile(fpath):
                with open(fpath, 'rb') as f:
                    pre_name, pre_data = pickle.load(f)
                LOG.i(f"check {self.rid}:<{ex_name}{name}> ...")
                self.check(ex_name + name, pre_data, data)
            else:
                global has_error
                has_error += 1
                LOG.e(f"No previous result found: {name}")
                return
        else:
            with open(fpath, 'wb') as f:
                pickle.dump((name, data), f)
            LOG.i(f"save {self.rid}:<{name}> ok")