def main2(args): best_prec1 = 0.0 torch.backends.cudnn.deterministic = not args.cudaNoise torch.manual_seed(time.time()) if args.init != "None": args.name = "lrnet_%s" % args.init if args.tensorboard: configure(f"runs/{args.name}") dstype = nondigits(args.dataset) if dstype == "cifar": means = [125.3, 123.0, 113.9] stds = [63.0, 62.1, 66.7] elif dstype == "imgnet": means = [123.3, 118.1, 108.0] stds = [54.1, 52.6, 53.2] normalize = transforms.Normalize( mean=[x / 255.0 for x in means], std=[x / 255.0 for x in stds], ) writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args)) args.classes = onlydigits(args.dataset) if args.augment: transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4), mode="reflect").squeeze()), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) else: transform_train = transforms.Compose( [transforms.ToTensor(), normalize]) if args.cutout: transform_train.transforms.append( Cutout(n_holes=args.n_holes, length=args.length)) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) kwargs = {"num_workers": 1, "pin_memory": True} assert dstype in ["cifar", "cinic", "imgnet"] if dstype == "cifar": train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]("../data", train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]("../data", train=False, transform=transform_test), batch_size=args.batch_size, shuffle=True, **kwargs, ) elif dstype == "cinic": cinic_directory = "%s/cinic10" % args.dir cinic_mean = [0.47889522, 0.47227842, 0.43047404] cinic_std = [0.24205776, 0.23828046, 0.25874835] train_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(cinic_directory + '/train', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=cinic_mean, std=cinic_std) ])), batch_size=args.batch_size, shuffle=True, **kwargs, ) print("Using CINIC10 dataset") val_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(cinic_directory + '/valid', transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=cinic_mean, std=cinic_std) ])), batch_size=args.batch_size, shuffle=True, **kwargs, ) elif dstype == "imgnet": print("Using converted imagenet") train_loader = torch.utils.data.DataLoader( IMGNET("%s" % args.dir, train=True, transform=transform_train, target_transform=None, classes=args.classes), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( IMGNET("%s" % args.dir, train=False, transform=transform_test, target_transform=None, classes=args.classes), batch_size=args.batch_size, shuffle=True, **kwargs, ) else: print("Error matching dataset %s" % dstype) ##print("main bn:") ##print(args.batchnorm) ##print("main fixup:") ##print(args.fixup) if args.prune: pruner_state = getPruneMask(args) if pruner_state is None: print("Failed to prune network, aborting") return None if args.arch.lower() == "constnet": model = WideResNet( args.layers, args.classes, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, varnet=args.varnet, noise=args.noise, lrelu=args.lrelu, sigmaW=args.sigmaW, init=args.init, dropl1=args.dropl1, ) elif args.arch.lower() == "leakynet": model = LRNet( args.layers, args.classes, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, varnet=args.varnet, noise=args.noise, lrelu=args.lrelu, sigmaW=args.sigmaW, init=args.init, ) else: print("arch %s is not supported" % args.arch) return None ##draw(args,model) complex installation param_num = sum([p.data.nelement() for p in model.parameters()]) print(f"Number of model parameters: {param_num}") if torch.cuda.device_count() > 1: start = int(args.device[0]) end = int(args.device[2]) + 1 torch.cuda.set_device(start) dev_list = [] for i in range(start, end): dev_list.append("cuda:%d" % i) model = torch.nn.DataParallel(model, device_ids=dev_list) model = model.cuda() if args.freeze > 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['scale'], name.split('.')): cnt = cnt + 1 if cnt == args.freeze: break if cnt >= args.freeze_start: ## if intersection(['conv','conv1'],name.split('.')): ## print("Freezing Block: %s" % name.split('.')[1:3] ) if not intersection(['conv_res', 'fc'], name.split('.')): param.requires_grad = False print("Freezing Block: %s" % name) elif args.freeze < 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['scale'], name.split('.')): cnt = cnt + 1 if cnt > args.layers - 3 + args.freeze - 1: ## if intersection(['conv','conv1'],name.split('.')): ## print("Freezing Block: %s" % name ) if not intersection(['conv_res', 'fc'], name.split('.')): param.requires_grad = False print("Freezing Block: %s" % name) if args.res_freeze > 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['conv_res'], name.split('.')): cnt = cnt + 1 if cnt > args.res_freeze_start: param.requires_grad = False print("Freezing Block: %s" % name) if cnt >= args.res_freeze: break elif args.res_freeze < 0: cnt = 0 for name, param in model.named_parameters(): if intersection(['conv_res'], name.split('.')): cnt = cnt + 1 if cnt > 3 + args.res_freeze: param.requires_grad = False print("Freezing Block: %s" % name) if args.prune: if args.prune_epoch >= 100: weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune else: weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % ( args.prune, args.prune_epoch) if os.path.isfile(weightsFile): print(f"=> loading checkpoint {weightsFile}") checkpoint = torch.load(weightsFile) model.load_state_dict(checkpoint["state_dict"]) print( f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})" ) else: if args.prune_epoch == 0: print(f"=> No source data, Restarting network from scratch") else: print(f"=> no checkpoint found at {weightsFile}, aborting...") return None else: if args.resume: tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume if os.path.isfile(tarfile): print(f"=> loading checkpoint {args.resume}") checkpoint = torch.load(tarfile) args.start_epoch = checkpoint["epoch"] best_prec1 = checkpoint["best_prec1"] model.load_state_dict(checkpoint["state_dict"]) print( f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})" ) else: print(f"=> no checkpoint found at {tarfile}, aborting...") return None cudnn.benchmark = True criterion = nn.CrossEntropyLoss().cuda() if args.optimizer.lower() == 'sgd': optimizer = torch.optim.SGD( model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay, ) elif args.optimizer.lower() == 'radam': optimizer = RAdam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) if args.prune and pruner_state is not None: cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff) params_retrain = get_params_for_pruning(args, model) pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain, cutoff_retrain) pruner_retrain.load_state_dict(pruner_state) pruner_retrain.prune(update_state=False) pruned_weights_count = count_pruned_weights(params_retrain, args.cutoff) params_left = param_num - pruned_weights_count print("Pruned %d weights, New model size: %d/%d (%d%%)" % (pruned_weights_count, params_left, param_num, int(100 * params_left / param_num))) else: pruner_retrain = None if args.eval: best_prec1 = validate(args, val_loader, model, criterion, 0, None) else: if args.varnet: save_checkpoint( args, { "epoch": 0, "state_dict": model.state_dict(), "best_prec1": 0.0, }, True, ) best_prec1 = 0.0 turns_above_50 = 0 for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(args, optimizer, epoch + 1) train(args, train_loader, model, criterion, optimizer, epoch, pruner_retrain, writer) prec1 = validate(args, val_loader, model, criterion, epoch, writer) correlation.measure_correlation(model, epoch, writer=writer) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) if args.savenet: save_checkpoint( args, { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_prec1": best_prec1, }, is_best, ) if args.symmetry_break: if prec1 > 50.0: turns_above_50 += 1 if turns_above_50 > 3: return epoch writer.close() print("Best accuracy: ", best_prec1) return best_prec1
def main(args): save_folder = '%s_%s' % (args.dataset, args.affix) log_folder = os.path.join(args.log_root, save_folder) model_folder = os.path.join(args.model_root, save_folder) makedirs(log_folder) makedirs(model_folder) setattr(args, 'log_folder', log_folder) setattr(args, 'model_folder', model_folder) logger = create_logger(log_folder, args.todo, 'info') print_args(args, logger) # Using a WideResNet model model = WideResNet(depth=34, num_classes=10, widen_factor=1, dropRate=0.0) flop, param = get_model_infos(model, (1, 3, 32, 32)) logger.info('Model Info: FLOP = {:.2f} M, Params = {:.2f} MB'.format( flop, param)) mean = [0] std = [1] inputs_box = (min((0 - m) / s for m, s in zip(mean, std)), max((1 - m) / s for m, s in zip(mean, std))) attack = carlini_wagner_L2.L2Adversary(targeted=False, confidence=0.0, search_steps=10, box=inputs_box, optimizer_lr=5e-4) if torch.cuda.is_available(): model.cuda() trainer = Trainer(args, logger, attack) if args.todo == 'train': transform_train = tv.transforms.Compose([ tv.transforms.ToTensor(), tv.transforms.Lambda(lambda x: F.pad( x.unsqueeze(0), (4, 4, 4, 4), mode='constant', value=0).squeeze()), tv.transforms.ToPILImage(), tv.transforms.RandomCrop(32), tv.transforms.RandomHorizontalFlip(), tv.transforms.ToTensor(), ]) tr_dataset = tv.datasets.CIFAR10(args.data_root, train=True, transform=transform_train, download=True) tr_loader = DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) # evaluation during training te_dataset = tv.datasets.CIFAR10(args.data_root, train=False, transform=tv.transforms.ToTensor(), download=True) te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) trainer.train(model, tr_loader, te_loader, args.adv_train) elif args.todo == 'test': pass elif args.todo == 'cw_test': model = WideResNet(depth=34, num_classes=10, widen_factor=2, dropRate=0.0) print(model) model.load_state_dict( torch.load(args.cw_attack_modelpath, map_location=lambda storage, loc: storage)) model.cuda() te_dataset = tv.datasets.CIFAR10(args.data_root, train=False, transform=tv.transforms.ToTensor(), download=True) te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) cw_attack_test(model, te_loader) else: raise NotImplementedError
def getPruneMask(args): baseTar = "runs/%s-net/checkpoint.pth.tar" % args.prune if os.path.isfile(baseTar): classes = onlydigits(args.prune_classes) if classes == 0: classes = args.classes fullModel = WideResNet( args.layers, classes, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, varnet=args.varnet, noise=args.noise, lrelu=args.lrelu, sigmaW=args.sigmaW, ) if torch.cuda.device_count() > 1: start = int(args.device[0]) end = int(args.device[2]) + 1 torch.cuda.set_device(start) dev_list = [] for i in range(start, end): dev_list.append("cuda:%d" % i) fullModel = torch.nn.DataParallel(fullModel, device_ids=dev_list) fullModel = fullModel.cuda() print(f"=> loading checkpoint {baseTar}") checkpoint = torch.load(baseTar) fullModel.load_state_dict(checkpoint["state_dict"]) # --------------------------- # # --- Pruning Setup Start --- # cutoff = prunhild.cutoff.LocalRatioCutoff(args.cutoff) # don't prune the final bias weights params = get_params_for_pruning(args, fullModel) print(params) pruner = prunhild.pruner.CutoffPruner(params, cutoff, prune_online=True) pruner.prune() print( f"=> loaded checkpoint '{baseTar}' (epoch {checkpoint['epoch']})") if torch.cuda.device_count() > 1: start = int(args.device[0]) end = int(args.device[2]) + 1 for i in range(start, end): torch.cuda.set_device(i) torch.cuda.empty_cache() mask = pruner.state_dict() if args.randomize_mask: mask = randomize_mask(mask, args.cutoff) return mask else: print(f"=> no checkpoint found at {baseTar}") return None
def main(): global args, best_prec1 args = parser.parse_args() if args.tensorboard: configure(f"runs/{args.name}") normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]], ) if args.augment: transform_train = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda( lambda x: F.pad( x.unsqueeze(0), (4, 4, 4, 4), mode="reflect" ).squeeze() ), transforms.ToPILImage(), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ) else: transform_train = transforms.Compose([transforms.ToTensor(), normalize]) if args.cutout: transform_train.transforms.append( Cutout(n_holes=args.n_holes, length=args.length) ) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) kwargs = {"num_workers": 1, "pin_memory": True} assert args.dataset == "cifar10" or args.dataset == "cifar100" train_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]( "../data", train=True, download=True, transform=transform_train ), batch_size=args.batch_size, shuffle=True, **kwargs, ) val_loader = torch.utils.data.DataLoader( datasets.__dict__[args.dataset.upper()]( "../data", train=False, transform=transform_test ), batch_size=args.batch_size, shuffle=True, **kwargs, ) model = WideResNet( args.layers, args.dataset == "cifar10" and 10 or 100, args.widen_factor, droprate=args.droprate, use_bn=args.batchnorm, use_fixup=args.fixup, ) param_num = sum([p.data.nelement() for p in model.parameters()]) print(f"Number of model parameters: {param_num}") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) model = model.cuda() if args.resume: if os.path.isfile(args.resume): print(f"=> loading checkpoint {args.resume}") checkpoint = torch.load(args.resume) args.start_epoch = checkpoint["epoch"] best_prec1 = checkpoint["best_prec1"] model.load_state_dict(checkpoint["state_dict"]) print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") else: print(f"=> no checkpoint found at {args.resume}") cudnn.benchmark = True criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD( model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay, ) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch + 1) train(train_loader, model, criterion, optimizer, epoch) prec1 = validate(val_loader, model, criterion, epoch) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_prec1": best_prec1, }, is_best, ) print("Best accuracy: ", best_prec1)