def __init__(self, attack_model, attack_method="FGSM", eps=0.3, data_type="test", rand_seed=0, rand_min=0, rand_max=1, loader_batch=128, for_trainning=False, atk_loss=None, quantize=False): normal_data = normalMnist(data_type=data_type, loader_batch=loader_batch) self.noarmal_data = normal_data.data self.labels = normal_data.labels x_atk = torch.tensor([]).to(device) bs = loader_batch for batch_data, batch_labels in normal_data.loader: if (attack_method == "FGSM"): if isinstance(eps, str): batch_pn = FGSM(attack_model, loss_fn=atk_loss, getAtkpn=True).perturb(batch_data.to(device), batch_labels.to(device)) eps_temp = (1) * torch.rand((len(batch_pn), 1, 1, 1)) eps_temp = eps_temp.to(device) batch_atk = torch.clamp(batch_data.to(device) + eps_temp * batch_pn, min=0, max=1) else: batch_atk = FGSM(attack_model, loss_fn=atk_loss, eps=eps).perturb(batch_data.to(device), batch_labels.to(device)) if (attack_method == "PGD"): batch_atk = PGDAttack(attack_model, loss_fn=atk_loss, eps=eps).perturb(batch_data.to(device), batch_labels.to(device)) x_atk = torch.cat((x_atk, batch_atk)) # x_atk = torch.tensor(x_atk) self.data = x_atk.cpu() if quantize: self.data = (self.data * 255).type(torch.int) / 255. if for_trainning: self.loader = torch.utils.data.DataLoader(train_dataSet(self.noarmal_data, self.labels, self.data), batch_size=loader_batch) else: self.loader = torch.utils.data.DataLoader(dataSet(self.data, self.labels), batch_size=loader_batch)
def test(epoch, is_adv=False): global is_training, best_acc is_training = False net.eval() test_loss = 0 correct = 0 total = 0 # with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): if use_cuda: inputs, targets = inputs.requires_grad_().cuda(), targets.cuda() if is_adv: adversary = PGDAttack(net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.eps, nb_iter=args.iter, eps_iter=args.eps / args.iter, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) with ctx_noparamgrad_and_eval(net): inputs = adversary.perturb(inputs, targets) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() correct = correct.item() progress_bar( batch_idx, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total)) # Save checkpoint. acc = 100. * correct / total if acc > best_acc: best_acc = acc checkpoint(acc, epoch) return (test_loss / batch_idx, 100. * correct / total)
def train(epoch): global is_training is_training = True print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): if use_cuda: inputs, targets = inputs.cuda().requires_grad_(), targets.cuda() # generate adv img if args.adv_train: adversary = PGDAttack(net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.eps, nb_iter=args.iter, eps_iter=args.eps / args.iter, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) with ctx_noparamgrad_and_eval(net): inputs = adversary.perturb(inputs, targets) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() correct = correct.item() progress_bar( batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) return (train_loss / batch_idx, 100. * correct / total)
def test_adver(net, tar_net, attack, target): net.eval() tar_net.eval() # BIM if attack == 'BIM': adversary = LinfBasicIterativeAttack( net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.25, nb_iter=120, eps_iter=0.02, clip_min=0.0, clip_max=1.0, targeted=opt.target) # PGD elif attack == 'PGD': if opt.target: adversary = PGDAttack(net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.25, nb_iter=11, eps_iter=0.03, clip_min=0.0, clip_max=1.0, targeted=opt.target) else: adversary = PGDAttack(net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.25, nb_iter=6, eps_iter=0.03, clip_min=0.0, clip_max=1.0, targeted=opt.target) # FGSM elif attack == 'FGSM': adversary = GradientSignAttack( net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.26, targeted=opt.target) elif attack == 'CW': adversary = CarliniWagnerL2Attack( net, num_classes=10, learning_rate=0.45, # loss_fn=nn.CrossEntropyLoss(reduction="sum"), binary_search_steps=10, max_iterations=12, targeted=opt.target) # ---------------------------------- # Obtain the accuracy of the model # ---------------------------------- with torch.no_grad(): correct_netD = 0.0 total = 0.0 net.eval() for data in testloader: inputs, labels = data inputs = inputs.cuda() labels = labels.cuda() outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct_netD += (predicted == labels).sum() print('Accuracy of the network on netD: %.2f %%' % (100. * correct_netD.float() / total)) # ---------------------------------- # Obtain the attack success rate of the model # ---------------------------------- correct = 0.0 total = 0.0 tar_net.eval() total_L2_distance = 0.0 for data in testloader: inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs = tar_net(inputs) _, predicted = torch.max(outputs.data, 1) if target: # randomly choose the specific label of targeted attack labels = torch.randint(0, 9, (1, )).to(device) # test the images which are not classified as the specific label if predicted != labels: # print(total) adv_inputs_ori = adversary.perturb(inputs, labels) L2_distance = (torch.norm(adv_inputs_ori - inputs)).item() total_L2_distance += L2_distance with torch.no_grad(): outputs = tar_net(adv_inputs_ori) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() else: # test the images which are classified correctly if predicted == labels: # print(total) adv_inputs_ori = adversary.perturb(inputs, labels) L2_distance = (torch.norm(adv_inputs_ori - inputs)).item() total_L2_distance += L2_distance with torch.no_grad(): outputs = tar_net(adv_inputs_ori) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() if target: print('Attack success rate: %.2f %%' % (100. * correct.float() / total)) else: print('Attack success rate: %.2f %%' % (100.0 - 100. * correct.float() / total)) print('l2 distance: %.4f ' % (total_L2_distance / total))
var_name = 'truncation.truncation' state_dict[var_name] = gan.state_dict()[var_name] gan.load_state_dict(state_dict) # reduce memory consumption gan = gan.synthesis for p in gan.parameters(): p.requires_grad_(False) gan.cuda() gan.eval() model = torch.nn.Sequential(gan, net) pgd_iters = [10, 50] test_attacker = PGDAttack(predict=net, eps=8.0 / 255, eps_iter=2.0 / 255, nb_iter=50, clip_min=-1.0, clip_max=1.0) # set dataset, dataloader dataset = get_dataset(cfg) transform = get_transform(cfg) testset = dataset(root=testset_cfg.path, train=False) testloader = DataLoader(testset, batch_size=10, num_workers=4, shuffle=False) for nb_iter in pgd_iters: acc_clean = AverageMeter() acc_adv_image = AverageMeter() acc_adv_latent = AverageMeter() progress_bar = tqdm(testloader)
def main(): opt = Parser(train=False).get() # dataset and data loader _, val_loader, adv_val_loader, _, num_classes = \ load_dataset(opt.dataset, opt.batch_size, opt.data_root, False, 0.0, opt.num_val_samples, workers=4) # model if opt.arch == 'lenet': model = LeNet(num_classes) elif opt.arch == 'resnet': model = ResNetv2_20(num_classes) else: raise NotImplementedError # move model to device model.to(opt.device) # load trained weight try: model.load_state_dict(torch.load(opt.weight_path)) except: model_weight = convert_model_from_parallel(opt.weight_path) model.load_state_dict(model_weight) # criterion criterion = nn.CrossEntropyLoss() # advertorch attacker if opt.attack == 'pgd': attacker = PGDAttack(model, loss_fn=criterion, eps=opt.eps / 255, nb_iter=opt.num_steps, eps_iter=opt.eps_iter / 255, rand_init=True, clip_min=opt.clip_min, clip_max=opt.clip_max, ord=np.inf, targeted=False) else: raise NotImplementedError # trainer trainer = Trainer(opt, model, criterion, attacker) trainer.print_freq = -1 # validation val_losses, val_acc1s, val_acc5s = \ trainer.validate(val_loader) aval_losses, aval_acc1s, aval_acc5s = \ trainer.adv_validate(adv_val_loader) print('[model] {}'.format(opt.weight_path)) print('[standard]\n' 'loss: {:.4f} | acc1: {:.2f}% | acc5: {:.2f}%' '\n[adversarial]\n' 'loss: {:.4f} | acc1: {:.2f}% | acc5: {:.2f}%'.format( val_losses['val'].avg, val_acc1s['val'].avg, val_acc5s['val'].avg, aval_losses['aval'].avg, aval_acc1s['aval'].avg, aval_acc5s['aval'].avg))
# load classifier predict = get_classifier(cfg, cfg.classifier) state_dict = torch.load(cfg.classifier.ckpt) try: predict.load_state_dict(state_dict) except: predict.load_state_dict(state_dict["state_dict"]) for p in predict.parameters(): p.requires_grad_(False) predict = torch.nn.Sequential(transform.classifier_preprocess_layer, predict).cuda() predict.eval() # create attacker attacker = PGDAttack(predict=predict, eps=args.eps / 255.0, eps_iter=1 / 255.0, nb_iter=args.eps + 4) total = 0 correct_clean = 0 correct_adv = 0 correct_def = 0 for i, (images, labels) in enumerate(progress_bar): if i < start_ind or i >= end_ind: continue images, labels = images.cuda(), labels.cuda() result_path = os.path.join(result_dir, 'batch_{:04d}.pt'.format(i)) if os.path.isfile(result_path): result_dict = torch.load(result_path) images_adv = result_dict['input'].cuda()
print(f'\tloaded model at {model_path}') except: print(f'CANNOT LOAD MODEL AT: {model_path}') continue net.cuda() net.eval() # instantiate adversary if args.attack_type == 'PGD': adversary = PGDAttack( predict=net, loss_fn=F.cross_entropy, eps=attack_configs['eps'], nb_iter=attack_configs['nb_iter'], eps_iter=attack_configs['eps_iter'], rand_init=True, clip_min=0., clip_max=1., ord=attack_configs['ord'], targeted=False) for h in range(n_repeat): print(f'\tAttacking {seed}. ({h+1}/{n_repeat})') test() net.cpu() print('') # get adversarial test and standard accuracy adv_ci = utils.mean_confidence_interval(adv_test_accuracies) std_ci = utils.mean_confidence_interval(test_accuracies)
def train_Ours(args, train_loader, val_loader, knownclass, Encoder, Decoder, NorClsfier, SSDClsfier, summary_writer, saver): seed = init_random_seed(args.manual_seed) criterionCls = nn.CrossEntropyLoss() criterionRec = nn.MSELoss() if args.parallel_train: Encoder = DataParallel(Encoder) Decoder = DataParallel(Decoder) NorClsfier = DataParallel(NorClsfier) SSDClsfier = DataParallel(SSDClsfier) optimizer = optim.Adam( list(Encoder.parameters()) + list(NorClsfier.parameters()) + list(SSDClsfier.parameters()) + list(Decoder.parameters()), lr=args.lr) if args.adv is 'PGDattack': print("**********Defense PGD Attack**********") elif args.adv is 'FGSMattack': print("**********Defense FGSM Attack**********") if args.adv is 'PGDattack': from advertorch.attacks import PGDAttack nor_adversary = PGDAttack(predict1=Encoder, predict2=NorClsfier, nb_iter=args.adv_iter) rot_adversary = PGDAttack(predict1=Encoder, predict2=SSDClsfier, nb_iter=args.adv_iter) elif args.adv is 'FGSMattack': from advertorch.attacks import GradientSignAttack nor_adversary = GradientSignAttack(predict1=Encoder, predict2=NorClsfier) rot_adversary = GradientSignAttack(predict1=Encoder, predict2=SSDClsfier) global_step = 0 # ---------- # Training # ---------- for epoch in range(args.n_epoch): Encoder.train() Decoder.train() NorClsfier.train() SSDClsfier.train() for steps, (orig, label, rot_orig, rot_label) in enumerate(train_loader): label = lab_conv(knownclass, label) orig, label = orig.cuda(), label.long().cuda() rot_orig, rot_label = rot_orig.cuda(), rot_label.long().cuda() with ctx_noparamgrad_and_eval(Encoder): with ctx_noparamgrad_and_eval(NorClsfier): with ctx_noparamgrad_and_eval(SSDClsfier): adv = nor_adversary.perturb(orig, label) rot_adv = rot_adversary.perturb(rot_orig, rot_label) latent_feat = Encoder(adv) norpred = NorClsfier(latent_feat) norlossCls = criterionCls(norpred, label) recon = Decoder(latent_feat) lossRec = criterionRec(recon, orig) ssdpred = SSDClsfier(Encoder(rot_adv)) rotlossCls = criterionCls(ssdpred, rot_label) loss = args.norClsWgt * norlossCls + args.rotClsWgt * rotlossCls + args.RecWgt * lossRec optimizer.zero_grad() loss.backward() optimizer.step() #============ tensorboard the log info ============# lossinfo = { 'loss': loss.item(), 'norlossCls': norlossCls.item(), 'lossRec': lossRec.item(), 'rotlossCls': rotlossCls.item(), } global_step += 1 #============ print the log info ============# if (steps + 1) % args.log_step == 0: errors = OrderedDict([ ('loss', loss.item()), ('norlossCls', norlossCls.item()), ('lossRec', lossRec.item()), ('rotlossCls', rotlossCls.item()), ]) saver.print_current_errors((epoch + 1), (steps + 1), errors) # evaluate performance on validation set periodically if ((epoch + 1) % args.val_epoch == 0): # switch model to evaluation mode Encoder.eval() NorClsfier.eval() running_corrects = 0.0 epoch_size = 0.0 val_loss_list = [] # calculate accuracy on validation set for steps, (images, label) in enumerate(val_loader): label = lab_conv(knownclass, label) images, label = images.cuda(), label.long().cuda() adv = nor_adversary.perturb(images, label) with torch.no_grad(): logits = NorClsfier(Encoder(adv)) _, preds = torch.max(logits, 1) running_corrects += torch.sum(preds == label.data) epoch_size += images.size(0) val_loss = criterionCls(logits, label) val_loss_list.append(val_loss.item()) val_loss_mean = sum(val_loss_list) / len(val_loss_list) val_acc = running_corrects.double() / epoch_size print('Val Acc: {:.4f}, Val Loss: {:.4f}'.format( val_acc, val_loss_mean)) valinfo = { 'Val Acc': val_acc.item(), 'Val Loss': val_loss.item(), } for tag, value in valinfo.items(): summary_writer.add_scalar(tag, value, (epoch + 1)) orig_show = vutils.make_grid(orig, normalize=True, scale_each=True) recon_show = vutils.make_grid(recon, normalize=True, scale_each=True) summary_writer.add_image('Ori_Image', orig_show, (epoch + 1)) summary_writer.add_image('Rec_Image', recon_show, (epoch + 1)) if ((epoch + 1) % args.model_save_epoch == 0): model_save_path = os.path.join(args.results_path, args.training_type, 'snapshots', args.datasetname + '-' + args.split, args.denoisemean, args.adv + str(args.adv_iter)) mkdir(model_save_path) torch.save( Encoder.state_dict(), os.path.join(model_save_path, "Encoder-{}.pt".format(epoch + 1))) torch.save( NorClsfier.state_dict(), os.path.join(model_save_path, "NorClsfier-{}.pt".format(epoch + 1))) torch.save( Decoder.state_dict(), os.path.join(model_save_path, "Decoder-{}.pt".format(epoch + 1))) torch.save(Encoder.state_dict(), os.path.join(model_save_path, "Encoder-final.pt")) torch.save(NorClsfier.state_dict(), os.path.join(model_save_path, "NorClsfier-final.pt")) torch.save(Decoder.state_dict(), os.path.join(model_save_path, "Decoder-final.pt"))
if args.L2PGD: adversary = L2PGDAttack(net, loss_fn=csl, eps=args.eps, nb_iter=10, eps_iter=25., rand_init=True, clip_min=0.0, clip_max=255., targeted=False) else: adversary = PGDAttack(net, loss_fn=csl, eps=args.eps, nb_iter=10, eps_iter=1.75, rand_init=False, clip_min=0.0, clip_max=255., targeted=False) net.eval() correct = 0 total = 0 from dataload_test import load_test_list, get_test test_num = load_test_list() iters = 100 print(iters)
# instantiate model, adversary, optimizer net = instantiate_model(args) net.cuda() print( f'\nTraining {args.model_name} model with {train_types[args.train_type]}.' ) # instantiate adversary if args.train_type == 'at': print("\tTraining against:", attack_configs) adversary = PGDAttack(predict=net, loss_fn=F.cross_entropy, eps=attack_configs['eps'], nb_iter=attack_configs['nb_iter'], eps_iter=attack_configs['eps_iter'], rand_init=True, clip_min=0., clip_max=1., ord=attack_configs['ord'], targeted=False) elif args.train_type == 'jr': print( f'\tTraining with {"approximated" if bool(args.jr_approx) else "full"} Jacobian norm.' ) optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=args.nesterov)
def precalc_weibull(args, dataloader_train, knownclass, Encoder, NorClsfier): # First generate pre-softmax 'activation vectors' for all training examples print( "Weibull: computing features for all correctly-classified training data" ) activation_vectors = {} if args.adv is 'PGDattack': from advertorch.attacks import PGDAttack adversary = PGDAttack(predict1=Encoder, predict2=NorClsfier, nb_iter=args.adv_iter) elif args.adv is 'FGSMattack': from advertorch.attacks import FGSM adversary = FGSM(predict1=Encoder, predict2=NorClsfier) for _, (images, labels, _, _) in enumerate(dataloader_train): labels = lab_conv(knownclass, labels) images, labels = images.cuda(), labels.long().cuda() print("**********Conduct Attack**********") advimg = adversary.perturb(images, labels) with torch.no_grad(): logits = NorClsfier(Encoder(advimg)) correctly_labeled = (logits.data.max(1)[1] == labels) labels_np = labels.cpu().numpy() logits_np = logits.data.cpu().numpy() for i, label in enumerate(labels_np): if not correctly_labeled[i]: continue # If correctly labeled, add this to the list of activation_vectors for this class if label not in activation_vectors: activation_vectors[label] = [] activation_vectors[label].append(logits_np[i]) print("Computed activation_vectors for {} known classes".format( len(activation_vectors))) for class_idx in activation_vectors: print("Class {}: {} images".format(class_idx, len(activation_vectors[class_idx]))) # Compute a mean activation vector for each class print("Weibull computing mean activation vectors...") mean_activation_vectors = {} for class_idx in activation_vectors: mean_activation_vectors[class_idx] = np.array( activation_vectors[class_idx]).mean(axis=0) # Initialize one libMR Wiebull object for each class print("Fitting Weibull to distance distribution of each class") weibulls = {} for class_idx in activation_vectors: distances = [] mav = mean_activation_vectors[class_idx] for v in activation_vectors[class_idx]: distances.append(np.linalg.norm(v - mav)) mr = libmr.MR() tail_size = min(len(distances), WEIBULL_TAIL_SIZE) mr.fit_high(distances, tail_size) weibulls[class_idx] = mr print("Weibull params for class {}: {}".format(class_idx, mr.get_params())) return activation_vectors, mean_activation_vectors, weibulls
def openset_weibull(args, dataloader_test, knownclass, Encoder, NorClsfier, activation_vectors, mean_activation_vectors, weibulls, mode='openset'): # Apply Weibull score to every logit weibull_scores = [] logits = [] classes = activation_vectors.keys() running_corrects = 0.0 epoch_size = 0.0 if args.adv is 'PGDattack': from advertorch.attacks import PGDAttack adversary = PGDAttack(predict1=Encoder, predict2=NorClsfier, nb_iter=args.adv_iter) elif args.adv is 'FGSMattack': from advertorch.attacks import FGSM adversary = FGSM(predict1=Encoder, predict2=NorClsfier) # reclosslist = [] for steps, (images, labels) in enumerate(dataloader_test): labels = lab_conv(knownclass, labels) images, labels = images.cuda(), labels.long().cuda() print("Calculate weibull_scores in step {}/{}".format( steps, len(dataloader_test))) print("**********Conduct Attack**********") if mode is 'closeset': advimg = adversary.perturb(images, labels) else: advimg = adversary.perturb(images) with torch.no_grad(): batch_logits_torch = NorClsfier(Encoder(advimg)) batch_logits = batch_logits_torch.data.cpu().numpy() batch_weibull = np.zeros(shape=batch_logits.shape) for activation_vector in batch_logits: weibull_row = np.ones(len(knownclass)) for class_idx in classes: mav = mean_activation_vectors[class_idx] dist = np.linalg.norm(activation_vector - mav) weibull_row[class_idx] = 1 - weibulls[class_idx].w_score(dist) weibull_scores.append(weibull_row) logits.append(activation_vector) if mode is 'closeset': _, preds = torch.max(batch_logits_torch, 1) # statistics running_corrects += torch.sum(preds == labels.data) epoch_size += images.size(0) if mode is 'closeset': running_corrects = running_corrects.double() / epoch_size print('Test Acc: {:.4f}'.format(running_corrects)) weibull_scores = np.array(weibull_scores) logits = np.array(logits) openmax_scores = -np.log(np.sum(np.exp(logits * weibull_scores), axis=1)) if mode is 'closeset': return running_corrects, np.array(openmax_scores) else: return np.array(openmax_scores)
return img_n def unnormalize(img, mean=mean, std=std): img_u = img * std img_u = img_u + mean return img_u epsilon = args.eps epsilon = epsilon / 255. ddn = False if args.attack == 'PGD': adversary = PGDAttack(lambda x: wrapper(x, pcl=pcl), eps=epsilon, eps_iter=epsilon / 4, nb_iter=10, ord=norm, rand_init=True) elif args.attack == 'MIFGSM': adversary = MomentumIterativeAttack( lambda x: wrapper(normalize(x), pcl=pcl), eps=epsilon, eps_iter=epsilon / 10, ord=norm, nb_iter=10) elif args.attack == 'FGSM': adversary = GradientSignAttack(lambda x: wrapper(x, pcl=pcl), eps=epsilon) # adversary = PGDAttack(lambda x: wrapper(x, pcl=pcl), eps=epsilon, eps_iter=epsilon, nb_iter=1, ord=norm, rand_init=False) elif args.attack == 'CW': adversary = CarliniWagnerL2Attack(lambda x: wrapper(x, pcl=pcl), 10,
state_dict = torch.load(gan_path) var_name = 'truncation.truncation' state_dict[var_name] = gan.state_dict()[var_name] gan.load_state_dict(state_dict) gan = gan.synthesis for p in gan.parameters(): p.requires_grad_(False) gan = move_to_device(gan, cfg) model = torch.nn.Sequential(gan, net) image_attacker = get_attack(cfg.image_attack, net) latent_attacker = get_attack(cfg.latent_attack, model) test_attacker = PGDAttack(predict=net, eps=cfg.image_attack.args.eps, eps_iter=cfg.image_attack.args.eps_iter, nb_iter=50, clip_min=-1.0, clip_max=1.0) # set dataset, dataloader dataset = get_dataset(cfg) transform = get_transform(cfg) trainset = dataset(root=trainset_cfg.path, train=True) testset = test_dataset(root=testset_cfg.path, train=False) train_sampler = None test_sampler = None if cfg.distributed: train_sampler = DistributedSampler(trainset) test_sampler = DistributedSampler(testset)
else: start_epoch = 0 best_acc = 0 # ############################################################################# e = 8. epsilon = e/255. max_iter= int(min(e+4, 1.25*e)) def normalize(img, mean=mean, std=std): img_n = img - mean img_n = img_n / std return img_n adversary = PGDAttack(lambda x: net(x), eps=epsilon, nb_iter=7, ord=np.inf, eps_iter=epsilon/4.) writer = SummaryWriter(comment=tensorboard_comment) for epoch in range(start_epoch+1, nb_epoch+1): if epoch >= args.epoch_adv: train_acc, train_loss = train(epoch, net, train_loader, optimizer, criterion_da, args, adv_training=True, epsilon=args.eps_train/255., alpha=args.alpha_train/255., num_iter=args.num_iter) else: train_acc, train_loss = train(epoch, net, train_loader, optimizer, criterion_class, args, adv_training=False) net.eval() val_acc, val_loss = test(net, val_loader, criterion_class, args) # adv_acc, adv_loss, _, _ = adv_test(net, val_loader, criterion_class, adversary, epsilon, args, store_imgs=False) # writer.add_scalar('adv_acc', adv_acc, epoch) writer.add_scalar('train_acc', train_acc, epoch)