Esempio n. 1
0
    def load(self, args, strict=True):
        model_dir = os.path.join('..', 'models')
        os.makedirs(model_dir, exist_ok=True)
        if args.data_train.find('CIFAR') >= 0:
            if args.pretrained == 'download' or args.extend == 'download':
                url = ('https://cv.snu.ac.kr/'
                       'research/clustering_kernels/models/vgg16-89711a85.pt')

                state = model_zoo.load_url(url, model_dir=model_dir)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_vgg(self)
                return
        elif args.data_train == 'ImageNet':
            if args.pretrained == 'download':
                #print('pretrained download')
                if self.norm is not None:
                    url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
                else:
                    url = 'https://download.pytorch.org/models/vgg16-397923af.pth'
                state = model_zoo.load_url(url, model_dir=model_dir)
            else:
                common.init_vgg(self)
                return
        else:
            raise NotImplementedError('Unavailable dataset {}'.format(
                args.data_train))
        #print(state['features.0.bias'])
        self.load_state_dict(state, strict=strict)
Esempio n. 2
0
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG_Basis, self).__init__()

        # we use batch noramlization for VGG
        args = args[0]
        norm = common.default_norm
        bias = not args.no_bias
        n_basis = args.n_basis
        basis_size = args.basis_size

        configs = {
            'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
            '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
            'ef': [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M']
        }

        body_list = []
        in_channels = args.n_colors
        for i, v in enumerate(configs[args.vgg_type]):
            if v == 'M':
                body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                t = 3 if args.vgg_decom_type == 'all' else 8
                if i <= t:
                    body_list.append(common.BasicBlock(in_channels, v, args.kernel_size, bias=bias,
                                                       conv3x3=conv3x3, norm=norm))
                else:
                    body_list.append(BasicBlock(in_channels, v, n_basis, basis_size, args.kernel_size, bias=bias,
                                                conv=conv3x3, norm=norm))
                in_channels = v

        # assert(args.data_train.find('CIFAR') >= 0)
        self.features = nn.Sequential(*body_list)
        if args.data_train.find('CIFAR') >= 0:
            n_classes = int(args.data_train[5:])
            self.classifier = nn.Linear(in_channels, n_classes)
        elif args.data_train == 'ImageNet':
            n_classes = 1000
            self.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, n_classes),
            )

        if conv3x3 == common.default_conv:
            model_dir = os.path.join('..', 'models')
            os.makedirs(model_dir, exist_ok=True)
            if args.data_train.find('CIFAR') >= 0:
                if args.pretrained == 'download' or args.extend == 'download':
                    url = (
                        'https://cv.snu.ac.kr/'
                        'research/clustering_kernels/models/vgg16-89711a85.pt'
                    )

                    state = model_zoo.load_url(url, model_dir=model_dir)
                elif args.extend:
                    state = torch.load(args.extend)
                else:
                    common.init_vgg(self)
                    return
            elif args.data_train == 'ImageNet':
                if args.pretrained == 'download':
                    url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'
                    state = model_zoo.load_url(url, model_dir=model_dir)
                else:
                    common.init_vgg(self)
                    return
            else:
                raise NotImplementedError('Unavailable dataset {}'.format(args.data_train))
            # from IPython import embed; embed()
            self.load_state_dict(state, False)
    def __init__(self, args, conv3x3=common.default_conv, conv1x1=None):
        super(VGG_GROUP, self).__init__()
        args = args[0]
        # we use batch noramlization for VGG
        norm = common.default_norm
        bias = not args.no_bias
        group_size = args.group_size

        configs = {
            'A':
            [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
            'B': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512,
                512, 'M'
            ],
            '16': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
                'M', 512, 512, 512, 'M'
            ],
            '19': [
                64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512,
                512, 512, 'M', 512, 512, 512, 512, 'M'
            ],
            'ef': [
                32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256,
                'M', 256, 256, 256, 'M'
            ]
        }

        body_list = []
        in_channels = args.n_colors
        for i, v in enumerate(configs[args.vgg_type]):
            if v == 'M':
                body_list.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                t = 3 if args.vgg_decom_type == 'all' else 8
                if i <= t:
                    body_list.append(
                        common.BasicBlock(in_channels,
                                          v,
                                          args.kernel_size,
                                          bias=bias,
                                          conv3x3=conv3x3,
                                          norm=norm))
                else:
                    body_list.append(
                        BasicBlock(in_channels,
                                   v,
                                   group_size,
                                   args.kernel_size,
                                   bias=bias,
                                   conv=conv3x3,
                                   norm=norm))
                in_channels = v

        # for CIFAR10 and CIFAR100 only
        assert (args.data_train.find('CIFAR') >= 0)
        n_classes = int(args.data_train[5:])

        self.features = nn.Sequential(*body_list)
        self.classifier = nn.Linear(in_channels, n_classes)

        if conv3x3 == common.default_conv:
            if args.pretrained == 'download' or args.extend == 'download':
                url = ('https://cv.snu.ac.kr/'
                       'research/clustering_kernels/models/vgg16-89711a85.pt')
                model_dir = os.path.join('..', 'models')
                os.makedirs(model_dir, exist_ok=True)
                state = torch.utils.model_zoo.load_url(url,
                                                       model_dir=model_dir)
            elif args.extend:
                state = torch.load(args.extend)
            else:
                common.init_vgg(self)
                return

            self.load_state_dict(state, strict=False)