Ejemplo n.º 1
0
def get_squeezenet(param, pretrained=False, pretrained_path="./pretrained/"):
    ''' param['model_url']: download url
        param['model_version']: model_version
        param['file_name']: model file's name
        param['n_class']: how many classes to be classified
        param['img_size']: img_size, a tuple(height, width) or int
    '''

    if isinstance(param['img_size'], (tuple, list)):
        h, w = param['img_size'][0], param['img_size'][1]
    else:
        h = w = param['img_size']

    #先创建一个跟预训练模型一样结构的,方便导入权重
    model = SqueezeNet(param['model_version'], num_classes=1000)
    model.img_size = (h, w)

    # 导入预训练模型的权值,预训练模型必须放在pretrained_path里
    if pretrained:
        if os.path.exists(os.path.join(pretrained_path, param['file_name'])):
            model.load_state_dict(
                TorchLoad(os.path.join(pretrained_path, param['file_name'])))
            logging.info("Find local model file, load model from local !!")
            logging.info("找到本地下载的预训练模型!!载入权重!!")
        else:
            logging.info("pretrained 文件夹下没有,从网上下载 !!")
            model.load_state_dict(
                model_zoo.load_url(param['model_url'],
                                   model_dir=pretrained_path))
            logging.info("下载完毕!!载入权重!!")

    # 根据输入图像大小和类别数,自动调整
    model.adaptive_set_classifier(param['n_class'])

    return model
Ejemplo n.º 2
0
def get_inceptionresnetv2(param, pretrained = False, pretrained_path="./pretrained/"):

    r''' param['model_url']: download url
        param['file_name']: model file's name
        param['n_class']: how many classes to be classified
        param['img_size']: img_size, a tuple(height, width)
    '''

    if isinstance(param['img_size'], (tuple, list)):
        h, w = param['img_size'][0], param['img_size'][1]
    else:
        h = w = param['img_size']
    assert h>74 and w>74, 'image size should >= 75 !!!'

    #先创建一个跟预训练模型一样结构的,方便导入权重
    model = InceptionResNetV2(num_classes=1001)
    model.img_size = (h, w)

    # 导入预训练模型的权值,预训练模型必须放在pretrained_path里
    if pretrained:
        if os.path.exists(os.path.join(pretrained_path, param['file_name'])):
            model.load_state_dict(TorchLoad(os.path.join(pretrained_path, param['file_name'])))
            logging.info("Find local model file, load model from local !!")
            logging.info("找到本地下载的预训练模型!!载入权重!!")
        else:
            logging.info("pretrained 文件夹下没有,从网上下载 !!")
            model.load_state_dict(model_zoo.load_url(param['model_url'], model_dir = pretrained_path))
            logging.info("下载完毕!!载入权重!!")

    # 根据输入图像大小和类别数,自动调整
    model.adaptive_set_fc(param['n_class'])

    return model
Ejemplo n.º 3
0
def xception(n_class,
             img_size=(299, 299),
             pretrained=False,
             pretrained_path="./pretrained/"):

    if isinstance(img_size, (tuple, list)):
        h, w = img_size[0], img_size[1]
    else:
        h = w = img_size

    model = Xception()
    model.img_size = (h, w)

    if pretrained:
        if os.path.exists(
                os.path.join(pretrained_path, model_names['xception'])):
            state_dict = TorchLoad(
                os.path.join(pretrained_path, model_names['xception']))
            logging.info("Find local model file, load model from local !!")
            logging.info("找到本地下载的预训练模型!!直接载入!!")
            model.load_state_dict(state_dict)  #权重载入完毕
        else:
            logging.info("本地文件夹下没有,请从百度云下载 !!")

    # 灵活调整
    if n_class != 1000:
        model.adaptive_fc(n_class)

    return model
Ejemplo n.º 4
0
def load(path="checkpoint.pth"):
    checkpoint = TorchLoad(path)
    model = model_factory(**checkpoint)
    model.class_to_idx = checkpoint["class_to_idx"]
    model.load_state_dict(checkpoint["model"])
    optimizer = optim.Adam(model.classifier.parameters())
    optimizer.load_state_dict(checkpoint["optimizer"])

    return model, optimizer
Ejemplo n.º 5
0
def get_densenet(Net_param, n_class, pretrained=False, pretrained_path="./pretrained/"):
    '''
    Net_param:网络参数,只与网络类型有关 
                包含 模型url 模型文件名字 growth_rate block_config
    n_class:输出类别
    pretrained:是否使用预训练模型
    img_size: img_size
    '''

    if isinstance(Net_param['img_size'], (tuple, list)):
        h, w = Net_param['img_size'][0], Net_param['img_size'][1]
    else:
        h = w = Net_param['img_size']    

    model = DenseNet(num_init_features=Net_param['num_init_features'], 
        growth_rate=Net_param['growth_rate'], block_config=Net_param['block_config'])
    model.img_size = (h, w)

    if pretrained:
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        
        if os.path.exists(os.path.join(pretrained_path, Net_param['model_name'])):
            state_dict = TorchLoad(os.path.join("./pretrained", Net_param['model_name']))
            logging.info("Find local model file, load model from local !!")
            logging.info("找到本地下载的预训练模型!!直接载入!!")
        else:
            logging.info("pretrained 文件夹下没有,从网上下载 !!")
            state_dict = model_zoo.load_url(Net_param['url'], model_dir = pretrained_path)
            logging.info("下载完毕!!载入权重!!")

        # 导入进来
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]

        model.load_state_dict(state_dict) #权重载入完毕

    # 灵活调整
    if n_class!=1000:
        model.adaptive_set_fc(n_class)

    return model
Ejemplo n.º 6
0
def get_vgg(Net_cfg,
            Net_urls,
            file_name,
            n_class,
            pretrained=False,
            img_size=(224, 224),
            pretrained_path="./pretrained/"):
    '''
    Net_cfg:网络结构
    Net_urls:预训练模型的url
    file_name:预训练模型的名字
    n_class:输出类别
    pretrained:是否使用预训练模型

    param为字典,包含网络需要的参数
    param['img_height']: image's height, must be 32's multiple
    param['img_width']: image's weight, must be 32's multiple
    '''
    if isinstance(img_size, (tuple, list)):
        h, w = img_size[0], img_size[1]
    else:
        h = w = img_size

    param = {'img_height': h, 'img_width': w}
    check_param(param)

    model = vgg_Net(Net_cfg, param)  #先建立一个跟预训练模型一样的网络
    model.img_size = (h, w)

    if pretrained:
        if os.path.exists(os.path.join(pretrained_path, file_name)):
            model.load_state_dict(
                TorchLoad(os.path.join(pretrained_path, file_name)))
            logging.info("Find local model file, load model from local !!")
            logging.info("找到本地下载的预训练模型!!直接载入!!")
        else:
            logging.info("pretrained 文件夹下没有,从网上下载 !!")
            model.load_state_dict(
                model_zoo.load_url(Net_urls, model_dir=pretrained_path))
            logging.info("下载完毕!!载入权重!!")

    model.adjust_classifier(n_class)  #调整全连接层,迁移学习

    return model