def Resnext101_32x16d(num_classes, test=False): model = resnext101_32x16d_wsl() if not test: if LOCAL_PRETRAINED['weights_resnext101_32x16d'] == None: state_dict = load_state_dict_from_url(model_urls['weights_resnext101_32x16d'], progress=True) else: state_dict = state_dict = torch.load(LOCAL_PRETRAINED['weights_resnext101_32x16d']) model.load_state_dict(state_dict) fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) return model
def pretrain_model(model_name, feature_extract=True): model_init = None if model_name == 'resnext101_32x16d': # load pretrain_model model_init = models.resnext101_32x16d_wsl() # set_layers to train set_params_requires_grad(model_init, feature_extract) ### # example for changing the classes to 4 ### num_input = model_init.fc.in_features model_init.fc = nn.Sequential( nn.Dropout(0.2), nn.Linear(in_features=num_input, out_features=4)) return model_init
def initital_model(model_name, num_classes, feature_extract=True): """ 基于提供的pre_trained_model 进行初始化 :param model_name: 提供的模型名称,例如: resnext101_32x16d/resnext101_32x8d.. :param num_classes: 图片分类个数 :param feature_extract: 设置true ,固定特征提取层,优化全连接的分类器 :return: """ model_ft = None if model_name == 'resnext101_32x16d': # 加载facebook pre_trained_model resnext101,默认1000 类 model_ft = models.resnext101_32x16d_wsl() # 设置 固定特征提取层 set_parameter_requires_grad(model_ft, feature_extract) # 调整分类个数 num_ftrs = model_ft.fc.in_features # 修改fc 的分类个数 model_ft.fc = nn.Sequential( nn.Dropout(0.2), # 防止过拟合 nn.Linear(in_features=num_ftrs, out_features=num_classes)) elif model_name == 'resnext101_32x8d': # 加载facebook pre_trained_model resnext101,默认1000 类 model_ft = models.resnext101_32x8d() # 设置 固定特征提取层 set_parameter_requires_grad(model_ft, feature_extract) # 调整分类个数 num_ftrs = model_ft.fc.in_features # 修改fc 的分类个数 model_ft.fc = nn.Sequential( nn.Dropout(0.2), # 防止过拟合 nn.Linear(in_features=num_ftrs, out_features=num_classes)) else: print('Invalid model name,exiting..') exit() return model_ft
# id -> name mapping ImageNet_dict = dict() for line in codecs.open('data/ImageNet1k_label.txt', 'r', encoding='utf-8'): line = line.strip() # 0: 'tench, Tinca tinca', 丁鲷(鱼) _id = line.split(":")[0] _name = line.split(":")[1] _name = _name.replace('\xa0', "") ImageNet_dict[int(_id)] = _name device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('Pytorch facebookresearch_WS-Images_resnext predict device =', device) # 加载模型 model_ft = models.resnext101_32x16d_wsl() model_ft.to(device) model_ft.eval() # 指定eval @app.route('/') def hello(): return "Hello World" @app.route('/predict', methods=['POST']) def predict(): # 获取输入数据 file = request.files['file'] img_bytes = file.read()