def update_params(self):
        r"""
        Load model parameters
        """
        self._make_convs()
        self._make_nonlocal()
        self._initialize_conv()
        self._initialize_nonlocal()
        super().update_params()

        if self._hyper_params["pretrain_model_path"] != "":
            model_path = self._hyper_params["pretrain_model_path"]
            try:
                state_dict = torch.load(model_path,
                                        map_location=torch.device("gpu"))
            except:
                state_dict = torch.load(model_path,
                                        map_location=torch.device("cpu"))
            if "model_state_dict" in state_dict:
                state_dict = state_dict["model_state_dict"]
            try:
                self.load_state_dict(state_dict, strict=True)
            except:
                self.load_state_dict(state_dict, strict=False)
            logger.info("Pretrained weights loaded from {}".format(model_path))
            logger.info("Check md5sum of Pretrained weights: %s" %
                        md5sum(model_path))
Beispiel #2
0
 def update_params(self):
     model_file = self._hyper_params["pretrain_model_path"]
     if model_file != "":
         state_dict = torch.load(model_file,
                                 map_location=torch.device("cpu"))
         self.load_state_dict(state_dict, strict=False)
         logger.info("Load pretrained resnet-18 parameters from: %s" %
                     model_file)
         logger.info("Check md5sum of pretrained resnet-18 parameters: %s" %
                     md5sum(model_file))
Beispiel #3
0
 def update_params(self):
     model_file = self._hyper_params.get("pretrain_model_path", "")
     if model_file != "":
         state_dict = torch.load(model_file,
                                 map_location=torch.device("cpu"))
         if "model_state_dict" in state_dict:
             state_dict = state_dict["model_state_dict"]
         self.load_model_param(state_dict)
         logger.info(
             "Load pretrained {} parameters from: {} whose md5sum is {}".
             format(self.__class__.__name__, model_file, md5sum(model_file)))
Beispiel #4
0
    def update_params(self):
        model_file = self._hyper_params["pretrain_model_path"]
        if model_file != "":
            try:
                state_dict = torch.load(model_file,
                                        map_location=torch.device("gpu"))
            except:
                state_dict = torch.load(model_file,
                                        map_location=torch.device("cpu"))
            self.load_state_dict(state_dict, strict=False)
            logger.info("Load pretrained GoogLeNet parameters from: %s" %
                        model_file)
            logger.info("Check md5sum of pretrained GoogLeNet parameters: %s" %
                        md5sum(model_file))

        self.crop_pad = self._hyper_params['crop_pad']
        self.pruned = self._hyper_params['pruned']
 def update_params(self):
     arch = "shufflenetv2_x1.0"
     kwargs = self._hyper_params
     # build module
     self._model = _shufflenetv2(arch,
                                 False,
                                 True, [4, 8, 4], [24, 116, 232, 464, 1024],
                                 fused_channls=[116, 232, 464],
                                 **kwargs)
     model_file = self._hyper_params["pretrain_model_path"]
     if model_file != "":
         state_dict = torch.load(model_file,
                                 map_location=torch.device("cpu"))
         self._model.load_state_dict(state_dict, strict=False)
         logger.info("Load pretrained ShuffleNet parameters from: %s" %
                     model_file)
         logger.info(
             "Check md5sum of pretrained ShuffleNet parameters: %s" %
             md5sum(model_file))
    def update_params(self):
        r"""
           定义了 model 用配置文件提供的超参数更新超参数的行为

           Returns:

        """

        model_file = self._hyper_params.get("pretrain_model_path",
                                            "")  # 获取预训练模型文件位置
        if model_file != "":
            state_dict = torch.load(
                model_file,
                map_location=torch.device("cpu"))  # 加载 model tensors
            if "model_state_dict" in state_dict:
                state_dict = state_dict["model_state_dict"]
            self.load_model_param(state_dict)
            logger.info(
                "Load pretrained {} parameters from: {} whose md5sum is {}".
                format(self.__class__.__name__, model_file,
                       md5sum(model_file)))