def test_update_classy_model(self):
     """
     Tests that the update_classy_model successfully updates from a
     checkpoint
     """
     config = get_fast_test_task_config()
     task = build_task(config)
     trainer = LocalTrainer()
     trainer.train(task)
     for reset_heads in [False, True]:
         task_2 = build_task(config)
         # prepare task_2 for the right device
         task_2.prepare()
         update_classy_model(task_2.model,
                             task.model.get_classy_state(deep_copy=True),
                             reset_heads)
         self._compare_model_state(
             task.model.get_classy_state(),
             task_2.model.get_classy_state(),
             check_heads=not reset_heads,
         )
         if reset_heads:
             # the model head states should be different
             with self.assertRaises(Exception):
                 self._compare_model_state(
                     task.model.get_classy_state(),
                     task_2.model.get_classy_state(),
                     check_heads=True,
                 )
Пример #2
0
    def prepare(
        self,
        num_dataloader_workers: int = 0,
        pin_memory: bool = False,
        use_gpu: bool = False,
        dataloader_mp_context=None,
    ) -> None:
        assert (self.pretrained_checkpoint
                is not None), "Need a pretrained checkpoint for fine tuning"
        super().prepare(num_dataloader_workers, pin_memory, use_gpu,
                        dataloader_mp_context)
        if self.checkpoint is None:
            # no checkpoint exists, load the model's state from the pretrained
            # checkpoint
            state_load_success = update_classy_model(
                self.base_model,
                self.pretrained_checkpoint["classy_state_dict"]["base_model"],
                self.reset_heads,
            )
            assert (
                state_load_success
            ), "Update classy state from pretrained checkpoint was unsuccessful."

        if self.freeze_trunk:
            # do not track gradients for all the parameters in the model except
            # for the parameters in the heads
            for param in self.base_model.parameters():
                param.requires_grad = False
            for heads in self.base_model.get_heads().values():
                for h in heads.values():
                    for param in h.parameters():
                        param.requires_grad = True
Пример #3
0
    def prepare(self) -> None:
        assert (
            self.pretrained_checkpoint is not None
        ), "Need a pretrained checkpoint for fine tuning"
        super().prepare()
        if self.checkpoint is None:
            # no checkpoint exists, load the model's state from the pretrained
            # checkpoint
            state_load_success = update_classy_model(
                self.base_model,
                self.pretrained_checkpoint["classy_state_dict"]["base_model"],
                self.reset_heads,
            )
            assert (
                state_load_success
            ), "Update classy state from pretrained checkpoint was unsuccessful."

        if self.freeze_trunk:
            # do not track gradients for all the parameters in the model except
            # for the parameters in the heads
            for param in self.base_model.parameters():
                param.requires_grad = False
            for heads in self.base_model.get_heads().values():
                for h in heads:
                    for param in h.parameters():
                        param.requires_grad = True
            # re-create ddp model
            self.distributed_model = None
            self.init_distributed_data_parallel_model()