Exemplo n.º 1
0
def Resnext101_32x16d(num_classes, test=False):
    model = resnext101_32x16d_wsl()
    if not test:
        if LOCAL_PRETRAINED['weights_resnext101_32x16d'] == None:
            state_dict = load_state_dict_from_url(model_urls['weights_resnext101_32x16d'], progress=True)
        else:
            state_dict = state_dict = torch.load(LOCAL_PRETRAINED['weights_resnext101_32x16d'])
        model.load_state_dict(state_dict)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, num_classes)
    return model
Exemplo n.º 2
0
def pretrain_model(model_name, feature_extract=True):

    model_init = None

    if model_name == 'resnext101_32x16d':
        # load pretrain_model
        model_init = models.resnext101_32x16d_wsl()
        # set_layers to train
        set_params_requires_grad(model_init, feature_extract)

        ###
        # example for changing the classes to 4
        ###
        num_input = model_init.fc.in_features
        model_init.fc = nn.Sequential(
            nn.Dropout(0.2), nn.Linear(in_features=num_input, out_features=4))

    return model_init
Exemplo n.º 3
0
def initital_model(model_name, num_classes, feature_extract=True):
    """
    基于提供的pre_trained_model 进行初始化
    :param model_name:
    提供的模型名称,例如: resnext101_32x16d/resnext101_32x8d..
    :param num_classes: 图片分类个数
    :param feature_extract: 设置true ,固定特征提取层,优化全连接的分类器
    :return:
    """

    model_ft = None

    if model_name == 'resnext101_32x16d':
        # 加载facebook pre_trained_model resnext101,默认1000 类
        model_ft = models.resnext101_32x16d_wsl()
        # 设置 固定特征提取层
        set_parameter_requires_grad(model_ft, feature_extract)

        # 调整分类个数
        num_ftrs = model_ft.fc.in_features
        # 修改fc 的分类个数
        model_ft.fc = nn.Sequential(
            nn.Dropout(0.2),  # 防止过拟合
            nn.Linear(in_features=num_ftrs, out_features=num_classes))
    elif model_name == 'resnext101_32x8d':
        # 加载facebook pre_trained_model resnext101,默认1000 类
        model_ft = models.resnext101_32x8d()
        # 设置 固定特征提取层
        set_parameter_requires_grad(model_ft, feature_extract)

        # 调整分类个数
        num_ftrs = model_ft.fc.in_features
        # 修改fc 的分类个数
        model_ft.fc = nn.Sequential(
            nn.Dropout(0.2),  # 防止过拟合
            nn.Linear(in_features=num_ftrs, out_features=num_classes))

    else:
        print('Invalid model name,exiting..')
        exit()

    return model_ft
# id -> name mapping
ImageNet_dict = dict()
for line in codecs.open('data/ImageNet1k_label.txt', 'r', encoding='utf-8'):
    line = line.strip()  # 0: 'tench, Tinca tinca',  丁鲷(鱼)

    _id = line.split(":")[0]
    _name = line.split(":")[1]
    _name = _name.replace('\xa0', "")
    ImageNet_dict[int(_id)] = _name

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('Pytorch facebookresearch_WS-Images_resnext predict device =', device)

# 加载模型
model_ft = models.resnext101_32x16d_wsl()
model_ft.to(device)
model_ft.eval()  # 指定eval

@app.route('/')
def hello():
    return "Hello World"


@app.route('/predict', methods=['POST'])
def predict():
    # 获取输入数据
    file = request.files['file']

    img_bytes = file.read()