def test_grad_through_normalize(): tensor = torch.rand((2, 1, 28, 28)) tensor.requires_grad_() mean = torch.tensor((0., )) std = torch.tensor((1., )) normalize = NormalizeByChannelMeanStd(mean, std) loss = (normalize(tensor)**2).sum() loss.backward() assert torch_allclose(2 * tensor, tensor.grad)
class Transform(object): classifier_training = transforms.Compose([ transforms.Normalize((-1., -1., -1.), (2., 2., 2.)), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) classifier_testing = transforms.Compose([ transforms.Normalize((-1., -1., -1.), (2., 2., 2.)), transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) ]) default = transforms.ToTensor() gan_deprocess_layer = [] classifier_preprocess_layer = NormalizeByChannelMeanStd(MEAN, STD).cuda()
class Transform(object): gan_training = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) classifier_training = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD) ]) classifier_testing = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD) ]) # transformations used during testing / generating attacks # we assume loaded images to be in [0, 1] # since g(z) is in [-1, 1], the output should be deprocessed into [0, 1] default = transforms.ToTensor() gan_deprocess_layer = NormalizeByChannelMeanStd([-1., -1., -1.], [2., 2., 2.]).cuda() classifier_preprocess_layer = NormalizeByChannelMeanStd( CIFAR10_MEAN, CIFAR10_STD).cuda()
def __init__(self, block, num_blocks, num_classes=10): super(PreActResNet, self).__init__() self.in_planes = 64 self.nomalize = NormalizeByChannelMeanStd( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) self.linear_main = nn.Linear(512 * block.expansion, num_classes)
def get_models(args, train=True, as_ensemble=False, model_file=None, leaky_relu=False, dataset=None): mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda() if (dataset == "STL-10"): std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda() #std = torch.tensor([0.2471, 0.2435, 0.2616], dtype=torch.float32).cuda() elif (dataset == "CIFAR-10"): std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda() normalizer = NormalizeByChannelMeanStd(mean=mean, std=std) if model_file: state_dict = torch.load(model_file) if train: print('Loading pre-trained models...') if args.arch.lower() == 'resnet': model = ResNet18() #depth=args.depth, leaky_relu=leaky_relu) else: raise ValueError('[{:s}] architecture is not supported yet...') # we include input normalization as a part of the model model = ModelWrapper(model, normalizer) if model_file: model.load_state_dict(state_dict) if train: model.train() else: model.eval() model = model.cuda() return model
def get_models(args, train=True, as_ensemble=False, model_file=None, leaky_relu=False): models = [] mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda() std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda() normalizer = NormalizeByChannelMeanStd(mean=mean, std=std) if model_file: state_dict = torch.load(model_file) if train: print('Loading pre-trained models...') iter_m = state_dict.keys() if model_file else range(args.model_num) for i in iter_m: if args.arch.lower() == 'resnet': model = ResNet(depth=args.depth, leaky_relu=leaky_relu) else: raise ValueError('[{:s}] architecture is not supported yet...') # we include input normalization as a part of the model model = ModelWrapper(model, normalizer) if model_file: model.load_state_dict(state_dict[i]) if train: model.train() else: model.eval() model = model.cuda() models.append(model) if as_ensemble: assert not train, 'Must be in eval mode when getting models to form an ensemble' ensemble = Ensemble(models) ensemble.eval() return ensemble else: return models
def main(): args = get_args() if not os.path.exists(args.fname): os.makedirs(args.fname) logger = logging.getLogger(__name__) logging.basicConfig( format='[%(asctime)s] - %(message)s', datefmt='%Y/%m/%d %H:%M:%S', level=logging.DEBUG, handlers=[ logging.FileHandler( os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')), logging.StreamHandler() ]) logger.info(args) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) transforms = [Crop(32, 32), FlipLR()] # transforms = [Crop(32, 32)] if args.cutout: transforms.append(Cutout(args.cutout_len, args.cutout_len)) if args.val: try: dataset = torch.load("cifar10_validation_split.pth") except: print( "Couldn't find a dataset with a validation split, did you run " "generate_validation.py?") return val_set = list( zip(transpose(dataset['val']['data'] / 255.), dataset['val']['labels'])) val_batches = Batches(val_set, args.batch_size, shuffle=False, num_workers=2) else: dataset = cifar10(args.data_dir) train_set = list( zip(transpose(pad(dataset['train']['data'], 4) / 255.), dataset['train']['labels'])) train_set_x = Transform(train_set, transforms) train_batches = Batches(train_set_x, args.batch_size, shuffle=True, set_random_choices=True, num_workers=2) test_set = list( zip(transpose(dataset['test']['data'] / 255.), dataset['test']['labels'])) test_batches = Batches(test_set, args.batch_size, shuffle=False, num_workers=2) trn_epsilon = (args.trn_epsilon / 255.) trn_pgd_alpha = (args.trn_pgd_alpha / 255.) tst_epsilon = (args.tst_epsilon / 255.) tst_pgd_alpha = (args.tst_pgd_alpha / 255.) if args.model == 'PreActResNet18': model = PreActResNet18() elif args.model == 'WideResNet': model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0) elif args.model == 'DenseNet121': model = DenseNet121() elif args.model == 'ResNet18': model = ResNet18() else: raise ValueError("Unknown model") ### temp testing ### model = model.cuda() # model = nn.DataParallel(model).cuda() model.train() ################################## # load pretrained model if needed if args.trn_adv_models != 'None': if args.trn_adv_arch == 'PreActResNet18': trn_adv_model = PreActResNet18() elif args.trn_adv_arch == 'WideResNet': trn_adv_model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0) elif args.trn_adv_arch == 'DenseNet121': trn_adv_model = DenseNet121() elif args.trn_adv_arch == 'ResNet18': trn_adv_model = ResNet18() trn_adv_model = nn.DataParallel(trn_adv_model).cuda() trn_adv_model.load_state_dict( torch.load( os.path.join('./adv_models', args.trn_adv_models, 'model_best.pth'))['state_dict']) logger.info(f'loaded adv_model: {args.trn_adv_models}') else: trn_adv_model = None if args.tst_adv_models != 'None': if args.tst_adv_arch == 'PreActResNet18': tst_adv_model = PreActResNet18() elif args.tst_adv_arch == 'WideResNet': tst_adv_model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0) elif args.tst_adv_arch == 'DenseNet121': tst_adv_model = DenseNet121() elif args.tst_adv_arch == 'ResNet18': tst_adv_model = ResNet18() ### temp testing ### tst_adv_model = tst_adv_model.cuda() tst_adv_model.load_state_dict( torch.load( os.path.join('./adv_models', args.tst_adv_models, 'model_best.pth'))) # tst_adv_model = nn.DataParallel(tst_adv_model).cuda() # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict']) logger.info(f'loaded adv_model: {args.tst_adv_models}') else: tst_adv_model = None ################################## if args.l2: decay, no_decay = [], [] for name, param in model.named_parameters(): if 'bn' not in name and 'bias' not in name: decay.append(param) else: no_decay.append(param) params = [{ 'params': decay, 'weight_decay': args.l2 }, { 'params': no_decay, 'weight_decay': 0 }] else: params = model.parameters() opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4) criterion = nn.CrossEntropyLoss() if args.trn_attack == 'free': delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() delta.requires_grad = True elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous': delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() delta.requires_grad = True if args.trn_attack == 'free': epochs = int(math.ceil(args.epochs / args.trn_attack_iters)) else: epochs = args.epochs if args.lr_schedule == 'superconverge': lr_schedule = lambda t: np.interp([t], [ 0, args.epochs * 2 // 5, args.epochs ], [0, args.lr_max, 0])[0] elif args.lr_schedule == 'piecewise': def lr_schedule(t): if t / args.epochs < 0.5: return args.lr_max elif t / args.epochs < 0.75: return args.lr_max / 10. else: return args.lr_max / 100. elif args.lr_schedule == 'linear': lr_schedule = lambda t: np.interp([t], [ 0, args.epochs // 3, args.epochs * 2 // 3, args.epochs ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0] elif args.lr_schedule == 'onedrop': def lr_schedule(t): if t < args.lr_drop_epoch: return args.lr_max else: return args.lr_one_drop elif args.lr_schedule == 'multipledecay': def lr_schedule(t): return args.lr_max - (t // (args.epochs // 10)) * (args.lr_max / 10) elif args.lr_schedule == 'cosine': def lr_schedule(t): return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi)) best_test_robust_acc = 0 best_val_robust_acc = 0 if args.resume: ### temp testing ### model.load_state_dict( torch.load(os.path.join(args.fname, 'model_best.pth'))) start_epoch = args.resume # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth'))) # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth'))) # logger.info(f'Resuming at epoch {start_epoch}') # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc'] if args.val: best_val_robust_acc = torch.load( os.path.join(args.fname, f'model_val.pth'))['val_robust_acc'] else: start_epoch = 0 if args.eval: if not args.resume: logger.info( "No model loaded to evaluate, specify with --resume FNAME") return logger.info("[Evaluation mode]") logger.info( 'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc' ) for epoch in range(start_epoch, epochs): model.train() start_time = time.time() train_loss = 0 train_acc = 0 train_robust_loss = 0 train_robust_acc = 0 train_n = 0 for i, batch in enumerate(train_batches): if args.eval: break X, y = batch['input'], batch['target'] if args.mixup: X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha) X, y_a, y_b = map(Variable, (X, y_a, y_b)) lr = lr_schedule(epoch + (i + 1) / len(train_batches)) opt.param_groups[0].update(lr=lr) if args.trn_attack == 'pgd': # Random initialization if args.mixup: delta = attack_pgd(model, X, y, trn_epsilon, trn_pgd_alpha, args.trn_attack_iters, args.trn_restarts, args.trn_norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam, adv_models=trn_adv_model) else: delta = attack_pgd(model, X, y, trn_epsilon, trn_pgd_alpha, args.trn_attack_iters, args.trn_restarts, args.trn_norm, adv_models=trn_adv_model) delta = delta.detach() elif args.trn_attack == 'fgsm': delta = attack_pgd(model, X, y, trn_epsilon, args.trn_fgsm_alpha * trn_epsilon, 1, 1, args.trn_norm, adv_models=trn_adv_model, rand_init=args.trn_fgsm_init) delta = delta.detach() # Standard training elif args.trn_attack == 'none': delta = torch.zeros_like(X) # The Momentum Iterative Attack elif args.trn_attack == 'tmim': if trn_adv_model is None: adversary = MomentumIterativeAttack( model, nb_iter=args.trn_attack_iters, eps=trn_epsilon, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps_iter=trn_pgd_alpha, clip_min=0, clip_max=1, targeted=False) else: trn_adv_model = nn.Sequential( NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD), trn_adv_model) adversary = MomentumIterativeAttack( trn_adv_model, nb_iter=args.trn_attack_iters, eps=trn_epsilon, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps_iter=trn_pgd_alpha, clip_min=0, clip_max=1, targeted=False) data_adv = adversary.perturb(X, y) delta = data_adv - X delta = delta.detach() robust_output = model( normalize( torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))) if args.mixup: robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam) else: robust_loss = criterion(robust_output, y) if args.l1: for name, param in model.named_parameters(): if 'bn' not in name and 'bias' not in name: robust_loss += args.l1 * param.abs().sum() opt.zero_grad() robust_loss.backward() opt.step() output = model(normalize(X)) if args.mixup: loss = mixup_criterion(criterion, output, y_a, y_b, lam) else: loss = criterion(output, y) train_robust_loss += robust_loss.item() * y.size(0) train_robust_acc += (robust_output.max(1)[1] == y).sum().item() train_loss += loss.item() * y.size(0) train_acc += (output.max(1)[1] == y).sum().item() train_n += y.size(0) train_time = time.time() model.eval() test_loss = 0 test_acc = 0 test_robust_loss = 0 test_robust_acc = 0 test_n = 0 for i, batch in enumerate(test_batches): X, y = batch['input'], batch['target'] # Random initialization if args.tst_attack == 'none': delta = torch.zeros_like(X) elif args.tst_attack == 'pgd': delta = attack_pgd(model, X, y, tst_epsilon, tst_pgd_alpha, args.tst_attack_iters, args.tst_restarts, args.tst_norm, adv_models=tst_adv_model, rand_init=args.tst_fgsm_init) elif args.tst_attack == 'fgsm': delta = attack_pgd(model, X, y, tst_epsilon, tst_epsilon, 1, 1, args.tst_norm, rand_init=args.tst_fgsm_init, adv_models=tst_adv_model) # The Momentum Iterative Attack elif args.tst_attack == 'tmim': if tst_adv_model is None: adversary = MomentumIterativeAttack( model, nb_iter=args.tst_attack_iters, eps=tst_epsilon, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps_iter=tst_pgd_alpha, clip_min=0, clip_max=1, targeted=False) else: tmp_model = nn.Sequential( NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device) adversary = MomentumIterativeAttack( tmp_model, nb_iter=args.tst_attack_iters, eps=tst_epsilon, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps_iter=tst_pgd_alpha, clip_min=0, clip_max=1, targeted=False) data_adv = adversary.perturb(X, y) delta = data_adv - X # elif args.tst_attack == 'pgd': # if tst_adv_model is None: # tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device) # adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters, # eps = tst_epsilon, # loss_fn=nn.CrossEntropyLoss(reduction="sum"), # eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False) # else: # tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device) # adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters, # eps = tst_epsilon, # loss_fn=nn.CrossEntropyLoss(reduction="sum"), # eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False) # data_adv = adversary.perturb(X, y) # delta = data_adv - X delta = delta.detach() robust_output = model( normalize( torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))) robust_loss = criterion(robust_output, y) output = model(normalize(X)) loss = criterion(output, y) test_robust_loss += robust_loss.item() * y.size(0) test_robust_acc += (robust_output.max(1)[1] == y).sum().item() test_loss += loss.item() * y.size(0) test_acc += (output.max(1)[1] == y).sum().item() test_n += y.size(0) test_time = time.time() if args.val: val_loss = 0 val_acc = 0 val_robust_loss = 0 val_robust_acc = 0 val_n = 0 for i, batch in enumerate(val_batches): X, y = batch['input'], batch['target'] # Random initialization if args.tst_attack == 'none': delta = torch.zeros_like(X) elif args.tst_attack == 'pgd': delta = attack_pgd(model, X, y, tst_epsilon, tst_pgd_alpha, args.tst_attack_iters, args.tst_restarts, args.tst_norm, early_stop=args.eval) elif args.tst_attack == 'fgsm': delta = attack_pgd(model, X, y, tst_epsilon, tst_epsilon, 1, 1, args.tst_norm, early_stop=args.eval, rand_init=args.tst_fgsm_init) delta = delta.detach() robust_output = model( normalize( torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))) robust_loss = criterion(robust_output, y) output = model(normalize(X)) loss = criterion(output, y) val_robust_loss += robust_loss.item() * y.size(0) val_robust_acc += (robust_output.max(1)[1] == y).sum().item() val_loss += loss.item() * y.size(0) val_acc += (output.max(1)[1] == y).sum().item() val_n += y.size(0) if not args.eval: logger.info( '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f', epoch, train_time - start_time, test_time - train_time, lr, train_loss / train_n, train_acc / train_n, train_robust_loss / train_n, train_robust_acc / train_n, test_loss / test_n, test_acc / test_n, test_robust_loss / test_n, test_robust_acc / test_n) if args.val: logger.info('validation %.4f \t %.4f \t %.4f \t %.4f', val_loss / val_n, val_acc / val_n, val_robust_loss / val_n, val_robust_acc / val_n) if val_robust_acc / val_n > best_val_robust_acc: torch.save( { 'state_dict': model.state_dict(), 'test_robust_acc': test_robust_acc / test_n, 'test_robust_loss': test_robust_loss / test_n, 'test_loss': test_loss / test_n, 'test_acc': test_acc / test_n, 'val_robust_acc': val_robust_acc / val_n, 'val_robust_loss': val_robust_loss / val_n, 'val_loss': val_loss / val_n, 'val_acc': val_acc / val_n, }, os.path.join(args.fname, f'model_val.pth')) best_val_robust_acc = val_robust_acc / val_n # save checkpoint if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs: torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth')) torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth')) # save best if test_robust_acc / test_n > best_test_robust_acc: torch.save( { 'state_dict': model.state_dict(), 'test_robust_acc': test_robust_acc / test_n, 'test_robust_loss': test_robust_loss / test_n, 'test_loss': test_loss / test_n, 'test_acc': test_acc / test_n, }, os.path.join(args.fname, f'model_best.pth')) best_test_robust_acc = test_robust_acc / test_n else: logger.info( '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f', epoch, train_time - start_time, test_time - train_time, -1, -1, -1, -1, -1, test_loss / test_n, test_acc / test_n, test_robust_loss / test_n, test_robust_acc / test_n) return
def main(): global args, best_prec1 args = parser.parse_args() print(args) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) torch.cuda.set_device(int(args.gpu)) setup_seed(args.seed) model = jigsaw_model(args.class_number) normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) model = nn.Sequential(normalize, model) model.cuda() cudnn.benchmark = True train_trans = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor() ]) val_trans = transforms.Compose([transforms.ToTensor()]) #dataset process train_dataset = datasets.CIFAR10(args.data, train=True, transform=train_trans, download=True) test_dataset = datasets.CIFAR10(args.data, train=False, transform=val_trans, download=True) valid_size = 0.1 indices = list(range(len(train_dataset))) split = int(np.floor(valid_size * len(train_dataset))) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = torch.utils.data.Subset(train_dataset, train_idx) valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx) train_loader = torch.utils.data.DataLoader(train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = torch.utils.data.DataLoader(valid_sampler, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader), 1, # since lr_lambda computes multiplicative factor 1e-6 / args.lr)) print('std training') train_acc = [] ta = [] #gernerate order of jigsaw permutation = np.array( [np.random.permutation(16) for i in range(args.class_number - 1)]) np.save('permutation.npy', permutation) if os.path.exists(args.save_dir) is not True: os.mkdir(args.save_dir) for epoch in range(args.epochs): print(optimizer.state_dict()['param_groups'][0]['lr']) acc, loss = train(train_loader, model, criterion, optimizer, epoch, scheduler, permutation) # evaluate on validation set tacc, tloss = validate(val_loader, model, criterion, permutation) train_acc.append(acc) ta.append(tacc) # remember best prec@1 and save checkpoint is_best = tacc > best_prec1 best_prec1 = max(tacc, best_prec1) if is_best: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'best_model.pt')) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.pt')) plt.plot(train_acc, label='train_acc') plt.plot(ta, label='TA') plt.legend() plt.savefig(os.path.join(args.save_dir, 'net_train.png')) plt.close() model_path = os.path.join(args.save_dir, 'best_model.pt') model.load_state_dict(torch.load(model_path)['state_dict']) print('testing result of ta best model') tacc, tloss = validate(test_loader, model, criterion, permutation)
def main(): global args, best_sa args = parser.parse_args() print(args) torch.cuda.set_device(int(args.gpu)) os.makedirs(args.save_dir, exist_ok=True) if args.seed: setup_seed(args.seed) if args.task == 'rotation': print('train for rotation classification') class_number = 4 else: print('train for supervised classification') if args.dataset == 'cifar10': class_number = 10 elif args.dataset == 'fmnist': class_number = 10 elif args.dataset == 'cifar100': class_number = 100 else: print('error dataset') assert 0 # prepare dataset if args.dataset == 'cifar10': print('training on cifar10 dataset') model = resnet18(num_classes=class_number) model.normalize = NormalizeByChannelMeanStd( mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) train_loader, val_loader, test_loader = cifar10_dataloaders(batch_size= args.batch_size, data_dir =args.data) elif args.dataset == 'cifar_10_10': print('training on cifar10 subset') model = resnet18(num_classes=class_number) model.normalize = NormalizeByChannelMeanStd( mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) train_loader, val_loader, test_loader = cifar10_subset_dataloaders(batch_size= args.batch_size, data_dir =args.data) elif args.dataset == 'cifar100': model = resnet18(num_classes=class_number) model.normalize = NormalizeByChannelMeanStd( mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762]) train_loader, val_loader, test_loader = cifar100_dataloaders(batch_size= args.batch_size, data_dir =args.data) elif args.dataset == 'fmnist': model = resnet18(num_classes=class_number) model.normalize = NormalizeByChannelMeanStd( mean=[0.2860], std=[0.3530]) model.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False) train_loader, val_loader, test_loader = fashionmnist_dataloaders(batch_size= args.batch_size, data_dir =args.data) else: print('dataset not support') model.cuda() criterion = nn.CrossEntropyLoss() decreasing_lr = list(map(int, args.decreasing_lr.split(','))) if args.prune_type == 'lt': print( 'report lottery tickets setting') initalization = deepcopy(model.state_dict()) else: initalization = None optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1) if args.resume: print('resume from checkpoint') checkpoint = torch.load(args.resume, map_location = torch.device('cuda:'+str(args.gpu))) best_sa = checkpoint['best_sa'] start_epoch = checkpoint['epoch'] all_result = checkpoint['result'] start_state = checkpoint['state'] if start_state>0: current_mask = extract_mask(checkpoint['state_dict']) prune_model_custom(model, current_mask) check_sparsity(model) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) initalization = checkpoint['init_weight'] print('loading state:', start_state) print('loading from epoch: ',start_epoch, 'best_sa=', best_sa) else: all_result = {} all_result['train'] = [] all_result['test_ta'] = [] all_result['ta'] = [] start_epoch = 0 start_state = 0 print('######################################## Start Standard Training Iterative Pruning ########################################') print(model.normalize) for state in range(start_state, args.pruning_times): print('******************************************') print('pruning state', state) print('******************************************') for epoch in range(start_epoch, args.epochs): print(optimizer.state_dict()['param_groups'][0]['lr']) check_sparsity(model) acc = train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set tacc = validate(val_loader, model, criterion) # evaluate on test set test_tacc = validate(test_loader, model, criterion) scheduler.step() all_result['train'].append(acc) all_result['ta'].append(tacc) all_result['test_ta'].append(test_tacc) # remember best prec@1 and save checkpoint is_best_sa = tacc > best_sa best_sa = max(tacc, best_sa) save_checkpoint({ 'state': state, 'result': all_result, 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_sa': best_sa, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'init_weight': initalization }, is_SA_best=is_best_sa, pruning=state, save_path=args.save_dir) plt.plot(all_result['train'], label='train_acc') plt.plot(all_result['ta'], label='val_acc') plt.plot(all_result['test_ta'], label='test_acc') plt.legend() plt.savefig(os.path.join(args.save_dir, str(state)+'net_train.png')) plt.close() #report result check_sparsity(model, True) print('report best SA={}'.format(best_sa)) all_result = {} all_result['train'] = [] all_result['test_ta'] = [] all_result['ta'] = [] best_sa = 0 start_epoch = 0 if args.prune_type == 'pt': print('report loading pretrained weight') initalization = torch.load(os.path.join(args.save_dir, '0model_SA_best.pth.tar'), map_location = torch.device('cuda:'+str(args.gpu)))['state_dict'] #pruning_model(model, args.rate) #current_mask = extract_mask(model.state_dict()) #remove_prune(model) #rewind weight to init pruning_model(model, args.rate) check_sparsity(model) current_mask = torch.load(os.path.join(args.mask_path, '{}checkpoint.pth.tar'.format(state+1)))['state_dict'] remove_prune(model) #rewind weight to init model.load_state_dict(initalization) prune_model_custom(model, current_mask) check_sparsity(model) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') os.makedirs('logs', exist_ok=True) # Model MEAN = torch.Tensor([0.4914, 0.4822, 0.4465]) STD = torch.Tensor([0.2023, 0.1994, 0.2010]) norm = NormalizeByChannelMeanStd(mean=MEAN, std=STD) print('==> Building model..') net = torch.nn.DataParallel(nn.Sequential(norm, ResNet18_feat()).cuda()) cudnn.benchmark = True checkpoint = torch.load('./checkpoint/batch_adv0_grad0_lambda_1.0.t7') net.load_state_dict(checkpoint['net']) #freeze last layer # net.module[-1].linear.weight.requires_grad = False # net.module[-1].linear.bias.requires_grad = False if args.resume or args.test: # Load checkpoint. print('==> Resuming from checkpoint..')
def main(): global args, best_sa args = parser.parse_args() print(args) torch.cuda.set_device(int(args.gpu)) os.makedirs(args.save_dir, exist_ok=True) if args.seed: setup_seed(args.seed) _, val_loader, _ = cifar10_dataloaders() class_number = 10 model = resnet18(num_classes=class_number) # prepare dataset img_results = torch.load("results.pth", map_location="cpu") model.normalize = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) model.cuda() criterion = nn.CrossEntropyLoss() initalization = deepcopy(model.state_dict()) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) all_result = {} all_result['train'] = [] all_result['test_ta'] = [] all_result['ta'] = [] start_epoch = 0 start_state = 0 print( '######################################## Start Standard Training Iterative Pruning ########################################' ) print(model.normalize) for state in range(start_state, args.pruning_times): print('******************************************') print('pruning state', state) print('******************************************') for epoch in range(start_epoch, args.epochs): print(optimizer.state_dict()['param_groups'][0]['lr']) check_sparsity(model) acc = train(img_results, model, criterion, optimizer, epoch) # evaluate on validation set tacc = validate(val_loader, model, criterion) # evaluate on test set all_result['train'].append(acc) all_result['ta'].append(tacc) # remember best prec@1 and save checkpoint is_best_sa = tacc > best_sa best_sa = max(tacc, best_sa) save_checkpoint( { 'state': state, 'result': all_result, 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_sa': best_sa, 'optimizer': optimizer.state_dict(), 'init_weight': initalization }, is_SA_best=is_best_sa, pruning=state, save_path=args.save_dir) plt.plot(all_result['train'], label='train_acc') plt.plot(all_result['ta'], label='val_acc') plt.plot(all_result['test_ta'], label='test_acc') plt.legend() plt.savefig( os.path.join(args.save_dir, str(state) + 'net_train.png')) plt.close() #report result check_sparsity(model, True) print('report best SA={}'.format(best_sa)) all_result = {} all_result['train'] = [] all_result['test_ta'] = [] all_result['ta'] = [] best_sa = 0 start_epoch = 0 if args.prune_type == 'pt': print('report loading pretrained weight') initalization = torch.load( os.path.join(args.save_dir, '0model_SA_best.pth.tar'), map_location=torch.device('cuda:' + str(args.gpu)))['state_dict'] pruning_model(model, args.rate) check_sparsity(model) current_mask = extract_mask(model.state_dict()) remove_prune(model) #rewind weight to init model.load_state_dict(initalization) #pruning using custom mask prune_model_custom(model, current_mask) check_sparsity(model) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
def init_advertorch(self, model, device, attack_params, dataset_params): mean = dataset_params['mean'] std = dataset_params['std'] num_classes = dataset_params['num_classes'] self.normalize = NormalizeByChannelMeanStd(mean=mean, std=std) basic_model = model if (attack_params['bpda'] == True): preprocess = attack_params['preprocess'] preprocess_bpda_wrapper = BPDAWrapper( preprocess, forwardsub=preprocess.back_approx) attack_model = nn.Sequential(self.normalize, preprocess_bpda_wrapper, basic_model).to(device) else: attack_model = nn.Sequential(self.normalize, basic_model).to(device) attack_name = attack_params['attack'].lower() if (attack_name == 'pgd'): iterations = attack_params['iterations'] stepsize = attack_params['stepsize'] epsilon = attack_params['epsilon'] attack = advertorch.attacks.LinfPGDAttack random = attack_params['random'] # Return attack dictionary return { 'attack': attack, 'iterations': iterations, 'epsilon': epsilon, 'stepsize': stepsize, 'model': attack_model, 'random': random } elif (attack_name == 'cw'): iterations = attack_params['iterations'] epsilon = attack_params['epsilon'] attack = advertorch.attacks.CarliniWagnerL2Attack # Return attack dictionary return { 'attack': attack, 'iterations': iterations, 'epsilon': epsilon, 'model': attack_model, 'num_classes': num_classes } elif (attack_name == 'fgsm'): epsilon = attack_params['epsilon'] attack = advertorch.attacks.FGSM # Return attack dictionary return { 'attack': attack, 'iterations': 1, 'epsilon': epsilon, 'model': attack_model, 'num_classes': num_classes } else: # Right way to handle exception in python see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python # Explains all the traps of using exception, does a good job!! I mean the link :) raise ValueError("Unsupported attack")
def main(): global args, best_prec1, best_ata args = parser.parse_args() print(args) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) torch.cuda.set_device(int(args.gpu)) setup_seed(args.seed) model = ResNet50() normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) model = nn.Sequential(normalize, model) model.cuda() cudnn.benchmark = True if args.pretrained_model: model_dict_pretrain = torch.load( args.pretrained_model, map_location=torch.device('cuda:' + str(args.gpu))) model.load_state_dict(model_dict_pretrain, strict=False) print('model loaded:', args.pretrained_model) #dataset process train_trans = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor() ]) val_trans = transforms.Compose([transforms.ToTensor()]) #dataset process train_dataset = datasets.CIFAR10(args.data, train=True, transform=train_trans, download=True) test_dataset = datasets.CIFAR10(args.data, train=False, transform=val_trans, download=True) valid_size = 0.1 indices = list(range(len(train_dataset))) split = int(np.floor(valid_size * len(train_dataset))) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = torch.utils.data.Subset(train_dataset, train_idx) valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx) train_loader = torch.utils.data.DataLoader(train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = torch.utils.data.DataLoader(valid_sampler, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) decreasing_lr = list(map(int, args.decreasing_lr.split(','))) criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1) print('adv training') train_acc = [] ta = [] ata = [] if os.path.exists(args.save_dir) is not True: os.mkdir(args.save_dir) for epoch in range(args.epochs): print(optimizer.state_dict()['param_groups'][0]['lr']) acc, loss = train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set tacc, tloss = validate(val_loader, model, criterion) atacc, atloss = validate_adv(val_loader, model, criterion) scheduler.step() train_acc.append(acc) ta.append(tacc) ata.append(atacc) # remember best prec@1 and save checkpoint is_best = tacc > best_prec1 best_prec1 = max(tacc, best_prec1) ata_is_best = atacc > best_ata best_ata = max(atacc, best_ata) if is_best: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'best_model.pt')) if ata_is_best: save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'ata_best_model.pt')) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=os.path.join(args.save_dir, 'model.pt')) plt.plot(train_acc, label='train_acc') plt.plot(ta, label='TA') plt.plot(ata, label='ATA') plt.legend() plt.savefig(os.path.join(args.save_dir, 'net_train.png')) plt.close() best_model_path = os.path.join(args.save_dir, 'ata_best_model.pt') print('start testing ATA best model') model.load_state_dict(torch.load(best_model_path)['state_dict']) tacc, tloss = validate(test_loader, model, criterion) atacc, atloss = validate_adv(test_loader, model, criterion) best_model_path = os.path.join(args.save_dir, 'best_model.pt') print('start testing TA best model') model.load_state_dict(torch.load(best_model_path)['state_dict']) tacc, tloss = validate(test_loader, model, criterion) atacc, atloss = validate_adv(test_loader, model, criterion)
def main(): global args, best_prec1, ata_best_prec1 args = parser.parse_args() print(args) if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) torch.cuda.set_device(int(args.gpu)) setup_seed(args.seed) # Data Preprocess traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') data_transforms = { 'train': transforms.Compose([ transforms.ToPILImage(), transforms.Pad(2), transforms.RandomCrop(32), # transforms.RandomHorizontalFlip(), # transforms.RandomRotation(5), transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'val': transforms.Compose([ transforms.ToPILImage(), #transforms.Pad(2), #transforms.RandomCrop(32), transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) } if args.dataset == 'cifar': train_dataset = CifarDataset(args.data, True, data_transforms['train'], args.percent) test_dataset = CifarDataset(args.data, False, data_transforms['val'], 1) elif args.dataset == 'imagenet': train_dataset = ImageNetDataset(args.data, True, data_transforms['train'], args.percent) test_dataset = ImageNetDataset(args.data, False, data_transforms['val'], 1) elif args.dataset == 'imagenet224': data_transforms = { 'train': transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]), 'val': transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) } train_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['train']) test_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['val']) valid_size = 0.1 indices = list(range(len(train_dataset))) split = int(np.floor(valid_size*len(train_dataset))) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = torch.utils.data.Subset(train_dataset, train_idx) valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx) train_loader = torch.utils.data.DataLoader( train_sampler, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( valid_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) criterion = nn.CrossEntropyLoss().cuda(args.gpu) # define model n_split = 4 selfie_model = get_selfie_model(n_split) selfie_model = selfie_model.cuda() P=get_P_model() normalize = NormalizeByChannelMeanStd( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) P = nn.Sequential(normalize, P) P = P.cuda() #define optimizer and scheduler params_list = [{'params': selfie_model.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay},] params_list.append({'params': P.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay}) optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov = True) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: cosine_annealing( step, args.epochs * len(train_loader), 1, # since lr_lambda computes multiplicative factor 1e-7 / args.lr)) print("Training model.") step = 0 if os.path.exists(args.modeldir) is not True: os.mkdir(args.modeldir) stats_ = stats(args.modeldir, args.start_epoch) if args.epochs > 0: #order of patches all_seq=[np.random.permutation(16) for ind in range(400)] pickle.dump(all_seq, open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'wb')) # all_seq=pickle.load(open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'rb')) print("Begin selfie training...") for epoch in range(args.start_epoch, args.epochs): print("The learning rate is {}".format(optimizer.param_groups[0]['lr'])) trainObj, top1 = train_selfie_adv(train_loader, selfie_model, P, criterion, optimizer, epoch, scheduler) valObj, prec1 = val_selfie(val_loader, selfie_model, P, criterion, all_seq) adv_valObj, adv_prec1 = val_pgd_selfie(val_loader, selfie_model, P, criterion, all_seq) stats_._update(trainObj, top1, valObj, prec1,adv_valObj, adv_prec1) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) ata_is_best = adv_prec1 > ata_best_prec1 ata_best_prec1 = max(adv_prec1, ata_best_prec1) if is_best: torch.save( { 'epoch': epoch, 'P_state': P.state_dict(), 'selfie_state': selfie_model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, os.path.join(args.modeldir, 'adv_selfie_TA_model_best.pth.tar')) if ata_is_best: torch.save( { 'epoch': epoch, 'P_state': P.state_dict(), 'selfie_state': selfie_model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, os.path.join(args.modeldir, 'adv_selfie_ATA_model_best.pth.tar')) torch.save( { 'epoch': epoch, 'P_state': P.state_dict(), 'selfie_state': selfie_model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_prec1': best_prec1, }, os.path.join(args.modeldir, 'adv_selfie_checkpoint.pth.tar')) plot_curve(stats_, args.modeldir, True) data = stats_ sio.savemat(os.path.join(args.modeldir,'stats.mat'), {'data':data}) print("testing ATA best selfie model from checkpoint...") model_path = os.path.join(args.modeldir, 'adv_selfie_ATA_model_best.pth.tar') model_loaded = torch.load(model_path) P.load_state_dict(model_loaded['P_state']) selfie_model.load_state_dict(model_loaded['selfie_state']) print("Best ATA selfie model loaded! ") valObj, prec1 = val_selfie(test_loader, selfie_model, P, criterion, all_seq) adv_valObj, adv_prec1 = val_pgd_selfie(test_loader, selfie_model, P, criterion, all_seq) print("testing TA best selfie model from checkpoint...") model_path = os.path.join(args.modeldir, 'adv_selfie_TA_model_best.pth.tar') model_loaded = torch.load(model_path) P.load_state_dict(model_loaded['P_state']) selfie_model.load_state_dict(model_loaded['selfie_state']) print("Best TA selfie model loaded! ") valObj, prec1 = val_selfie(test_loader, selfie_model, P, criterion, all_seq) adv_valObj, adv_prec1 = val_pgd_selfie(test_loader, selfie_model, P, criterion, all_seq)
def __init__( self, clip_values, step=200, means=None, stds=None, gan_ckpt='styleGAN.pt', encoder_ckpt='encoder.pt', optimize_noise=True, use_noise_regularize=False, use_lpips=False, apply_fit=False, apply_predict=True, mse=500, lr_rampup=0.05, lr_rampdown=0.05, noise=0.05, noise_ramp=0.75, noise_regularize=1e5, lr=0.1 ): super(InvGAN, self).__init__() #print("invgan") #pdb.set_trace() self._apply_fit = apply_fit self._apply_predict = apply_predict # setup normalization parameters if means is None: means = (0.0, 0.0, 0.0) # identity operation if len(means) != 3: raise ValueError("means must have 3 values, one per channel") self.means = means if stds is None: stds = (1.0, 1.0, 1.0) # identity operation if len(stds) != 3: raise ValueError("stds must have 3 values, one per channel") self.stds = stds self.clip_values = clip_values # setup optimization parameters self.optimize_noise = optimize_noise self.use_noise_regularize = use_noise_regularize self.use_lpips = use_lpips self.step = step self.mse = mse self.lr = lr self.lr_rampup = lr_rampup self.lr_rampdown = lr_rampdown self.noise = noise self.noise_ramp = noise_ramp self.noise_regularize = noise_regularize # setup generator self.generator = Generator(256, 512, 8) #self.generator.load_state_dict(torch.load(gan_ckpt)['g_ema']) self.generator.load_state_dict(torch.load(maybe_download_weights_from_s3(gan_ckpt))['g_ema']) self.generator.eval() self.generator.cuda() self.deprocess_layer = NormalizeByChannelMeanStd([-1., -1., -1.], [2., 2., 2.]).cuda() # setup encoder self.encoder = Encoder() #self.encoder.load_state_dict(torch.load(encoder_ckpt)['netE']) self.encoder.load_state_dict(torch.load(maybe_download_weights_from_s3(encoder_ckpt))['netE']) self.encoder.eval() self.encoder.cuda() # setup loss if use_lpips: self.lpips = PerceptualLoss().cuda() # estimate latent code statistics n_mean_latent = 10000 with torch.no_grad(): noise_sample = torch.randn(n_mean_latent, 512, device='cuda') latent_out = self.generator.style(noise_sample) latent_mean = latent_out.mean(0) self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format( replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.normalize = NormalizeByChannelMeanStd( mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]) self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.Identity() self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0)