コード例 #1
0
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR,
                                              save_last=True)

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt"))
        self.assertTrue(
            self._compare_state_dict(task.ema_state.state_dict(),
                                     task2.ema_state.state_dict()))

        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(task.ema_state.state_dict(),
                                     task2.model.state_dict()))
コード例 #2
0
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
            )
        )

        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
            )
        )
コード例 #3
0
    def test_load_from_checkpoint(self) -> None:
        with tempfile.TemporaryDirectory() as tmp_dir:
            task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
            from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint
            checkpoint_callback = ModelCheckpoint(
                directory=task.cfg.OUTPUT_DIR, has_user_data=False)
            params = {
                "max_steps": 1,
                "limit_train_batches": 1,
                "num_sanity_val_steps": 0,
                "checkpoint_callback": checkpoint_callback,
            }
            trainer = pl.Trainer(**params)
            with EventStorage() as storage:
                task.storage = storage
                trainer.fit(task)
                ckpt_path = os.path.join(tmp_dir, "test.ckpt")
                trainer.save_checkpoint(ckpt_path)
                self.assertTrue(os.path.exists(ckpt_path))

                # load model weights from checkpoint
                task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
                self.assertTrue(
                    self._compare_state_dict(task.model.state_dict(),
                                             task2.model.state_dict()))
コード例 #4
0
    def test_qat(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest):
            custom_config_dict = {"preserved_attributes": ["preserved_attr"]}

            def __init__(self, cfg):
                super().__init__(cfg)
                self.avgpool.preserved_attr = "foo"
                self.avgpool.not_preserved_attr = "bar"

            def prepare_for_quant(self, cfg):
                example_inputs = (torch.rand(1, 3, 3, 3), )
                self.avgpool = prepare_qat_fx(
                    self.avgpool,
                    {
                        "":
                        set_backend_and_create_qconfig(cfg,
                                                       is_train=self.training)
                    },
                    example_inputs,
                    self.custom_config_dict,
                )
                return self

            def prepare_for_quant_convert(self, cfg):
                self.avgpool = convert_fx(
                    self.avgpool,
                    convert_custom_config_dict=self.custom_config_dict)
                return self

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest"
        cfg.QUANTIZATION.QAT.ENABLED = True
        task = GeneralizedRCNNTask(cfg)

        callbacks = [
            QuantizationAwareTraining.from_config(cfg),
            ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True),
        ]
        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=callbacks,
            logger=False,
        )
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
        prepared_avgpool = task._prepared.model.avgpool
        self.assertEqual(prepared_avgpool.preserved_attr, "foo")
        self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr"))

        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
コード例 #5
0
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
            self.assertTrue(
                self._compare_state_dict(task.model.state_dict(),
                                         task2.model.state_dict()))
コード例 #6
0
    def test_train_ema(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        cfg.MODEL_EMA.DECAY = 0.7
        task = GeneralizedRCNNTask(cfg)
        init_state = deepcopy(task.model.state_dict())

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        for k, v in task.model.state_dict().items():
            init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

        self.assertTrue(
            self._compare_state_dict(init_state, task.ema_state.state_dict()))
コード例 #7
0
def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask):
    """Runs the evaluation with a pre-trained model.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.
        trainer: PyTorch Lightning trainer.
        task: Lightning module instance.

    """
    with EventStorage() as storage:
        task.storage = storage
        trainer.test(task)
コード例 #8
0
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        trainer = self._get_trainer(tmp_dir)

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.model.state_dict()))

        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.ema_state.state_dict()))
コード例 #9
0
    def test_meta_arch_training_step(self, tmp_dir):
        @META_ARCH_REGISTRY.register()
        class DetMetaArchForWithTrainingStep(mah.DetMetaArchForTest):
            def training_step(self, batch, batch_idx, opt, manual_backward):
                assert batch
                assert opt
                assert manual_backward
                # We step the optimizer for progress tracking to occur
                # This is reflected in the Trainer's global_step property
                # which is used to determine when to stop training
                # when specifying the loop bounds with Trainer(max_steps=N)
                opt.step()
                return {"total_loss": 0.4}

        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForWithTrainingStep"

        task = GeneralizedRCNNTask(cfg)

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
コード例 #10
0
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
        params = {
            "max_steps": 1,
            "limit_train_batches": 1,
            "num_sanity_val_steps": 0,
            "checkpoint_callback": checkpoint_callback,
        }
        trainer = pl.Trainer(**params)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
            self.assertTrue(
                self._compare_state_dict(task.model.state_dict(),
                                         task2.model.state_dict()))
コード例 #11
0
    def test_train_ema(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            cfg = self._get_cfg(tmp_dir)
            cfg.MODEL_EMA.ENABLED = True
            cfg.MODEL_EMA.DECAY = 0.7
            task = GeneralizedRCNNTask(cfg)
            init_state = deepcopy(task.model.state_dict())

            trainer = pl.Trainer(
                max_steps=1,
                limit_train_batches=1,
                num_sanity_val_steps=0,
            )
            with EventStorage() as storage:
                task.storage = storage
                trainer.fit(task)

            for k, v in task.model.state_dict().items():
                init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v)

            self.assertTrue(
                self._compare_state_dict(init_state,
                                         task.ema_state.state_dict()))
コード例 #12
0
def do_train(cfg: CfgNode, trainer: pl.Trainer,
             task: GeneralizedRCNNTask) -> Dict[str, str]:
    """Runs the training loop with given trainer and task.

    Args:
        cfg: The normalized ConfigNode for this D2Go Task.
        trainer: PyTorch Lightning trainer.
        task: Lightning module instance.

    Returns:
        A map of model name to trained model config path.
    """
    with EventStorage() as storage:
        task.storage = storage
        trainer.fit(task)
        final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT)
        trainer.save_checkpoint(final_ckpt)  # for validation monitor

        trained_cfg = cfg.clone()
        with temp_defrost(trained_cfg):
            trained_cfg.MODEL.WEIGHTS = final_ckpt
        model_configs = dump_trained_model_configs(
            cfg.OUTPUT_DIR, {"model_final": trained_cfg})
    return model_configs
コード例 #13
0
    def test_build_model(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR,
                                              save_last=True)

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # test building untrained model
        model = GeneralizedRCNNTask.build_model(cfg)
        self.assertTrue(model.training)

        # test loading regular weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.model.state_dict()))

        # test loading EMA weights
        with temp_defrost(cfg):
            cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt")
            cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True
            model = GeneralizedRCNNTask.build_model(cfg, eval_only=True)
            self.assertFalse(model.training)
            self.assertTrue(
                self._compare_state_dict(model.state_dict(),
                                         task.ema_state.state_dict()))