示例#1
0
def mobilenet_v2(pretrained='scratch', **kwargs):
    model = MobileNetV2(**kwargs)

    if pretrained == 'imagenet':

        print('loading pre-trained imagenet model')
        model_full = mobilenet_v2_imagenet.mobilenet_v2(pretrained=True)
        model.load_pretrained(model_full)
    elif pretrained == 'coco':
        print('loading pre-trained COCO model')
        # Load checkpoint
        checkpoint = torch.load(os.path.join(Path.models_dir(),
                                             'mobilenet_v2_coco_80.pth'),
                                map_location=lambda storage, loc: storage)

        # handle dataparallel
        if 'module.' in list(checkpoint.keys())[0]:
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                name = k.replace('module.', '')  # remove `module.`
                new_state_dict[name] = v
        else:
            new_state_dict = checkpoint

        # Load pre-trained IN model
        model.load_state_dict(new_state_dict)

    elif pretrained == 'scratch':
        print('using imagenet initialized from scratch')
    else:
        raise NotImplementedError(
            'select either imagenet or scratch for pre-training')

    return model
示例#2
0
def Res_Deeplab(n_classes=21, pretrained=None):
    model = MS_Deeplab(Bottleneck, n_classes)
    if pretrained is not None:
        if pretrained == 'voc':
            pth_model = 'MS_DeepLab_resnet_trained_VOC.pth'
        elif pretrained == 'ms_coco':
            pth_model = 'MS_DeepLab_resnet_pretrained_COCO_init.pth'
        saved_state_dict = torch.load(
            os.path.join(Path.models_dir(), pth_model),
            map_location=lambda storage, loc: storage)
        if n_classes != 21:
            for i in saved_state_dict:
                i_parts = i.split('.')
                if i_parts[1] == 'layer5':
                    saved_state_dict[i] = model.state_dict()[i]
        model.load_state_dict(saved_state_dict)
    return model
示例#3
0
def resnet26(pretrained=False, remote=True):
    """Constructs a ResNet-26 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet_C5(Bottleneck, [2, 2, 2, 2])

    if pretrained:

        # Define ResNet26 ImageNet
        model_IN = resnet.ResNet(block=Bottleneck,
                                 layers=[2, 2, 2, 2],
                                 num_classes=1000)

        # Load checkpoint
        if remote:
            checkpoint = load_state_dict_from_url(model_urls['resnet26'],
                                                  map_location='cpu',
                                                  progress=True)
        else:
            checkpoint = torch.load(os.path.join(Path.models_dir(),
                                                 'resnet26.pth'),
                                    map_location=lambda storage, loc: storage)
        checkpoint = checkpoint['model_state']

        # Handle DataParallel
        if 'module.' in list(checkpoint.keys())[0]:
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                name = k.replace('module.', '')  # remove `module.`
                new_state_dict[name] = v
        else:
            new_state_dict = checkpoint

        # Load pre-trained IN model
        model_IN.load_state_dict(new_state_dict)

        # Load weights to dense-labelling network
        model.load_pretrained(model_IN)

    return model
示例#4
0
def se_mobilenet_v2(pretrained=False,
                    features=False,
                    n_class=1000,
                    last_channel=1280,
                    remote=True):

    if not features:
        model = SEMobileNetV2(n_class=n_class, last_channel=last_channel)
    else:
        model = SEMobileNetV2Features(n_class=n_class,
                                      last_channel=last_channel)

    if pretrained:
        print('Loading Imagenet pre-trained SE-MobileNet-v2')

        # Load checkpoint
        if remote:
            checkpoint = load_state_dict_from_url(
                model_urls['se_mobilenet_v2_1280'],
                map_location='cpu',
                progress=True)
        else:
            checkpoint = torch.load(os.path.join(Path.models_dir(),
                                                 'se_mobilenet_v2_1280.pth'),
                                    map_location=lambda storage, loc: storage)
        checkpoint = checkpoint['model_state']

        # Handle DataParallel
        if 'module.' in list(checkpoint.keys())[0]:
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                name = k.replace('module.', '')  # remove `module.`
                new_state_dict[name] = v
        else:
            new_state_dict = checkpoint

        # Load pre-trained IN model
        model.load_state_dict(new_state_dict)

    return model
示例#5
0
def get_state_dict_se(model_name, remote=True):
    # Load checkpoint
    if remote:
        checkpoint = load_state_dict_from_url(model_urls[model_name],
                                              map_location='cpu',
                                              progress=True)
    else:
        checkpoint = torch.load(os.path.join(Path.models_dir(),
                                             model_name + '.pth'),
                                map_location=lambda storage, loc: storage)
    checkpoint = checkpoint['model_state']

    # Handle DataParallel
    if 'module.' in list(checkpoint.keys())[0]:
        new_state_dict = OrderedDict()
        for k, v in checkpoint.items():
            name = k.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
    else:
        new_state_dict = checkpoint

    return new_state_dict
示例#6
0
def resnet26(pretrained=False, features=False, remote=True, **kwargs):
    """Constructs a ResNet-26 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if not features:
        model = ResNet(Bottleneck, [2, 2, 2, 2], **kwargs)
    else:
        model = ResNetFeatures(Bottleneck, [2, 2, 2, 2], **kwargs)

    if pretrained:
        print('Loading resnet26 Imagenet')

        # Load checkpoint
        if remote:
            checkpoint = load_state_dict_from_url(model_urls['resnet26'],
                                                  map_location='cpu',
                                                  progress=True)
        else:
            checkpoint = torch.load(os.path.join(Path.models_dir(),
                                                 'resnet26.pth'),
                                    map_location=lambda storage, loc: storage)
        checkpoint = checkpoint['model_state']

        # Handle DataParallel
        if 'module.' in list(checkpoint.keys())[0]:
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                name = k.replace('module.', '')  # remove `module.`
                new_state_dict[name] = v
        else:
            new_state_dict = checkpoint

        # Load pre-trained IN model
        model.load_state_dict(new_state_dict)

    return model
示例#7
0
def mobilenet_v2(pretrained=False,
                 features=False,
                 n_class=1000,
                 last_channel=1280,
                 remote=True):

    if not features:
        model = MobileNetV2(n_class=n_class, last_channel=last_channel)
    else:
        model = MobileNetV2Features(n_class=n_class, last_channel=last_channel)

    if pretrained:
        if remote:
            checkpoint = load_state_dict_from_url(
                model_urls['mobilenet_v2_1280'],
                map_location='cpu',
                progress=True)
        else:
            checkpoint = torch.load(os.path.join(Path.models_dir(),
                                                 'mobilenet_v2.pth'),
                                    map_location='cpu')

        model.load_state_dict(checkpoint)
    return model
示例#8
0
def resnext_101_32x4d():
    model = resnext_101_32x4d_model
    model.load_state_dict(torch.load(os.path.join(Path.models_dir(), 'resnext_101_32x4d.pth')))
    return model