def get_squeezenet(param, pretrained=False, pretrained_path="./pretrained/"): ''' param['model_url']: download url param['model_version']: model_version param['file_name']: model file's name param['n_class']: how many classes to be classified param['img_size']: img_size, a tuple(height, width) or int ''' if isinstance(param['img_size'], (tuple, list)): h, w = param['img_size'][0], param['img_size'][1] else: h = w = param['img_size'] #先创建一个跟预训练模型一样结构的,方便导入权重 model = SqueezeNet(param['model_version'], num_classes=1000) model.img_size = (h, w) # 导入预训练模型的权值,预训练模型必须放在pretrained_path里 if pretrained: if os.path.exists(os.path.join(pretrained_path, param['file_name'])): model.load_state_dict( TorchLoad(os.path.join(pretrained_path, param['file_name']))) logging.info("Find local model file, load model from local !!") logging.info("找到本地下载的预训练模型!!载入权重!!") else: logging.info("pretrained 文件夹下没有,从网上下载 !!") model.load_state_dict( model_zoo.load_url(param['model_url'], model_dir=pretrained_path)) logging.info("下载完毕!!载入权重!!") # 根据输入图像大小和类别数,自动调整 model.adaptive_set_classifier(param['n_class']) return model
def get_inceptionresnetv2(param, pretrained = False, pretrained_path="./pretrained/"): r''' param['model_url']: download url param['file_name']: model file's name param['n_class']: how many classes to be classified param['img_size']: img_size, a tuple(height, width) ''' if isinstance(param['img_size'], (tuple, list)): h, w = param['img_size'][0], param['img_size'][1] else: h = w = param['img_size'] assert h>74 and w>74, 'image size should >= 75 !!!' #先创建一个跟预训练模型一样结构的,方便导入权重 model = InceptionResNetV2(num_classes=1001) model.img_size = (h, w) # 导入预训练模型的权值,预训练模型必须放在pretrained_path里 if pretrained: if os.path.exists(os.path.join(pretrained_path, param['file_name'])): model.load_state_dict(TorchLoad(os.path.join(pretrained_path, param['file_name']))) logging.info("Find local model file, load model from local !!") logging.info("找到本地下载的预训练模型!!载入权重!!") else: logging.info("pretrained 文件夹下没有,从网上下载 !!") model.load_state_dict(model_zoo.load_url(param['model_url'], model_dir = pretrained_path)) logging.info("下载完毕!!载入权重!!") # 根据输入图像大小和类别数,自动调整 model.adaptive_set_fc(param['n_class']) return model
def xception(n_class, img_size=(299, 299), pretrained=False, pretrained_path="./pretrained/"): if isinstance(img_size, (tuple, list)): h, w = img_size[0], img_size[1] else: h = w = img_size model = Xception() model.img_size = (h, w) if pretrained: if os.path.exists( os.path.join(pretrained_path, model_names['xception'])): state_dict = TorchLoad( os.path.join(pretrained_path, model_names['xception'])) logging.info("Find local model file, load model from local !!") logging.info("找到本地下载的预训练模型!!直接载入!!") model.load_state_dict(state_dict) #权重载入完毕 else: logging.info("本地文件夹下没有,请从百度云下载 !!") # 灵活调整 if n_class != 1000: model.adaptive_fc(n_class) return model
def load(path="checkpoint.pth"): checkpoint = TorchLoad(path) model = model_factory(**checkpoint) model.class_to_idx = checkpoint["class_to_idx"] model.load_state_dict(checkpoint["model"]) optimizer = optim.Adam(model.classifier.parameters()) optimizer.load_state_dict(checkpoint["optimizer"]) return model, optimizer
def get_densenet(Net_param, n_class, pretrained=False, pretrained_path="./pretrained/"): ''' Net_param:网络参数,只与网络类型有关 包含 模型url 模型文件名字 growth_rate block_config n_class:输出类别 pretrained:是否使用预训练模型 img_size: img_size ''' if isinstance(Net_param['img_size'], (tuple, list)): h, w = Net_param['img_size'][0], Net_param['img_size'][1] else: h = w = Net_param['img_size'] model = DenseNet(num_init_features=Net_param['num_init_features'], growth_rate=Net_param['growth_rate'], block_config=Net_param['block_config']) model.img_size = (h, w) if pretrained: pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') if os.path.exists(os.path.join(pretrained_path, Net_param['model_name'])): state_dict = TorchLoad(os.path.join("./pretrained", Net_param['model_name'])) logging.info("Find local model file, load model from local !!") logging.info("找到本地下载的预训练模型!!直接载入!!") else: logging.info("pretrained 文件夹下没有,从网上下载 !!") state_dict = model_zoo.load_url(Net_param['url'], model_dir = pretrained_path) logging.info("下载完毕!!载入权重!!") # 导入进来 for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) #权重载入完毕 # 灵活调整 if n_class!=1000: model.adaptive_set_fc(n_class) return model
def get_vgg(Net_cfg, Net_urls, file_name, n_class, pretrained=False, img_size=(224, 224), pretrained_path="./pretrained/"): ''' Net_cfg:网络结构 Net_urls:预训练模型的url file_name:预训练模型的名字 n_class:输出类别 pretrained:是否使用预训练模型 param为字典,包含网络需要的参数 param['img_height']: image's height, must be 32's multiple param['img_width']: image's weight, must be 32's multiple ''' if isinstance(img_size, (tuple, list)): h, w = img_size[0], img_size[1] else: h = w = img_size param = {'img_height': h, 'img_width': w} check_param(param) model = vgg_Net(Net_cfg, param) #先建立一个跟预训练模型一样的网络 model.img_size = (h, w) if pretrained: if os.path.exists(os.path.join(pretrained_path, file_name)): model.load_state_dict( TorchLoad(os.path.join(pretrained_path, file_name))) logging.info("Find local model file, load model from local !!") logging.info("找到本地下载的预训练模型!!直接载入!!") else: logging.info("pretrained 文件夹下没有,从网上下载 !!") model.load_state_dict( model_zoo.load_url(Net_urls, model_dir=pretrained_path)) logging.info("下载完毕!!载入权重!!") model.adjust_classifier(n_class) #调整全连接层,迁移学习 return model