def test_load_ema_weights(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True) trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=[checkpoint_callback], ) with EventStorage() as storage: task.storage = storage trainer.fit(task) # load EMA weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint( os.path.join(tmp_dir, "last.ckpt")) self.assertTrue( self._compare_state_dict(task.ema_state.state_dict(), task2.ema_state.state_dict())) # apply EMA weights to model task2.ema_state.apply_to(task2.model) self.assertTrue( self._compare_state_dict(task.ema_state.state_dict(), task2.model.state_dict()))
def test_build_model(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) # test building untrained model model = GeneralizedRCNNTask.build_model(cfg) self.assertTrue(model.training) # test loading regular weights with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertFalse(model.training) self.assertTrue( self._compare_state_dict(model.state_dict(), task.model.state_dict())) # test loading EMA weights with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertFalse(model.training) self.assertTrue( self._compare_state_dict(model.state_dict(), task.ema_state.state_dict()))
def test_load_ema_weights(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) # load EMA weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint( os.path.join(tmp_dir, "last.ckpt") ) self.assertTrue( self._compare_state_dict( task.ema_state.state_dict(), task2.ema_state.state_dict() ) ) # apply EMA weights to model task2.ema_state.apply_to(task2.model) self.assertTrue( self._compare_state_dict( task.ema_state.state_dict(), task2.model.state_dict() ) )
def test_load_from_checkpoint(self) -> None: with tempfile.TemporaryDirectory() as tmp_dir: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) from stl.lightning.callbacks.model_checkpoint import ModelCheckpoint checkpoint_callback = ModelCheckpoint( directory=task.cfg.OUTPUT_DIR, has_user_data=False) params = { "max_steps": 1, "limit_train_batches": 1, "num_sanity_val_steps": 0, "checkpoint_callback": checkpoint_callback, } trainer = pl.Trainer(**params) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))
def test_qat(self, tmp_dir): @META_ARCH_REGISTRY.register() class QuantizableDetMetaArchForTest(mah.DetMetaArchForTest): custom_config_dict = {"preserved_attributes": ["preserved_attr"]} def __init__(self, cfg): super().__init__(cfg) self.avgpool.preserved_attr = "foo" self.avgpool.not_preserved_attr = "bar" def prepare_for_quant(self, cfg): example_inputs = (torch.rand(1, 3, 3, 3), ) self.avgpool = prepare_qat_fx( self.avgpool, { "": set_backend_and_create_qconfig(cfg, is_train=self.training) }, example_inputs, self.custom_config_dict, ) return self def prepare_for_quant_convert(self, cfg): self.avgpool = convert_fx( self.avgpool, convert_custom_config_dict=self.custom_config_dict) return self cfg = self._get_cfg(tmp_dir) cfg.MODEL.META_ARCHITECTURE = "QuantizableDetMetaArchForTest" cfg.QUANTIZATION.QAT.ENABLED = True task = GeneralizedRCNNTask(cfg) callbacks = [ QuantizationAwareTraining.from_config(cfg), ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True), ] trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=callbacks, logger=False, ) with EventStorage() as storage: task.storage = storage trainer.fit(task) prepared_avgpool = task._prepared.model.avgpool self.assertEqual(prepared_avgpool.preserved_attr, "foo") self.assertFalse(hasattr(prepared_avgpool, "not_preserved_attr")) with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertTrue(isinstance(model.avgpool, torch.fx.GraphModule))
def do_test(trainer: pl.Trainer, task: GeneralizedRCNNTask): """Runs the evaluation with a pre-trained model. Args: cfg: The normalized ConfigNode for this D2Go Task. trainer: PyTorch Lightning trainer. task: Lightning module instance. """ with EventStorage() as storage: task.storage = storage trainer.test(task)
def test_load_from_checkpoint(self, tmp_dir) -> None: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))
def test_train_ema(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.DECAY = 0.7 task = GeneralizedRCNNTask(cfg) init_state = deepcopy(task.model.state_dict()) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task) for k, v in task.model.state_dict().items(): init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v) self.assertTrue( self._compare_state_dict(init_state, task.ema_state.state_dict()))
def test_build_model(self, tmp_dir): cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True task = GeneralizedRCNNTask(cfg) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR, save_last=True) trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, callbacks=[checkpoint_callback], ) with EventStorage() as storage: task.storage = storage trainer.fit(task) # test building untrained model model = GeneralizedRCNNTask.build_model(cfg) self.assertTrue(model.training) # test loading regular weights with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertFalse(model.training) self.assertTrue( self._compare_state_dict(model.state_dict(), task.model.state_dict())) # test loading EMA weights with temp_defrost(cfg): cfg.MODEL.WEIGHTS = os.path.join(tmp_dir, "last.ckpt") cfg.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = True model = GeneralizedRCNNTask.build_model(cfg, eval_only=True) self.assertFalse(model.training) self.assertTrue( self._compare_state_dict(model.state_dict(), task.ema_state.state_dict()))
def test_meta_arch_training_step(self, tmp_dir): @META_ARCH_REGISTRY.register() class DetMetaArchForWithTrainingStep(mah.DetMetaArchForTest): def training_step(self, batch, batch_idx, opt, manual_backward): assert batch assert opt assert manual_backward # We step the optimizer for progress tracking to occur # This is reflected in the Trainer's global_step property # which is used to determine when to stop training # when specifying the loop bounds with Trainer(max_steps=N) opt.step() return {"total_loss": 0.4} cfg = self._get_cfg(tmp_dir) cfg.MODEL.META_ARCHITECTURE = "DetMetaArchForWithTrainingStep" task = GeneralizedRCNNTask(cfg) trainer = self._get_trainer(tmp_dir) with EventStorage() as storage: task.storage = storage trainer.fit(task)
def test_load_from_checkpoint(self, tmp_dir) -> None: task = GeneralizedRCNNTask(self._get_cfg(tmp_dir)) checkpoint_callback = ModelCheckpoint(dirpath=task.cfg.OUTPUT_DIR) params = { "max_steps": 1, "limit_train_batches": 1, "num_sanity_val_steps": 0, "checkpoint_callback": checkpoint_callback, } trainer = pl.Trainer(**params) with EventStorage() as storage: task.storage = storage trainer.fit(task) ckpt_path = os.path.join(tmp_dir, "test.ckpt") trainer.save_checkpoint(ckpt_path) self.assertTrue(os.path.exists(ckpt_path)) # load model weights from checkpoint task2 = GeneralizedRCNNTask.load_from_checkpoint(ckpt_path) self.assertTrue( self._compare_state_dict(task.model.state_dict(), task2.model.state_dict()))
def test_train_ema(self): with tempfile.TemporaryDirectory() as tmp_dir: cfg = self._get_cfg(tmp_dir) cfg.MODEL_EMA.ENABLED = True cfg.MODEL_EMA.DECAY = 0.7 task = GeneralizedRCNNTask(cfg) init_state = deepcopy(task.model.state_dict()) trainer = pl.Trainer( max_steps=1, limit_train_batches=1, num_sanity_val_steps=0, ) with EventStorage() as storage: task.storage = storage trainer.fit(task) for k, v in task.model.state_dict().items(): init_state[k].copy_(init_state[k] * 0.7 + 0.3 * v) self.assertTrue( self._compare_state_dict(init_state, task.ema_state.state_dict()))
def do_train(cfg: CfgNode, trainer: pl.Trainer, task: GeneralizedRCNNTask) -> Dict[str, str]: """Runs the training loop with given trainer and task. Args: cfg: The normalized ConfigNode for this D2Go Task. trainer: PyTorch Lightning trainer. task: Lightning module instance. Returns: A map of model name to trained model config path. """ with EventStorage() as storage: task.storage = storage trainer.fit(task) final_ckpt = os.path.join(cfg.OUTPUT_DIR, FINAL_MODEL_CKPT) trainer.save_checkpoint(final_ckpt) # for validation monitor trained_cfg = cfg.clone() with temp_defrost(trained_cfg): trained_cfg.MODEL.WEIGHTS = final_ckpt model_configs = dump_trained_model_configs( cfg.OUTPUT_DIR, {"model_final": trained_cfg}) return model_configs