Exemple #1
0
def vgg13_bn(pretrained=False, **kwargs):
    """VGG 13-layer model (configuration "B") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
    return model
Exemple #2
0
def vgg16(pretrained=False, model_path=None, **kwargs):
    """VGG 16-layer model (configuration "D")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        else:
            model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    return model
Exemple #3
0
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg = VGG(make_layers([
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
            512, 512, 512, 'M'
        ]),
                  init_weights=False)
        vgg_weight = torch.load('vgg16-397923af.pth')
        vgg.load_state_dict(vgg_weight)
        self.mean = torch.tensor([0.485, 0.456, 0.406], ).cuda()
        self.mean = self.mean.view(1, 3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).cuda()
        self.std = self.std.view(1, 3, 1, 1)
        vgg_pretrained_features = vgg.features
        del vgg_weight,
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5

        self.slice1.add_module('0', vgg_pretrained_features[0])
        self.slice2.add_module('1', vgg_pretrained_features[1])
        self.slice3.add_module('2', vgg_pretrained_features[2])
        self.slice4.add_module('3', vgg_pretrained_features[3])
        #
        # for x in range(4):
        #     self.slice1.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(4, 9):
        #     self.slice2.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(9, 16):
        #     self.slice3.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(16, 23):
        #     self.slice4.add_module(str(x), vgg_pretrained_features[x])
        # for x in range(23, 30):
        #     self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False