def test_utils_abnormal_checker_wrapper(self): model = Model() with tempfile.TemporaryDirectory() as tmp_dir: checker = ac.AbnormalLossChecker(-1, writers=[ac.FileWriter(tmp_dir)]) cmodel = ac.AbnormalLossCheckerWrapper(model, checker) losses = [5, 4, 3, 10, 9, 2, 5, 4] for loss in losses: cur = cmodel(loss) cur_gt = model(loss) self.assertEqual(cur, cur_gt) log_files = glob.glob(f"{tmp_dir}/*.pth") self.assertEqual(len(log_files), 2) GT_INVALID_INDICES = [3, 6] logged_indices = [] for cur_log_file in log_files: cur_log = torch.load(cur_log_file, map_location="cpu") self.assertIsInstance(cur_log, dict) self.assertIn("data", cur_log) logged_indices.append(cur_log["step"]) self.assertSetEqual(set(logged_indices), set(GT_INVALID_INDICES))
def _get_model_with_abnormal_checker(model): if not cfg.ABNORMAL_CHECKER.ENABLED: return model tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) writers = abnormal_checker.get_writers(cfg, tbx_writer) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) return ret
def _get_model_with_abnormal_checker(model): if not cfg.ABNORMAL_CHECKER.ENABLED: return model tbx_writer = self.get_tbx_writer(cfg) writers = abnormal_checker.get_writers(cfg, tbx_writer) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) return ret