def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR,
                                              save_last=True)

        trainer = pl.Trainer(
            max_steps=1,
            limit_train_batches=1,
            num_sanity_val_steps=0,
            callbacks=[checkpoint_callback],
        )

        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt"))
        self.assertTrue(
            self._compare_state_dict(task.ema_state.state_dict(),
                                     task2.ema_state.state_dict()))

        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(task.ema_state.state_dict(),
                                     task2.model.state_dict()))
    def test_load_ema_weights(self, tmp_dir):
        cfg = self._get_cfg(tmp_dir)
        cfg.MODEL_EMA.ENABLED = True
        task = GeneralizedRCNNTask(cfg)
        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)

        # load EMA weights from checkpoint
        task2 = GeneralizedRCNNTask.load_from_checkpoint(
            os.path.join(tmp_dir, "last.ckpt")
        )
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.ema_state.state_dict()
            )
        )

        # apply EMA weights to model
        task2.ema_state.apply_to(task2.model)
        self.assertTrue(
            self._compare_state_dict(
                task.ema_state.state_dict(), task2.model.state_dict()
            )
        )
    def test_load_from_checkpoint(self) -> None:
        with tempfile.TemporaryDirectory() as tmp_dir:
            task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))
            from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint
            checkpoint_callback = ModelCheckpoint(
                directory=task.cfg.OUTPUT_DIR, has_user_data=False)
            params = {
                "max_steps": 1,
                "limit_train_batches": 1,
                "num_sanity_val_steps": 0,
                "checkpoint_callback": checkpoint_callback,
            }
            trainer = pl.Trainer(**params)
            with EventStorage() as storage:
                task.storage = storage
                trainer.fit(task)
                ckpt_path = os.path.join(tmp_dir, "test.ckpt")
                trainer.save_checkpoint(ckpt_path)
                self.assertTrue(os.path.exists(ckpt_path))

                # load model weights from checkpoint
                task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
                self.assertTrue(
                    self._compare_state_dict(task.model.state_dict(),
                                             task2.model.state_dict()))
Exemple #4
0
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        trainer = self._get_trainer(tmp_dir)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
            self.assertTrue(
                self._compare_state_dict(task.model.state_dict(),
                                         task2.model.state_dict()))
    def test_load_from_checkpoint(self, tmp_dir) -> None:
        task = GeneralizedRCNNTask(self._get_cfg(tmp_dir))

        checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR)
        params = {
            "max_steps": 1,
            "limit_train_batches": 1,
            "num_sanity_val_steps": 0,
            "checkpoint_callback": checkpoint_callback,
        }
        trainer = pl.Trainer(**params)
        with EventStorage() as storage:
            task.storage = storage
            trainer.fit(task)
            ckpt_path = os.path.join(tmp_dir, "test.ckpt")
            trainer.save_checkpoint(ckpt_path)
            self.assertTrue(os.path.exists(ckpt_path))

            # load model weights from checkpoint
            task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path)
            self.assertTrue(
                self._compare_state_dict(task.model.state_dict(),
                                         task2.model.state_dict()))