示例#1
0
    def test_d2go_runner_train_qat_hook_update_stat(self):
        """Check that the qat hook is used and updates stats"""
        @META_ARCH_REGISTRY.register()
        class MetaArchForTestQAT(MetaArchForTest):
            def prepare_for_quant(self, cfg):
                """Set the qconfig to updateable observers"""
                self.qconfig = updateable_symmetric_moving_avg_minmax_config
                return self

        def setup(tmp_dir):
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = default_runner.Detectron2GoRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)
            cfg.merge_from_list(
                (["MODEL.META_ARCHITECTURE", "MetaArchForTestQAT"] +
                 ["QUANTIZATION.QAT.ENABLED", "True"] +
                 ["QUANTIZATION.QAT.START_ITER", "0"] +
                 ["QUANTIZATION.QAT.ENABLE_OBSERVER_ITER", "0"]))
            return runner, cfg

        # check observers have not changed their minmax vals (stats changed)
        with tempfile.TemporaryDirectory() as tmp_dir:
            runner, cfg = setup(tmp_dir)
            model = runner.build_model(cfg)
            runner.do_train(cfg, model, resume=True)
            observer = model.conv.activation_post_process.activation_post_process
            self.assertEqual(observer.min_val, torch.tensor(float("inf")))
            self.assertEqual(observer.max_val, torch.tensor(float("-inf")))
            self.assertNotEqual(observer.max_stat, torch.tensor(float("inf")))

        # check observer does not change if period is > max_iter
        with tempfile.TemporaryDirectory() as tmp_dir:
            runner, cfg = setup(tmp_dir)
            cfg.merge_from_list(([
                "QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY", "True"
            ] + ["QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD", "10"]))
            model = runner.build_model(cfg)
            runner.do_train(cfg, model, resume=True)
            observer = model.conv.activation_post_process.activation_post_process
            self.assertEqual(observer.min_val, torch.tensor(float("inf")))
            self.assertEqual(observer.max_val, torch.tensor(float("-inf")))
            self.assertNotEqual(observer.max_stat, torch.tensor(float("inf")))

        # check observer changes if period < max_iter
        with tempfile.TemporaryDirectory() as tmp_dir:
            runner, cfg = setup(tmp_dir)
            cfg.merge_from_list(([
                "QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIODICALLY", "True"
            ] + ["QUANTIZATION.QAT.UPDATE_OBSERVER_STATS_PERIOD", "1"]))
            model = runner.build_model(cfg)
            runner.do_train(cfg, model, resume=True)
            observer = model.conv.activation_post_process.activation_post_process
            self.assertNotEqual(observer.min_val, torch.tensor(float("inf")))
            self.assertNotEqual(observer.max_val, torch.tensor(float("-inf")))
            self.assertNotEqual(observer.max_stat, torch.tensor(float("inf")))

        default_runner._close_all_tbx_writers()
示例#2
0
 def _run_train(cfg):
     cfg = copy.deepcopy(cfg)
     model = runner.build_model(cfg)
     model = DistributedDataParallel(model, broadcast_buffers=False)
     runner.do_train(cfg, model, True)
     final_model_path = os.path.join(tmp_dir, "model_final.pth")
     trained_weights = torch.load(final_model_path)
     self.assertIn("ema_state", trained_weights)
     default_runner._close_all_tbx_writers()
     return final_model_path, model.module.ema_state
示例#3
0
    def test_d2go_runner_test(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = default_runner.Detectron2GoRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)

            model = runner.build_model(cfg)
            results = runner.do_test(cfg, model)
            self.assertEqual(results["default"][ds_name]["bbox"]["AP"], 10.0)
            default_runner._close_all_tbx_writers()
示例#4
0
    def test_d2go_runner_train(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = default_runner.Detectron2GoRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)

            model = runner.build_model(cfg)
            runner.do_train(cfg, model, resume=True)
            final_model_path = os.path.join(tmp_dir, "model_final.pth")
            self.assertTrue(os.path.isfile(final_model_path))
            default_runner._close_all_tbx_writers()
