Ejemplo n.º 1
0
 def load_model(self, model_path: str, with_mask=True) -> None:
     """Load weights and masks."""
     checkpt = torch.load(model_path, map_location=self.device)
     model_utils.initialize_params(
         self.model, checkpt["state_dict"], with_mask=with_mask
     )
     logger.info(f"Loaded the model from {model_path}")
Ejemplo n.º 2
0
    def _init_model(self, checkpoint_path: str) -> None:
        """Create a model instance and load weights."""
        # load weights
        logger.info(f"Load weights from the checkpoint {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path,
                                map_location=torch.device("cpu"))

        state_dict = checkpoint["state_dict"]
        self.orig_acc = checkpoint["test_acc"]

        is_pruned = (next((name for name in state_dict if "mask" in name),
                          None) is not None)

        if is_pruned:
            logger.info("Dummy prunning to load pruned weights")
            model_utils.dummy_pruning(self.params_all)

        model_utils.initialize_params(self.model, state_dict)
        logger.info("Initialized weights")

        # check the trained model is pruned

        if is_pruned:
            logger.info(
                "Get masks and remove prunning reparameterization for prepare_qat"
            )
            self.mask = model_utils.get_masks(self.model)
            model_utils.remove_pruning_reparameterization(self.params_all)
Ejemplo n.º 3
0
    def _create_teacher(
        self, teacher_model_name: str, teacher_model_params: Dict[str, Any]
    ) -> nn.Module:
        """Create teacher network."""
        # create teacher instance
        teacher = model_utils.get_model(teacher_model_name, teacher_model_params).to(
            self.device
        )

        # teacher path info
        prefix = os.path.join("save", "pretrained")
        model_info = model_utils.get_pretrained_model_info(teacher)
        model_name, file_name = model_info["dir_name"], model_info["file_name"]
        file_path = os.path.join(prefix, model_name, file_name)

        # load teacher model params:
        if not os.path.isfile(file_path):
            model_utils.download_pretrained_model(file_path, model_info["link"])
            logger.info(
                f"Pretrained teacher model({model_name}) doesn't exist in the path.\t"
                f"Download teacher model as {file_path}"
            )

        logger.info(f"Load teacher model: {file_path}")
        state_dict = torch.load(file_path, map_location=self.device)["state_dict"]
        model_utils.initialize_params(model=teacher, state_dict=state_dict)
        teacher = teacher.to(self.device)
        teacher.eval()
        return teacher
Ejemplo n.º 4
0
 def load_params(self, model_path: str, with_mask=True) -> None:
     """Load weights and masks."""
     checkpt = torch.load(model_path, map_location=self.device)
     model_utils.initialize_params(
         self.model, checkpt["state_dict"], with_mask=with_mask
     )
     model_utils.initialize_params(
         self.optimizer, checkpt["optimizer"], with_mask=False
     )
     self.best_acc = checkpt["test_acc"]
     logger.info(f"Loaded parameters from {model_path}")
Ejemplo n.º 5
0
    def run(self, resume_info_path: str = "") -> None:
        """Run the module."""
        # initialize weights
        logger.info(
            f"Initialize the model from the checkpoint {self.checkpoint_path}")
        state_dict = torch.load(self.checkpoint_path,
                                map_location=self.device)["state_dict"]
        model_utils.initialize_params(self.model, state_dict)

        # measure the model size
        # model has to run at least 1 time due to execution of forward hooks
        _, acc = self.trainer.test_one_epoch()
        size = model_utils.get_model_size_mb(self.model)
        sparsity = model_utils.sparsity(self.params_all)
        logger.info(
            f"Original model's Acc: {acc['model_acc']:.2f}, Size: {size} MB, "
            f"Sparsity: {sparsity:.2f} %")

        shrinked_model = model_utils.get_model(
            self.train_config["MODEL_NAME"],
            self.train_config["MODEL_PARAMS"]).to(self.device)
        shrinked_model = self.shrink_model(self.model, shrinked_model)

        # measure the shrinked model size
        _, acc = self.trainer.test_one_epoch_model(shrinked_model)
        size = model_utils.get_model_size_mb(shrinked_model)
        n_params = model_utils.count_model_params(shrinked_model)
        logger.info(
            f"Shrinked model's Acc: {acc['model_acc']:.2f}, Size: {size} MB, "
            f"Params: {(n_params * 1e-6):.2f} M")

        # save the shrinked model
        shrinked_model_path = os.path.join(self.dir_prefix,
                                           "shrinked_model.pth")
        torch.save(shrinked_model, shrinked_model_path)
        logger.info(f"Saved shrinked model as {shrinked_model_path}")

        # load the shrinked model
        logger.info(f"Load a shrinked model from {shrinked_model_path}")
        loaded_model = torch.load(shrinked_model_path)
        loaded_model.eval()

        # measure the loaded model size
        _, acc = self.trainer.test_one_epoch_model(loaded_model)
        size = model_utils.get_model_size_mb(loaded_model)
        n_params = model_utils.count_model_params(loaded_model)
        logger.info(
            f"Loaded model's Acc: {acc['model_acc']:.2f}, Size: {size} MB, "
            f"Params: {(n_params * 1e-6):.2f} M")