Exemplo n.º 1
0
    def test_utils_abnormal_checker_wrapper(self):
        model = Model()

        with tempfile.TemporaryDirectory() as tmp_dir:
            checker = ac.AbnormalLossChecker(-1,
                                             writers=[ac.FileWriter(tmp_dir)])
            cmodel = ac.AbnormalLossCheckerWrapper(model, checker)

            losses = [5, 4, 3, 10, 9, 2, 5, 4]
            for loss in losses:
                cur = cmodel(loss)
                cur_gt = model(loss)
                self.assertEqual(cur, cur_gt)

            log_files = glob.glob(f"{tmp_dir}/*.pth")
            self.assertEqual(len(log_files), 2)

            GT_INVALID_INDICES = [3, 6]
            logged_indices = []
            for cur_log_file in log_files:
                cur_log = torch.load(cur_log_file, map_location="cpu")
                self.assertIsInstance(cur_log, dict)
                self.assertIn("data", cur_log)
                logged_indices.append(cur_log["step"])
            self.assertSetEqual(set(logged_indices), set(GT_INVALID_INDICES))
Exemplo n.º 2
0
        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret
Exemplo n.º 3
0
        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = self.get_tbx_writer(cfg)
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret
Exemplo n.º 4
0
    def test_utils_abnormal_checker(self):
        counter = 0

        def _writer(all_data):
            nonlocal counter
            counter += 1

        checker = ac.AbnormalLossChecker(-1, writers=[_writer])
        losses = [5, 4, 3, 10, 9, 2, 5, 4]

        for loss in losses:
            checker.check_step({"loss": loss})

        self.assertEqual(counter, 2)