def test_update_classy_model(self): """ Tests that the update_classy_model successfully updates from a checkpoint """ config = get_fast_test_task_config() task = build_task(config) trainer = LocalTrainer() trainer.train(task) for reset_heads in [False, True]: task_2 = build_task(config) # prepare task_2 for the right device task_2.prepare() update_classy_model(task_2.model, task.model.get_classy_state(deep_copy=True), reset_heads) self._compare_model_state( task.model.get_classy_state(), task_2.model.get_classy_state(), check_heads=not reset_heads, ) if reset_heads: # the model head states should be different with self.assertRaises(Exception): self._compare_model_state( task.model.get_classy_state(), task_2.model.get_classy_state(), check_heads=True, )
def prepare( self, num_dataloader_workers: int = 0, pin_memory: bool = False, use_gpu: bool = False, dataloader_mp_context=None, ) -> None: assert (self.pretrained_checkpoint is not None), "Need a pretrained checkpoint for fine tuning" super().prepare(num_dataloader_workers, pin_memory, use_gpu, dataloader_mp_context) if self.checkpoint is None: # no checkpoint exists, load the model's state from the pretrained # checkpoint state_load_success = update_classy_model( self.base_model, self.pretrained_checkpoint["classy_state_dict"]["base_model"], self.reset_heads, ) assert ( state_load_success ), "Update classy state from pretrained checkpoint was unsuccessful." if self.freeze_trunk: # do not track gradients for all the parameters in the model except # for the parameters in the heads for param in self.base_model.parameters(): param.requires_grad = False for heads in self.base_model.get_heads().values(): for h in heads.values(): for param in h.parameters(): param.requires_grad = True
def prepare(self) -> None: assert ( self.pretrained_checkpoint is not None ), "Need a pretrained checkpoint for fine tuning" super().prepare() if self.checkpoint is None: # no checkpoint exists, load the model's state from the pretrained # checkpoint state_load_success = update_classy_model( self.base_model, self.pretrained_checkpoint["classy_state_dict"]["base_model"], self.reset_heads, ) assert ( state_load_success ), "Update classy state from pretrained checkpoint was unsuccessful." if self.freeze_trunk: # do not track gradients for all the parameters in the model except # for the parameters in the heads for param in self.base_model.parameters(): param.requires_grad = False for heads in self.base_model.get_heads().values(): for h in heads: for param in h.parameters(): param.requires_grad = True # re-create ddp model self.distributed_model = None self.init_distributed_data_parallel_model()