示例#5
0
 def _run_test(cfg, final_path, gt_ema):
     cfg = copy.deepcopy(cfg)
     cfg.MODEL.WEIGHTS = final_path
     model = runner.build_model(cfg, eval_only=True)
     self.assertGreater(len(model.ema_state.state), 0)
     self.assertEqual(len(model.ema_state.state), len(gt_ema.state))
     self.assertTrue(
         _compare_state_dict(model.ema_state.state_dict(),
                             gt_ema.state_dict()))
     results = runner.do_test(cfg, model)
     self.assertEqual(results["default"][ds_name]["bbox"]["AP"],
                      3.0)
     self.assertEqual(results["ema"][ds_name]["bbox"]["AP"], 9.0)
     default_runner._close_all_tbx_writers()
示例#6
0
    def test_build_model(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = oss_runner.DETRRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)
            model = runner.build_model(cfg)
            dl = runner.build_detection_train_loader(cfg)
            batch = next(iter(dl))
            output = model(batch)
            self.assertIsInstance(output, dict)

            model.eval()
            output = model(batch)
            self.assertIsInstance(output, list)
            default_runner._close_all_tbx_writers()
    def test_d2go_runner_train_qat(self):
        """Make sure QAT runs"""
        @META_ARCH_REGISTRY.register()
        class MetaArchForTestQAT1(torch.nn.Module):
            def __init__(self, cfg):
                super().__init__()
                self.conv = torch.nn.Conv2d(3,
                                            4,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)

            @property
            def device(self):
                return self.conv.weight.device

            def forward(self, inputs):
                images = [x["image"] for x in inputs]
                images = ImageList.from_tensors(images, 1)

                ret = self.conv(images.tensor)
                losses = {"loss": ret.norm()}

                # run the same conv again
                ret1 = self.conv(images.tensor)
                losses["ret1"] = ret1.norm()

                return losses

        def setup(tmp_dir, backend):
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = default_runner.Detectron2GoRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)
            cfg.merge_from_list(
                (["MODEL.META_ARCHITECTURE", "MetaArchForTestQAT1"] +
                 ["QUANTIZATION.QAT.ENABLED", "True"] +
                 ["QUANTIZATION.QAT.START_ITER", "0"] +
                 ["QUANTIZATION.QAT.ENABLE_OBSERVER_ITER", "0"] +
                 ["QUANTIZATION.BACKEND", backend]))
            return runner, cfg

        for backend in ["fbgemm", "qnnpack"]:
            with tempfile.TemporaryDirectory() as tmp_dir:
                runner, cfg = setup(tmp_dir, backend=backend)
                model = runner.build_model(cfg)
                runner.do_train(cfg, model, resume=True)

            default_runner._close_all_tbx_writers()
示例#8
0
    def test_modeling_hook_runner(self):
        """Create model with modeling hook from runner"""
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = runner.build_model(cfg)
        self.assertEqual(model(2), 10)

        self.assertTrue(hasattr(model, "_modeling_hooks"))
        self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

        default_runner._close_all_tbx_writers()
示例#9
0
    def test_d2go_runner_trainer_hooks(self):
        counts = 0

        @TRAINER_HOOKS_REGISTRY.register()
        def _check_hook_func(hooks):
            nonlocal counts
            counts = len(hooks)
            print(hooks)

        with tempfile.TemporaryDirectory() as tmp_dir:
            ds_name = create_local_dataset(tmp_dir, 5, 10, 10)
            runner = default_runner.Detectron2GoRunner()
            cfg = _get_cfg(runner, tmp_dir, ds_name)
            model = runner.build_model(cfg)
            runner.do_train(cfg, model, resume=True)

            default_runner._close_all_tbx_writers()

        self.assertGreater(counts, 0)