Example #1
0
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        trainer = self._get_trainer(tmp_dir)

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.model.state_dict()))

        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.ema_state.state_dict()))
Example #2
0
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

            def prepare_for_quant(self, cfg):
                example_inputs = (torch.rand(1, 3, 3, 3), )
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
                    {
                        "":
                        set_backend_and_create_qconfig(cfg,
                                                       is_train=self.training)
                    },
                    example_inputs,
                    self.custom_config_dict,
                )
                return self

            def prepare_for_quant_convert(self, cfg):
                self.avgpool = convert_fx(
                    self.avgpool,
                    convert_custom_config_dict=self.custom_config_dict)
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
            logger=False,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR,
                                              save_last=True)

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.model.state_dict()))

        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.ema_state.state_dict()))