def load(self, args, strict=True): model_dir = os.path.join('..', 'models') os.makedirs(model_dir, exist_ok=True) if args.data_train.find('CIFAR') >= 0: if args.pretrained == 'download' or args.extend == 'download': url = ('https://cv.snu.ac.kr/' 'research/clustering_kernels/models/vgg16-89711a85.pt') state = model_zoo.load_url(url, model_dir=model_dir) elif args.extend: state = torch.load(args.extend) else: common.init_vgg(self) return elif args.data_train == 'ImageNet': if args.pretrained == 'download': #print('pretrained download') if self.norm is not None: url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' else: url = 'https://download.pytorch.org/models/vgg16-397923af.pth' state = model_zoo.load_url(url, model_dir=model_dir) else: common.init_vgg(self) return else: raise NotImplementedError('Unavailable dataset {}'.format( args.data_train)) #print(state['features.0.bias']) self.load_state_dict(state, strict=strict)
def __init__(self, args, conv3x3=common.default_conv, conv1x1=None): super(VGG_Basis, self).__init__() # we use batch noramlization for VGG args = args[0] norm = common.default_norm bias = not args.no_bias n_basis = args.n_basis basis_size = args.basis_size configs = { 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 'ef': [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M'] } body_list = [] in_channels = args.n_colors for i, v in enumerate(configs[args.vgg_type]): if v == 'M': body_list.append(nn.MaxPool2d(kernel_size=2, stride=2)) else: t = 3 if args.vgg_decom_type == 'all' else 8 if i <= t: body_list.append(common.BasicBlock(in_channels, v, args.kernel_size, bias=bias, conv3x3=conv3x3, norm=norm)) else: body_list.append(BasicBlock(in_channels, v, n_basis, basis_size, args.kernel_size, bias=bias, conv=conv3x3, norm=norm)) in_channels = v # assert(args.data_train.find('CIFAR') >= 0) self.features = nn.Sequential(*body_list) if args.data_train.find('CIFAR') >= 0: n_classes = int(args.data_train[5:]) self.classifier = nn.Linear(in_channels, n_classes) elif args.data_train == 'ImageNet': n_classes = 1000 self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, n_classes), ) if conv3x3 == common.default_conv: model_dir = os.path.join('..', 'models') os.makedirs(model_dir, exist_ok=True) if args.data_train.find('CIFAR') >= 0: if args.pretrained == 'download' or args.extend == 'download': url = ( 'https://cv.snu.ac.kr/' 'research/clustering_kernels/models/vgg16-89711a85.pt' ) state = model_zoo.load_url(url, model_dir=model_dir) elif args.extend: state = torch.load(args.extend) else: common.init_vgg(self) return elif args.data_train == 'ImageNet': if args.pretrained == 'download': url = 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth' state = model_zoo.load_url(url, model_dir=model_dir) else: common.init_vgg(self) return else: raise NotImplementedError('Unavailable dataset {}'.format(args.data_train)) # from IPython import embed; embed() self.load_state_dict(state, False)
def __init__(self, args, conv3x3=common.default_conv, conv1x1=None): super(VGG_GROUP, self).__init__() args = args[0] # we use batch noramlization for VGG norm = common.default_norm bias = not args.no_bias group_size = args.group_size configs = { 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'B': [ 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' ], '16': [ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M' ], '19': [ 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M' ], 'ef': [ 32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M' ] } body_list = [] in_channels = args.n_colors for i, v in enumerate(configs[args.vgg_type]): if v == 'M': body_list.append(nn.MaxPool2d(kernel_size=2, stride=2)) else: t = 3 if args.vgg_decom_type == 'all' else 8 if i <= t: body_list.append( common.BasicBlock(in_channels, v, args.kernel_size, bias=bias, conv3x3=conv3x3, norm=norm)) else: body_list.append( BasicBlock(in_channels, v, group_size, args.kernel_size, bias=bias, conv=conv3x3, norm=norm)) in_channels = v # for CIFAR10 and CIFAR100 only assert (args.data_train.find('CIFAR') >= 0) n_classes = int(args.data_train[5:]) self.features = nn.Sequential(*body_list) self.classifier = nn.Linear(in_channels, n_classes) if conv3x3 == common.default_conv: if args.pretrained == 'download' or args.extend == 'download': url = ('https://cv.snu.ac.kr/' 'research/clustering_kernels/models/vgg16-89711a85.pt') model_dir = os.path.join('..', 'models') os.makedirs(model_dir, exist_ok=True) state = torch.utils.model_zoo.load_url(url, model_dir=model_dir) elif args.extend: state = torch.load(args.extend) else: common.init_vgg(self) return self.load_state_dict(state, strict=False)