def net_from_args(args, num_classes, IM_SIZE): if (args.net_type == 'PreActResNet18'): net = PreActResNet18() file_name = 'PreActResNet18-' elif (args.net_type == 'resnet'): net = ResNet(args.depth, num_classes, IM_SIZE) file_name = 'resnet-' + str(args.depth) elif (args.net_type == 'wide-resnet'): net = Wide_ResNet(args.depth, args.widen_factor, args.dropout, num_classes, IM_SIZE) file_name = 'wide-resnet-' + str(args.depth) + 'x' + str( args.widen_factor) elif (args.net_type == 'SimpleNetMNIST'): net = SimpleNetMNIST() file_name = 'SimpleNetMNIST' else: # print('Error : Wrong net type') sys.exit(0) return net, file_name # def load_net_cifar(model_loc): # """ Make a model # Network must be saved in the form model_name-depth, where this is a unique identifier # """ # model_file = Path(model_loc).name # model_name = model_file.split('-')[0] # print('Loading model_file', model_file) # if (model_name == 'vggnet'): # model = VGG(int(model_file.split('-')[1]), 10) # elif (model_name == 'resnet'): # model = ResNet(int(model_file.split('-')[1]), 10) # # so ugly # elif (model_name == 'preact_resnet'): # if model_file.split('/')[-1].split('_')[2] == 'model': # model = PreActResNet(int(model_file.split('-')[1].split('_')[0]), 10) # else: # model = PResNetReg(int(model_file.split('-')[1]), float(model_file.split('-')[2]), 1, 10) # elif (model_name == 'wide'): # model = Wide_ResNet(model_file.split('-')[2][0:2], model_file.split('-')[2][2:4], 0, 10, 32) # # Dumb ones # elif (model_name == 'PResNetRegNoRelU'): # model = PResNetRegNoRelU(int(model_file.split('-')[1]), float(model_file.split('-')[2]), 1, 10) # else: # print(f'Error : {model_file} not found') # sys.exit(0) # model.load_state_dict(torch.load(model_loc)['state_dict']) # return model
def get_net(network: str, num_classes) -> torch.nn.Module: return VGG('VGG16', num_classes=num_classes) if network == 'VGG16' else \ ResNet34(num_classes=num_classes) if network == 'ResNet34' else \ PreActResNet18(num_classes=num_classes) if network == 'PreActResNet18' else \ GoogLeNet(num_classes=num_classes) if network == 'GoogLeNet' else \ densenet_cifar(num_classes=num_classes) if network == 'densenet_cifar' else \ ResNeXt29_2x64d(num_classes=num_classes) if network == 'ResNeXt29_2x64d' else \ MobileNet(num_classes=num_classes) if network == 'MobileNet' else \ MobileNetV2(num_classes=num_classes) if network == 'MobileNetV2' else \ DPN92(num_classes=num_classes) if network == 'DPN92' else \ ShuffleNetG2(num_classes=num_classes) if network == 'ShuffleNetG2' else \ SENet18(num_classes=num_classes) if network == 'SENet18' else \ ShuffleNetV2(1, num_classes=num_classes) if network == 'ShuffleNetV2' else \ EfficientNetB0( num_classes=num_classes) if network == 'EfficientNetB0' else None
def _get_model(self, backbone): if backbone == 'resnet18': model = resnet18(pretrained=True, num_classes=self.args.classnum).to(self.args.device) elif backbone == 'resnet34': model = resnet34(pretrained=True, num_classes=self.args.classnum).to(self.args.device) elif backbone == 'resnet50': model = resnet50(pretrained=True, num_classes=self.args.classnum).to(self.args.device) elif backbone == 'resnet101': model = resnet101(pretrained=True, num_classes=self.args.classnum).to(self.args.device) elif backbone == 'resnet152': model = resnet152(pretrained=True, num_classes=self.args.classnum).to(self.args.device) elif backbone == 'preact_resnet18': model = PreActResNet18(num_classes=self.args.classnum, input_size=self.args.image_size, input_dim=self.args.input_dim).to(self.args.device) elif backbone == 'preact_resnet34': model = PreActResNet34(num_classes=self.args.classnum, input_size=self.args.image_size, input_dim=self.args.input_dim).to(self.args.device) elif backbone == 'preact_resnet50': model = PreActResNet50(num_classes=self.args.classnum, input_size=self.args.image_size, input_dim=self.args.input_dim).to(self.args.device) elif backbone == 'preact_resnet101': model = PreActResNet101(num_classes=self.args.classnum, input_size=self.args.image_size, input_dim=self.args.input_dim).to(self.args.device) elif backbone == 'preact_resnet152': model = PreActResNet152(num_classes=self.args.classnum, input_size=self.args.image_size, input_dim=self.args.input_dim).to(self.args.device) elif backbone == 'densenet121': model = densenet121(num_classes=self.args.classnum, pretrained=True).to(self.args.device) elif backbone == 'densenet161': model = densenet161(num_classes=self.args.classnum, pretrained=True).to(self.args.device) elif backbone == 'densenet169': model = densenet169(num_classes=self.args.classnum, pretrained=True).to(self.args.device) elif backbone == 'densenet201': model = densenet201(num_classes=self.args.classnum, pretrained=True).to(self.args.device) elif backbone == 'mlp': model = MLPNet().to(self.args.device) elif backbone == 'cnn_small' or backbone == "CNN_SMALL": model = CNN_small(self.args.classnum).to(self.args.device) elif backbone == "cnn" or backbone == "CNN": model = CNN(n_outputs=self.args.classnum, input_channel=self.args.input_dim, linear_num=self.args.linear_num).to(self.args.device) else: print("No matched backbone. Using ResNet50...") model = resnet50(pretrained=True, num_classes=self.args.classnum, input_size=self.args.image_size).to(self.args.device) return model
def select(self, model, args): """ Selector utility to create models from model directory :param model: which model to select. Currently choices are: (cnn | resnet | preact_resnet | densenet | wresnet) :return: neural network to be trained """ if model == 'cnn': net = SimpleModel(in_shape=self.in_shape, activation=args.activation, num_classes=self.num_classes, filters=args.filters, strides=args.strides, kernel_sizes=args.kernel_sizes, linear_widths=args.linear_widths, use_batch_norm=args.use_batch_norm) else: assert (args.dataset != 'MNIST' and args.dataset != 'Fashion-MNIST'), \ "Cannot use resnet or densenet for mnist style data" if model == 'resnet': assert args.resdepth in [18, 34, 50, 101, 152], \ "Non-standard and unsupported resnet depth ({})".format(args.resdepth) if args.resdepth == 18: net = ResNet18(self.num_classes) elif args.resdepth == 34: net = ResNet34(self.num_classes) elif args.resdepth == 50: net = ResNet50(self.num_classes) elif args.resdepth == 101: net = ResNet101(self.num_classes) else: net = ResNet152() elif model == 'densenet': assert args.resdepth in [121, 161, 169, 201], \ "Non-standard and unsupported densenet depth ({})".format(args.resdepth) if args.resdepth == 121: net = DenseNet121( growth_rate=12, num_classes=self.num_classes ) # NB NOTE: growth rate controls cifar implementation elif args.resdepth == 161: net = DenseNet161(growth_rate=12, num_classes=self.num_classes) elif args.resdepth == 169: net = DenseNet169(growth_rate=12, num_classes=self.num_classes) else: net = DenseNet201(growth_rate=12, num_classes=self.num_classes) elif model == 'preact_resnet': assert args.resdepth in [18, 34, 50, 101, 152], \ "Non-standard and unsupported preact resnet depth ({})".format(args.resdepth) if args.resdepth == 18: net = PreActResNet18(self.num_classes) elif args.resdepth == 34: net = PreActResNet34(self.num_classes) elif args.resdepth == 50: net = PreActResNet50(self.num_classes) elif args.resdepth == 101: net = PreActResNet101(self.num_classes) else: net = PreActResNet152() elif model == 'wresnet': assert ((args.resdepth - 4) % 6 == 0), \ "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format(args.resdepth) net = WideResNet(depth=args.resdepth, num_classes=self.num_classes, widen_factor=args.widen_factor) else: raise NotImplementedError( 'Model {} not supported'.format(model)) return net
EMBEDDING_SIZE = 500 if dataset == 'mnist' else 512 def experiment_id(dataset, k, tau, nloglr, method): return 'baseline-resnet-%s-%s-k%d-t%d-b%d' % (dataset, method, k, tau, nloglr) e_id = experiment_id(dataset, k, tau * 10, args.nloglr, method) gpu = torch.device('cuda') if dataset == 'mnist': h_phi = ConvNet().to(gpu) else: h_phi = PreActResNet18( num_channels=3 if dataset == 'cifar10' else 1).to(gpu) optimizer = torch.optim.SGD(h_phi.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) linear_layer = torch.nn.Linear(EMBEDDING_SIZE, 10).to(device=gpu) ce_loss = torch.nn.CrossEntropyLoss() batched_train = split.get_train_loader(NUM_TRAIN_QUERIES) def train(epoch): h_phi.train() to_average = []
def select(self, model, path_fc=False, upsample='pixel'): if model == 'cnn': net = SimpleModel( in_shape=self.in_shape, activation=self.activation, num_classes=self.num_classes, filters=self.filters, ) else: assert (self.dataset != 'MNIST' and self.dataset != 'Fashion-MNIST' ), "Cannot use resnet or densenet for mnist style data" if model == 'resnet': assert self.resdepth in [ 18, 34, 50, 101, 152 ], "Non-standard and unsupported resnet depth ({})".format( self.resdepth) if self.resdepth == 18: net = ResNet18() elif self.resdepth == 34: net = ResNet34() elif self.resdepth == 50: net = ResNet50() elif self.resdepth == 101: net = ResNet101() else: net = ResNet152() elif model == 'densenet': assert self.resdepth in [ 121, 161, 169, 201 ], "Non-standard and unsupported densenet depth ({})".format( self.resdepth) if self.resdepth == 121: net = DenseNet121() elif self.resdepth == 161: net = DenseNet161() elif self.resdepth == 169: net = DenseNet169() else: net = DenseNet201() elif model == 'preact_resnet': assert self.resdepth in [ 10, 18, 34, 50, 101, 152 ], "Non-standard and unsupported preact resnet depth ({})".format( self.resdepth) if self.resdepth == 10: net = PreActResNet10(path_fc=path_fc, num_classes=self.num_classes, upsample=upsample) elif self.resdepth == 18: net = PreActResNet18() elif self.resdepth == 34: net = PreActResNet34() elif self.resdepth == 50: net = PreActResNet50() elif self.resdepth == 101: net = PreActResNet101() else: net = PreActResNet152() elif model == 'wresnet': assert ( (self.resdepth - 4) % 6 == 0 ), "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format( self.resdepth) net = WideResNet(depth=self.resdepth, num_classes=self.num_classes, widen_factor=self.widen_factor) return net
def get_model(model_name, dataset_name): if dataset_name == "mnist": grayscale = True num_classes = 10 elif dataset_name == "cifar10": grayscale = False num_classes = 10 elif dataset_name == "cifar100": grayscale = False num_classes = 100 elif dataset_name == "tiny-imagenet": grayscale = False num_classes = 200 elif dataset_name == "clothing1m": grayscale = False num_classes = 14 else: raise NameError("Invalid dataset") if model_name == "jocor_model": if dataset_name == "mnist": net = MLPNet() elif dataset_name == "cifar10": net = CNN(n_outputs=10) elif dataset_name == "cifar100": net = CNN(n_outputs=100) elif dataset_name == "tiny-imagenet": net = PreActResNet18(num_classes) elif dataset_name == "clothing1m": model = getattr(resnet, "resnet18") net = model(grayscale, num_classes) net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]) return net elif model_name.startswith("attention"): model = getattr(attention, model_name) elif model_name.startswith("densenet"): model = getattr(densenet, model_name) elif model_name.startswith("googlenet"): model = getattr(googlenet, model_name) elif model_name.startswith("inceptionv3"): model = getattr(inceptionv3, model_name) elif model_name.startswith("inception"): model = getattr(inceptionv4, model_name) elif model_name.startswith("mobilenetv2"): model = getattr(mobilenetv2, model_name) elif model_name.startswith("mobilenet"): model = getattr(mobilenet, model_name) elif model_name.startswith("nasnet"): model = getattr(nasnet, model_name) elif model_name.startswith("preactresnet"): model = getattr(preactresnet, model_name) elif model_name.startswith("resnet"): model = getattr(resnet, model_name) elif model_name.startswith("resnext"): model = getattr(resnext, model_name) elif model_name.startswith("rir"): model = getattr(rir, model_name) elif model_name.startswith("seresnet"): model = getattr(senet, model_name) elif model_name.startswith("shufflenetv2"): model = getattr(shufflenetv2, model_name) elif model_name.startswith("shufflenet"): model = getattr(shufflenet, model_name) elif model_name.startswith("squeezenet"): model = getattr(squeezenet, model_name) elif model_name.startswith("vgg"): model = getattr(vgg, model_name) elif model_name.startswith("xception"): model = getattr(xception, model_name) else: raise NameError("Invalid model") net = model(grayscale, num_classes) if dataset_name == "clothing1m": net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3]) return model(grayscale, num_classes)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--gpus', default='0,1,2,3', type=str) parser.add_argument('--max_epoch', default=200, type=int) parser.add_argument('--lr_decay_steps', default='160,190,200', type=str) parser.add_argument('--exp', default='', type=str) parser.add_argument('--res_path', default='', type=str) parser.add_argument('--resume_path', default='', type=str) parser.add_argument('--pretrain_path', default='', type=str) parser.add_argument('--dataset', default='imagenet', type=str) parser.add_argument('--lr', default=0.03, type=float) parser.add_argument('--lr_decay_rate', default=0.1, type=float) parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--weight_decay', default=5e-4, type=float) parser.add_argument('--n_workers', default=32, type=int) parser.add_argument('--n_background', default=4096, type=int) parser.add_argument('--t', default=0.07, type=float) parser.add_argument('--m', default=0.5, type=float) parser.add_argument('--dropout', action='store_true') parser.add_argument('--blur', action='store_true') parser.add_argument('--cos', action='store_true') parser.add_argument('--network', default='resnet18', type=str) parser.add_argument('--mix', action='store_true') parser.add_argument('--not_hardpos', action='store_true') parser.add_argument('--InvP', type=int, default=1) parser.add_argument('--ramp_up', default='binary', type=str) parser.add_argument('--lam_inv', default=0.6, type=float) parser.add_argument('--lam_mix', default=1.0, type=float) parser.add_argument('--diffusion_layer', default=3, type=int) # for cifar 10 the best diffusion_layer is 3 and cifar 100 is 4 # for imagenet I have only tested when diffusion_layer = 3 parser.add_argument('--K_nearst', default=4, type=int) parser.add_argument('--n_pos', default=50, type=int) # for cifar10 the best n_pos is 20, for cifar 100 the best is 10 or 20 parser.add_argument('--exclusive', default=1, type=int) parser.add_argument('--nonlinearhead', default=0, type=int) # exclusive best to be 0 global args args = parser.parse_args() exp_identifier = get_expidentifier([ 'mix', 'network', 'lam_inv', 'lam_mix', 'diffusion_layer', 'K_nearst', 'n_pos', 'exclusive', 'max_epoch', 'ramp_up', 'nonlinearhead', 't', 'weight_decay' ], args) if not args.InvP: exp_identifier = 'hard' args.exp = os.path.join(args.exp, exp_identifier) if not os.path.exists(args.exp): os.makedirs(args.exp) if not os.path.exists(os.path.join(args.exp, 'runs')): os.makedirs(os.path.join(args.exp, 'runs')) if not os.path.exists(os.path.join(args.exp, 'models')): os.makedirs(os.path.join(args.exp, 'models')) if not os.path.exists(os.path.join(args.exp, 'logs')): os.makedirs(os.path.join(args.exp, 'logs')) logger = getLogger(args.exp) device_ids = list(map(lambda x: int(x), args.gpus.split(','))) device = torch.device('cuda: 0') if args.dataset.startswith('cifar'): train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = cifar.get_dataloader( args) elif args.dataset.startswith('imagenet'): train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = imagenet.get_instance_dataloader( args) elif args.dataset == 'svhn': train_loader, val_loader, train_ordered_labels, train_dataset, val_dataset = svhn.get_dataloader( args) # create model if args.network == 'alexnet': network = alexnet(128) if args.network == 'alexnet_cifar': network = AlexNet_cifar(128) elif args.network == 'resnet18_cifar': network = ResNet18_cifar(128, dropout=args.dropout, non_linear_head=args.nonlinearhead) elif args.network == 'resnet50_cifar': network = ResNet50_cifar(128, dropout=args.dropout) elif args.network == 'wide_resnet28': network = WideResNetInstance(28, 2) elif args.network == 'resnet18': network = resnet18(non_linear_head=args.nonlinearhead) elif args.network == 'pre-resnet18': network = PreActResNet18(128) elif args.network == 'resnet50': network = resnet50(non_linear_head=args.nonlinearhead) elif args.network == 'pre-resnet50': network = PreActResNet50(128) network = nn.DataParallel(network, device_ids=device_ids) network.to(device) # create optimizer if args.network == 'pre-resnet18' or args.network == 'pre-resnet50': logging.info( colorful( 'Exclude bns from weight decay, copied from LocalAggregation proposed by Zhuang et al [ICCV 2019]' )) parameters = exclude_bn_weight_bias_from_weight_decay( network, weight_decay=args.weight_decay) else: parameters = network.parameters() optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, ) cudnn.benchmark = True # create memory_bank global writer writer = SummaryWriter(comment='InvariancePropagation', logdir=os.path.join(args.exp, 'runs')) memory_bank = objective.MemoryBank_v1(len(train_dataset), train_ordered_labels, writer, device, m=args.m) # create criterion criterionA = objective.InvariancePropagationLoss( args.t, diffusion_layer=args.diffusion_layer, k=args.K_nearst, n_pos=args.n_pos, exclusive=args.exclusive, InvP=args.InvP, hard_pos=(not args.not_hardpos)) criterionB = objective.MixPointLoss(args.t) if args.ramp_up == 'binary': ramp_up = lambda i_epoch: objective.BinaryRampUp(i_epoch, 30) elif args.ramp_up == 'gaussian': ramp_up = lambda i_epoch: objective.GaussianRampUp(i_epoch, 30, 5) elif args.ramp_up == 'zero': ramp_up = lambda i_epoch: 1 logging.info(beautify(args)) start_epoch = 0 if args.pretrain_path != '' and args.pretrain_path != 'none': logging.info('loading pretrained file from {}'.format( args.pretrain_path)) checkpoint = torch.load(args.pretrain_path) network.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) _memory_bank = checkpoint['memory_banks'] try: _neigh = checkpoint['neigh'] memory_bank.neigh = _neigh except: logging.info( colorful( 'The Pretrained Path has No NEIGH and require a epoch to re-calculate' )) memory_bank.points = _memory_bank start_epoch = checkpoint['epoch'] else: initialize_memorybank(network, train_loader, device, memory_bank) logging.info('start training') best_acc = 0.0 try: for i_epoch in range(start_epoch, args.max_epoch): adjust_learning_rate(args.lr, args.lr_decay_steps, optimizer, i_epoch, lr_decay_rate=args.lr_decay_rate, cos=args.cos, max_epoch=args.max_epoch) train(i_epoch, network, criterionA, criterionB, optimizer, train_loader, device, memory_bank, ramp_up) save_name = 'checkpoint.pth' checkpoint = { 'epoch': i_epoch + 1, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), 'memory_banks': memory_bank.points, 'neigh': memory_bank.neigh, } torch.save(checkpoint, os.path.join(args.exp, 'models', save_name)) # scheduler.step() # validate(network, memory_bank, val_loader, train_ordered_labels, device) acc = kNN(i_epoch, network, memory_bank, val_loader, train_ordered_labels, K=200, sigma=0.07) if acc >= best_acc: best_acc = acc torch.save(checkpoint, os.path.join(args.exp, 'models', 'best.pth')) if i_epoch in [30, 60, 120, 160, 200, 400, 600]: torch.save( checkpoint, os.path.join(args.exp, 'models', '{}.pth'.format(i_epoch + 1))) args.y_best_acc = best_acc logging.info( colorful('[Epoch: {}] val acc: {:.4f}'.format(i_epoch, acc))) logging.info( colorful('[Epoch: {}] best acc: {:.4f}'.format( i_epoch, best_acc))) writer.add_scalar('acc', acc, i_epoch + 1) with torch.no_grad(): for name, param in network.named_parameters(): if 'bn' not in name: writer.add_histogram(name, param, i_epoch) # cluster except KeyboardInterrupt as e: logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch)) save_result(args) exit() save_result(args)
# labels = labels.type_as(torch.LongTensor()).view(-1) - 1 images = Variable(images, requires_grad=False).cuda() labels = Variable(labels, requires_grad=False).cuda() pred, _ = cnn(images) test_loss += loss_func(pred, labels).data[0] pred = torch.max(pred.data, 1)[1] total += labels.size(0) correct += (pred == labels.data).sum() val_acc = correct / total val_loss = test_loss / total cnn.train() return val_acc, val_loss if args.model == 'resnet18': cnn = PreActResNet18(channels=num_channels, num_classes=num_classes) elif args.model == 'resnet34': cnn = PreActResNet34(channels=num_channels, num_classes=num_classes) elif args.model == 'resnet50': cnn = PreActResNet50(channels=num_channels, num_classes=num_classes) elif args.model == 'resnet101': cnn = PreActResNet101(channels=num_channels, num_classes=num_classes) elif args.model == 'resnet152': cnn = PreActResNet152(channels=num_channels, num_classes=num_classes) elif args.model == 'vgg': cnn = VGG(depth=16, num_classes=num_classes, channels=num_channels) elif args.model == 'wideresnet': if args.dataset == 'svhn': cnn = Wide_ResNet(depth=16, num_classes=num_classes, widen_factor=8, dropout_rate=args.dropout_rate) else: cnn = Wide_ResNet(depth=28, num_classes=num_classes, widen_factor=10, dropout_rate=args.dropout_rate)