コード例 #1
0
    def test_utils_abnormal_checker_wrapper(self):
        model = Model()

        with tempfile.TemporaryDirectory() as tmp_dir:
            checker = ac.AbnormalLossChecker(-1,
                                             writers=[ac.FileWriter(tmp_dir)])
            cmodel = ac.AbnormalLossCheckerWrapper(model, checker)

            losses = [5, 4, 3, 10, 9, 2, 5, 4]
            for loss in losses:
                cur = cmodel(loss)
                cur_gt = model(loss)
                self.assertEqual(cur, cur_gt)

            log_files = glob.glob(f"{tmp_dir}/*.pth")
            self.assertEqual(len(log_files), 2)

            GT_INVALID_INDICES = [3, 6]
            logged_indices = []
            for cur_log_file in log_files:
                cur_log = torch.load(cur_log_file, map_location="cpu")
                self.assertIsInstance(cur_log, dict)
                self.assertIn("data", cur_log)
                logged_indices.append(cur_log["step"])
            self.assertSetEqual(set(logged_indices), set(GT_INVALID_INDICES))
コード例 #2
0
ファイル: default_runner.py プロジェクト: Pandinosaurus/d2go
        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret
コード例 #3
0
        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = self.get_tbx_writer(cfg)
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret