def test_load_ema_weights(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True) trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=[checkpoint_callback], ) with EventStorage() as storage: task.storage = storage trainer.fit(task) # load EMA weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint( os.path.join(tmp_dir, "last.ckpt")) self.assertTrue( self._compare_state_dict(task.ema_state.state_dict(), task2.ema_state.state_dict())) # apply EMA weights to model task2.ema_state.apply_to(task2.model) self.assertTrue( self._compare_state_dict(task.ema_state.state_dict(), task2.model.state_dict()))
def test_load_ema_weights(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) # load EMA weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint( os.path.join(tmp_dir, "last.ckpt") ) self.assertTrue( self._compare_state_dict( task.ema_state.state_dict(), task2.ema_state.state_dict() ) ) # apply EMA weights to model task2.ema_state.apply_to(task2.model) self.assertTrue( self._compare_state_dict( task.ema_state.state_dict(), task2.model.state_dict() ) )
def test_load_from_checkpoint(self) -> None: with tempfile.TemporaryDirectory() as tmp_dir: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint checkpoint_callback = ModelCheckpoint( directory=task.cfg.OUTPUT_DIR, has_user_data=False) params = { "max_steps": 1, "limit_train_batches": 1, "num_sanity_val_steps": 0, "checkpoint_callback": checkpoint_callback, } trainer = pl.Trainer(**params) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))
def test_load_from_checkpoint(self, tmp_dir) -> None: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))
def test_load_from_checkpoint(self, tmp_dir) -> None: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR) params = { "max_steps": 1, "limit_train_batches": 1, "num_sanity_val_steps": 0, "checkpoint_callback": checkpoint_callback, } trainer = pl.Trainer(**params) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))