def test_attack(threshold, arch, dataset, test_loader): target_model = StandardModel(dataset, arch, no_grad=False) if torch.cuda.is_available(): target_model = target_model.cuda() target_model.eval() attack = LinfPGDAttack(target_model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=threshold, nb_iter=30, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) all_count = 0 success_count = 0 all_adv_images = [] all_true_labels = [] for idx, (img, true_label) in enumerate(test_loader): img = img.cuda() true_label = true_label.cuda().long() adv_image = attack.perturb(img, true_label) # (3, 224, 224), float if adv_image is None: continue adv_label = target_model.forward(adv_image).max(1)[1].detach().cpu().numpy().astype(np.int32) # adv_image = np.transpose(adv_image, (0, 2, 3, 1)) # N,C,H,W -> (N, H, W, 3), float all_count += len(img) true_label_np = true_label.detach().cpu().numpy().astype(np.int32) success_count+= len(np.where(true_label_np != adv_label)[0]) all_adv_images.append(adv_image.cpu().detach().numpy()) all_true_labels.append(true_label_np) attack_success_rate = success_count / float(all_count) log.info("Before train. Attack success rate is {:.3f}".format(attack_success_rate)) return target_model, np.concatenate(all_adv_images,0), np.concatenate(all_true_labels, 0) # N,224,224,3
def attack_pgd_transfer(self, model_attacker, clean_loader, epsilon=0.1, eps_iter=0.02, test='average', nb_iter=7): """ Use adversarial samples generated against model_attacker to attack the current model. """ self.model.eval() self.model.reset() model_attacker.eval() model_attacker.reset() adversary = LinfPGDAttack( model_attacker.forward_adv, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) correct = 0 for batch_idx, (data, target) in enumerate(clean_loader): data, target = data.to(self.device), target.to(self.device) self.model.reset() model_attacker.reset() with ctx_noparamgrad_and_eval(model_attacker): adv_images = adversary.perturb(data, target) if(test=='last'): output = self.model.run_cycles(adv_images) elif(test=='average'): output = self.model.run_average(adv_images) else: self.model.reset() output = self.model(adv_images) pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() acc = correct / len(clean_loader.dataset) print('PGD attack Acc {:.3f}'.format(100. * acc)) return acc
def test_pgd(args, model, device, test_loader, epsilon=0.063): model.eval() model.reset() adversary = LinfPGDAttack(model.forward_adv, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) correct = 0 for batch_idx, (data, target) in enumerate(test_loader): data, target = data.to(device), target.to(device) model.reset() with ctx_noparamgrad_and_eval(model): adv_images = adversary.perturb(data, target) output = model.run_cycles(adv_images) pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() acc = correct / len(test_loader.dataset) print('PGD attack Acc {:.3f}'.format(100. * acc)) return acc
def __init__(self, args, model, nb_iter, loss_fn=nn.CrossEntropyLoss(reduction="sum")): super(PGDAttack, self).__init__(args, model, nb_iter, loss_fn) self.args = args self.model = model if args.attack_ball == 'Linf': self.adversary = LinfPGDAttack(self.model, loss_fn=loss_fn, eps=args.epsilon, nb_iter=nb_iter, eps_iter=0.01, rand_init=True, clip_min=args.clip_min, clip_max=args.clip_max, targeted=False) elif args.attack_ball == 'L2': self.adversary = L2PGDAttack(self.model, loss_fn=loss_fn, eps=args.epsilon, nb_iter=nb_iter, eps_iter=0.01, rand_init=True, clip_min=args.clip_min, clip_max=args.clip_max, targeted=False) else: raise NotImplementedError
def get_metric_eval(self): utr_score = [] tr_score = [] for i in range(1): ##TODO: Customise input parameters to methods like LinfPGDAttack adversary = LinfPGDAttack( self.phi, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.10, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) adv_untargeted = adversary.perturb(x_e, y_e) target = torch.ones_like(y_e) * 3 adversary.targeted = True adv_targeted = adversary.perturb(x_e, target) pred_cln = predict_from_logits(self.phi(x_e)) pred_untargeted_adv = predict_from_logits(self.phi(adv_untargeted)) pred_targeted_adv = predict_from_logits(self.phi(adv_targeted)) utr_score.append(torch.sum(pred_cln != pred_untargeted_adv)) tr_score.append(torch.sum(pred_cln != pred_targeted_adv)) batch_size = 5 plt.figure(figsize=(10, 8)) for ii in range(batch_size): plt.subplot(3, batch_size, ii + 1) _imshow(x_e[ii]) plt.title("clean \n pred: {}".format(pred_cln[ii])) plt.subplot(3, batch_size, ii + 1 + batch_size) _imshow(adv_untargeted[ii]) plt.title("untargeted \n adv \n pred: {}".format( pred_untargeted_adv[ii])) plt.subplot(3, batch_size, ii + 1 + batch_size * 2) _imshow(adv_targeted[ii]) plt.title("targeted to 3 \n adv \n pred: {}".format( pred_targeted_adv[ii])) plt.tight_layout() plt.savefig(self.save_path + '.png') utr_score = np.array(utr_score) tr_score = np.array(tr_score) print('MisClassifcation on Untargetted Attack ', np.mean(utr_score), np.std(utr_score)) print('MisClassifcation on Targeted Atttack', np.mean(tr_score), np.std(tr_score)) self.metric_score['Untargetted Method'] = np.mean(utr_score) self.metric_score['Targetted Method'] = np.mean(tr_score) return
def mifgsm_attack(max_count, model, train_loader, max_epsilon, learning_rate, iters=20, isnorm=False, num_classes=1000): if isnorm: mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) mean = torch.tensor(mean).float().view(3, 1, 1) std = torch.tensor(std).float().view(3, 1, 1) mmax = torch.ones(3, 224, 224) mmin = torch.zeros(3, 224, 224) mmax = ((mmax - mean) / std).cuda() mmin = ((mmin - mean) / std).cuda() learning_rate = learning_rate / (255 * 0.224) max_epsilon = max_epsilon / (255 * 0.224) else: learning_rate = float(learning_rate) max_epsilon = float(max_epsilon) mmax = 255 mmin = 0 # adversary = mifgsm(model,eps=max_epsilon,nb_iter=iters,eps_iter=learning_rate,clip_min=mmin,clip_max=mmax) adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=max_epsilon, nb_iter=iters, eps_iter=learning_rate, clip_min=mmin, clip_max=mmax, targeted=True) count = 0 total_correct = 0 # device = model.device() for x, y in train_loader: x = x.cuda() y = y.cuda() y1 = (y + 3) % num_classes count += len(x) ad_ex = adversary.perturb(x, y1) if not isnorm: ad_ex = torch.round(ad_ex) z1 = model(ad_ex).argmax(dim=1) diff = ad_ex - x total_correct += (z1 == y).sum() if count >= max_count: break return total_correct.cpu().numpy() / (count)
def init_models(self): args = self.args model = self.model teacher = self.teacher # Adv eval eval_pretrained_model = eval('fe{}'.format( args.network))(pretrained=True).cuda().eval() adversary = LinfPGDAttack(eval_pretrained_model, loss_fn=myloss, eps=args.B, nb_iter=args.pgd_iter, eps_iter=0.01, rand_init=True, clip_min=-2.2, clip_max=2.2, targeted=False) adveval_test_loader = torch.utils.data.DataLoader( self.test_loader.dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=False) self.adv_eval_fn = partial( advtest, loader=adveval_test_loader, adversary=adversary, args=args, )
def whitebox_pgd(args, image, target, model, normalize=None): adversary = LinfPGDAttack( model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) adv_image = adversary.perturb(image, target) print("Target is %d" %(target)) pred = model(adv_image) out = pred.max(1, keepdim=True)[1] # get the index of the max log-probability print("Adv Target is %d" %(out)) clean_image = (image)[0].detach() adv_image = adv_image[0].detach() if args.comet: plot_image_to_comet(args,clean_image,"clean.png") plot_image_to_comet(args,adv_image,"Adv.png") return pred, clamp(clean_image - adv_image,0.,1.)
def pgd_attack(sdim, args): thresholds1, thresholds2 = extract_thresholds(sdim, args) results_dict = { 'reject_rate1': [], 'reject_rate2': [], 'l2_distortion': [] } eps_list = [0.01, 0.02, 0.05, 0.1] for eps in eps_list: adversary = LinfPGDAttack(sdim, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=args.targeted) logger.info('epsilon = {:.4f}'.format(adversary.eps)) #attack_run(sdim, adversary, args) l2_dist, rj_rate1, rj_rate2 = adv_eval_with_rejection( sdim, adversary, args, thresholds1, thresholds2) results_dict['reject_rate1'].append(rj_rate1) results_dict['reject_rate2'].append(rj_rate2) results_dict['l2_distortion'].append(l2_dist) torch.save(results_dict, '{}_results.pth'.format(args.attack))
def pgd_attack(sdim, args): thresholds = extract_thresholds(sdim, args) results_dict = { 'clean_acc': [], 'clean_error': [], 'reject_rate': [], 'left_acc': [], 'left_error': [] } eps_list = [2 / 255, 4 / 255, 6 / 255, 8 / 255, 10 / 255] for eps in eps_list: adversary = LinfPGDAttack(sdim, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=clip_min, clip_max=clip_max, targeted=False) logger.info('epsilon = {:.4f}'.format(adversary.eps)) clean_acc, clean_err, reject_rate, left_acc, left_err = adv_eval_with_rejection( sdim, adversary, args, thresholds) results_dict['clean_acc'].append(clean_acc) results_dict['clean_error'].append(clean_err) results_dict['reject_rate'].append(reject_rate) results_dict['left_acc'].append(left_acc) results_dict['left_error'].append(left_err) torch.save(results_dict, 'adv_eval{}_results.pth'.format(suffix_dict[args.base_type]))
def generate_attack_samples(model, cln_data, true_label, nb_iter, eps_iter): adversary = LinfPGDAttack( model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.25, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) adv_untargeted = adversary.perturb(cln_data, true_label) adv_targeted_results = [] adv_target_labels = [] for target_label in range(0, 10): assert target_label >= 0 and target_label <= 10 and type( target_label) == int target = torch.ones_like(true_label) * target_label adversary.targeted = True adv_targeted = adversary.perturb(cln_data, target) adv_targeted_results.append(adv_targeted) adv_target_labels.append(target) return adv_targeted_results, adv_target_labels, adv_untargeted
def get_pgd_adversary(model, eps, num_iter, lr): adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps, nb_iter=num_iter, eps_iter=lr, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) return adversary
def create_adv_input(self, x, y, model): # Prepare copied model model = copy.deepcopy(model) # Prepare input and corresponding label data = torch.from_numpy(np.expand_dims(x, axis=0).astype(np.float32)) target = torch.from_numpy(np.array([y]).astype(np.int64)) data.requires_grad = True from advertorch.attacks import LinfPGDAttack adversary = LinfPGDAttack(model.forward) perturbed_data = adversary.perturb(data, target) # Have to be different output = model.forward(perturbed_data) final_pred = output.max( 1, keepdim=True)[1] # get the index of the max log-probability if final_pred.item() == target.item(): return perturbed_data, 0 else: return perturbed_data, 1
def get_attackers(attack_type: str, cfg: dict, loss_fn: nn, model: models, *args, **kwargs): attackers = {} if attack_type == 'gsa' or attack_type == 'all': # https://arxiv.org/abs/1412.652 from advertorch.attacks import GradientSignAttack adversary = GradientSignAttack(model, loss_fn=loss_fn, eps=cfg['adv_gsa_eps'], clip_min=cfg['adv_gsa_clip_min'], clip_max=cfg['adv_gsa_clip_max'], targeted=cfg['adv_gsa_targeted']) attackers['gsa'] = adversary if attack_type == 'linfpgd' or attack_type == 'all': from advertorch.attacks import LinfPGDAttack adversary = LinfPGDAttack(model, loss_fn=loss_fn, eps=cfg['adv_linfpgd_eps'], nb_iter=cfg['adv_linfpgd_nb_iter'], eps_iter=cfg['adv_linfpgd_eps_iter'], rand_init=cfg['adv_linfpgd_rand_int'], clip_min=cfg['adv_linfpgd_clip_min'], clip_max=cfg['adv_linfpgd_clip_max'], targeted=cfg['adv_linfpgd_targeted']) attackers['linfpgd'] = adversary if attack_type == 'singlepixel' or attack_type == 'all': # https://arxiv.org/pdf/1612.06299.pdf from advertorch.attacks import SinglePixelAttack adversary = SinglePixelAttack( model, loss_fn=loss_fn, max_pixels=cfg['adv_singlepixel_max_pixel'], clip_min=cfg['adv_singlepixel_clip_min'], clip_max=cfg['adv_singlepixel_clip_max'], targeted=cfg['adv_singlepixel_targeted']) attackers['singlepixel'] = adversary # if attack_type=='jacobiansaliencymap' or attack_type=='all': if attack_type == 'jacobiansaliencymap': # https://arxiv.org/abs/1511.07528v1 from advertorch.attacks import JacobianSaliencyMapAttack adversary = JacobianSaliencyMapAttack( model, num_classes=cfg['adv_jacobiansaliencymap_num_classes'], clip_min=cfg['adv_jacobiansaliencymap_clip_min'], clip_max=cfg['adv_jacobiansaliencymap_clip_max'], gamma=cfg['adv_jacobiansaliencymap_gamma'], theta=cfg['adv_jacobiansaliencymap_theta']) attackers['jacobiansaliencymap'] = adversary return attackers
def main(): # create dataloader transform_test = transforms.Compose([ transforms.ToTensor(), ]) data_set = SubsetImageNet(root=args.input_dir, transform=transform_test) data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=False, **kwargs) # create models net = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet') model = nn.Sequential(Normalize(mean=net.mean, std=net.std), net) model = model.to(device) model.eval() # create adversary attack epsilon = args.epsilon / 255.0 if args.step_size < 0: step_size = epsilon / args.num_steps else: step_size = args.step_size / 255.0 # if args.gamma < 1.0: # print('using our method') # register_hook(model, args.arch, args.gamma, is_conv=args.is_conv) # using our method - Skip Gradient Method (SGM) if args.gamma < 1.0: if args.arch in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']: register_hook_for_resnet(model, arch=args.arch, gamma=args.gamma) elif args.arch in ['densenet121', 'densenet169', 'densenet201']: register_hook_for_densenet(model, arch=args.arch, gamma=args.gamma) else: raise ValueError('Current code only supports resnet/densenet. ' 'You can extend this code to other architectures.') if args.momentum > 0.0: print('using PGD attack with momentum = {}'.format(args.momentum)) adversary = MomentumIterativeAttack(predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size, decay_factor=args.momentum, clip_min=0.0, clip_max=1.0, targeted=False) else: print('using linf PGD attack') adversary = LinfPGDAttack(predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size, rand_init=False, clip_min=0.0, clip_max=1.0, targeted=False) generate_adversarial_example(model=model, data_loader=data_loader, adversary=adversary, img_path=data_set.img_path)
def init_pgd(self, model, test=False): epsilon = self.spectral_args[ 'test_epsilon'] if test else self.spectral_args['train_epsilon'] step_size = self.spectral_args[ 'test_eps_iter'] if test else self.spectral_args['eps_iter'] nb_iter = self.spectral_args[ 'test_nb_iter'] if test else self.spectral_args['nb_iter'] return LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon / 255.0, nb_iter=nb_iter, eps_iter=step_size, rand_init=True, clip_min=self.spectral_args['clip_min'], clip_max=self.spectral_args['clip_max'], targeted=False)
def init_adversary(self, model, test=False): # chose train or test epsilon epsilon = self.spectral_args[ 'test_epsilon'] if test else self.spectral_args['train_epsilon'] uniform_sample = False #True if not test else False return LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=self.spectral_args["nb_iter"], eps_iter=self.spectral_args["eps_iter"], rand_init=True, clip_min=self.spectral_args['clip_min'], clip_max=self.spectral_args['clip_max'], targeted=False, uniform_sample=uniform_sample)
def linfPGD_attack(model, hps): eps_list = [0., 0.1, 0.2, 0.3, 0.4, 0.5] #hps.n_batch_test = 5 print('============== LinfPGD Summary ===============') for eps in eps_list: adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=hps.targeted) print('epsilon = {:.4f}'.format(adversary.eps)) attack_run(model, adversary, hps) #attack_run_rejection_policy(model, adversary, hps) print('============== LinfPGD Summary ===============')
def make_adversary_dict(model, model_name, targetted=False): if (model_name == "capsnet"): model_for_adversary = Model_for_Adversary_Caps(model) else: model_for_adversary = Model_for_Adversary_CNN(model) linf_eps = 0.3 fgsm_step = 0.05 bim_pgd_step = 0.01 adversary_dict = {} adversary_dict['Clean'] = CleanAttack(clip_min=-0.4242, clip_max=2.8215) adversary_dict['PGD'] = LinfPGDAttack( model_for_adversary, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=(linf_eps / 0.3081), nb_iter=100, eps_iter=(bim_pgd_step / 0.3081), rand_init=True, clip_min=-0.4242, clip_max=2.8215, targeted=targetted) adversary_dict['FGSM'] = GradientSignAttack( model_for_adversary, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=(fgsm_step / 0.3081), clip_min=-0.4242, clip_max=2.8215, targeted=targetted) adversary_dict['BIM'] = LinfBasicIterativeAttack( model_for_adversary, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=(linf_eps / 0.3081), nb_iter=100, eps_iter=(bim_pgd_step / 0.3081), clip_min=-0.4242, clip_max=2.8215, targeted=targetted) return adversary_dict
def generate(datasetname, batch_size): save_dir_path = "{}/data_adv_defense/guided_denoiser".format(PY_ROOT) os.makedirs(save_dir_path, exist_ok=True) set_log_file(save_dir_path + "/generate_{}.log".format(datasetname)) data_loader = DataLoaderMaker.get_img_label_data_loader(datasetname, batch_size, is_train=True) attackers = [] for model_name in MODELS_TRAIN_STANDARD[datasetname] + MODELS_TEST_STANDARD[datasetname]: model = StandardModel(datasetname, model_name, no_grad=False) model = model.cuda().eval() linf_PGD_attack =LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.031372, nb_iter=30, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) l2_PGD_attack = L2PGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),eps=4.6, nb_iter=30,clip_min=0.0, clip_max=1.0, targeted=False) FGSM_attack = FGSM(model, loss_fn=nn.CrossEntropyLoss(reduction="sum")) momentum_attack = MomentumIterativeAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.031372, nb_iter=30, eps_iter=0.01, clip_min=0.0, clip_max=1.0, targeted=False) attackers.append(linf_PGD_attack) attackers.append(l2_PGD_attack) attackers.append(FGSM_attack) attackers.append(momentum_attack) log.info("Create model {} done!".format(model_name)) generate_and_save_adv_examples(datasetname, data_loader, attackers, save_dir_path)
shuffle=False, num_workers=8, pin_memory=False) transferred_model = eval('{}_dropout'.format(args.network))( pretrained=False, dropout=args.dropout, num_classes=test_loader.dataset.num_classes).cuda() checkpoint = torch.load(args.checkpoint) transferred_model.load_state_dict(checkpoint['state_dict']) pretrained_model = eval('fe{}'.format( args.network))(pretrained=True).cuda().eval() adversary = LinfPGDAttack(pretrained_model, loss_fn=myloss, eps=args.B, nb_iter=args.pgd_iter, eps_iter=0.01, rand_init=True, clip_min=-2.2, clip_max=2.2, targeted=False) clean_top1, adv_top1, adv_sr = test(transferred_model, test_loader, adversary) print( 'Clean Top-1: {:.2f} | Adv Top-1: {:.2f} | Attack Success Rate: {:.2f}' .format(clean_top1, adv_top1, adv_sr))
import torch.nn as nn from advertorch.attacks import LinfPGDAttack from advertorch.attacks.utils import multiple_mini_batch_attack from advertorch_examples.utils import get_mnist_test_loader from madry_et_al_utils import get_madry_et_al_tf_model model = get_madry_et_al_tf_model("MNIST") loader = get_mnist_test_loader(batch_size=100) adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3, nb_iter=100, eps_iter=0.01, rand_init=False, clip_min=0.0, clip_max=1.0, targeted=False) label, pred, advpred = multiple_mini_batch_attack(adversary, loader, device="cuda") print("Accuracy: {:.2f}%, Robust Accuracy: {:.2f}%".format( 100. * (label == pred).sum().item() / len(label), 100. * (label == advpred).sum().item() / len(label))) # Accuracy: 98.53%, Robust Accuracy: 92.51%
inputs, targets = inputs.to(device), targets.to(device) with torch.enable_grad(): adv = adversary.perturb(inputs, targets) outputs = net(adv) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() iterator.set_description(str(predicted.eq(targets).sum().item()/targets.size(0))) # Save checkpoint. acc = 100.*correct/total print('Test Acc of ckpt.{}: {}'.format(args.epoch, acc)) print('==> Loading from checkpoint epoch {}..'.format(args.epoch)) assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/ckpt.{}'.format(args.epoch)) net.load_state_dict(checkpoint['net']) net = net.to(device) net.eval() adversary = LinfPGDAttack( net, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.epsilon, nb_iter=args.iteration, eps_iter=args.step_size, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) test(adversary)
def main(): # create model net = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet') height, width = net.input_size[1], net.input_size[2] model = nn.Sequential(Normalize(mean=net.mean, std=net.std), net) model = model.to(device) # create dataloader data_loader, image_list = load_images(input_dir=args.input_dir, batch_size=args.batch_size, input_height=height, input_width=width) # create adversary epsilon = args.epsilon / 255.0 if args.step_size < 0: step_size = epsilon / args.num_steps else: step_size = args.step_size / 255.0 # using our method - Skip Gradient Method (SGM) if args.gamma < 1.0: if args.arch in [ 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' ]: register_hook_for_resnet(model, arch=args.arch, gamma=args.gamma) elif args.arch in ['densenet121', 'densenet169', 'densenet201']: register_hook_for_densenet(model, arch=args.arch, gamma=args.gamma) else: raise ValueError( 'Current code only supports resnet/densenet. ' 'You can extend this code to other architectures.') if args.momentum > 0.0: print('using PGD attack with momentum = {}'.format(args.momentum)) adversary = MomentumIterativeAttack( predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size, decay_factor=args.momentum, clip_min=0.0, clip_max=1.0, targeted=False) else: print('using linf PGD attack') adversary = LinfPGDAttack(predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size, rand_init=False, clip_min=0.0, clip_max=1.0, targeted=False) generate_adversarial_example(model=model, data_loader=data_loader, adversary=adversary, img_path=image_list)
def train_epoch(self, model: nn.Module, train_loader: DataLoader, val_clean_loader: DataLoader, val_triggered_loader: DataLoader, epoch_num: int, use_amp: bool = False): """ Runs one epoch of training on the specified model :param model: the model to train for one epoch :param train_loader: a DataLoader object pointing to the training dataset :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered :param epoch_num: the epoch number that is being trained :param use_amp: if True, uses automated mixed precision for FP16 training. :return: a list of statistics for batches where statistics were computed """ # Probability of Adversarial attack to occur in each iteration attack_prob = self.optimizer_cfg.training_cfg.adv_training_ratio pid = os.getpid() train_dataset_len = len(train_loader.dataset) loop = tqdm(train_loader, disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar) scaler = None if use_amp: scaler = torch.cuda.amp.GradScaler() train_n_correct, train_n_total = None, None # Define parameters of the adversarial attack attack_eps = float(self.optimizer_cfg.training_cfg.adv_training_eps) attack_iterations = int(self.optimizer_cfg.training_cfg.adv_training_iterations) eps_iter = (2.0 * attack_eps) / float(attack_iterations) attack = LinfPGDAttack( predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=attack_eps, nb_iter=attack_iterations, eps_iter=eps_iter) sum_batchmean_train_loss = 0 running_train_acc = 0 num_batches = len(train_loader) model.train() for batch_idx, (x, y_truth) in enumerate(loop): x = x.to(self.device) y_truth = y_truth.to(self.device) # put network into training mode & zero out previous gradient computations self.optimizer.zero_grad() # get predictions based on input & weights learned so far if use_amp: with torch.cuda.amp.autocast(): # add adversarial noise via l_inf PGD attack # only apply attack to attack_prob of the batches if attack_prob and np.random.rand() <= attack_prob: with ctx_noparamgrad_and_eval(model): x = attack.perturb(x, y_truth) y_hat = model(x) # compute metrics batch_train_loss = self._eval_loss_function(y_hat, y_truth) else: # add adversarial noise vis lin PGD attack if attack_prob and np.random.rand() <= attack_prob: with ctx_noparamgrad_and_eval(model): x = attack.perturb(x, y_truth) y_hat = model(x) batch_train_loss = self._eval_loss_function(y_hat, y_truth) sum_batchmean_train_loss += batch_train_loss.item() running_train_acc, train_n_total, train_n_correct = default_optimizer._running_eval_acc(y_hat, y_truth, n_total=train_n_total, n_correct=train_n_correct, soft_to_hard_fn=self.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs) # compute gradient if use_amp: # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(batch_train_loss).backward() else: if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc): default_optimizer._save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc, train_n_total, train_n_correct, model) batch_train_loss.backward() # perform gradient clipping if configured if self.optimizer_cfg.training_cfg.clip_grad: if use_amp: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(self.optimizer) if self.optimizer_cfg.training_cfg.clip_type == 'norm': # clip_grad_norm_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val, **self.optimizer_cfg.training_cfg.clip_kwargs) elif self.optimizer_cfg.training_cfg.clip_type == 'val': # clip_grad_val_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_value_( model.parameters(), self.optimizer_cfg.training_cfg.clip_val) else: msg = "Unknown clipping type for gradient clipping!" logger.error(msg) raise ValueError(msg) if use_amp: # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(self.optimizer) # Updates the scale for next iteration. scaler.update() else: self.optimizer.step() # report batch statistics to tensorflow if self.tb_writer: try: batch_num = int(epoch_num * num_batches + batch_idx) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss', batch_train_loss.item(), global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc', running_train_acc, global_step=batch_num) except: # TODO: catch specific expcetions pass loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs)) loop.set_postfix(avg_train_loss=batch_train_loss.item()) if batch_idx % self.num_batches_per_logmsg == 0: logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format( pid, epoch_num, batch_idx * len(x), train_dataset_len, 100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc)) train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches)) # if we have validation data, we compute on the validation dataset num_val_batches_clean = len(val_clean_loader) if num_val_batches_clean > 0: logger.info('Running Validation on Clean Data') running_val_clean_acc, _, _, val_clean_loss = \ default_optimizer._eval_acc(val_clean_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info("No dataset computed for validation on clean dataset!") running_val_clean_acc = None val_clean_loss = None num_val_batches_triggered = len(val_triggered_loader) if num_val_batches_triggered > 0: logger.info('Running Validation on Triggered Data') running_val_triggered_acc, _, _, val_triggered_loss = \ default_optimizer._eval_acc(val_triggered_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info( "No dataset computed for validation on triggered dataset!") running_val_triggered_acc = None val_triggered_loss = None validation_stats = EpochValidationStatistics(running_val_clean_acc, val_clean_loss, running_val_triggered_acc, val_triggered_loss) if num_val_batches_clean > 0: logger.info('{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'.format( pid, epoch_num, val_clean_loss, running_val_clean_acc)) if num_val_batches_triggered > 0: logger.info('{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'.format( pid, epoch_num, val_triggered_loss, running_val_triggered_acc)) if self.tb_writer: try: batch_num = int((epoch_num + 1) * num_batches) if num_val_batches_clean > 0: self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val-loss', val_clean_loss, global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val_acc', running_val_clean_acc, global_step=batch_num) if num_val_batches_triggered > 0: self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val-loss', val_triggered_loss, global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num) except: pass # update the lr-scheduler if necessary if self.lr_scheduler is not None: if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None: self.lr_scheduler.step() elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc': val_acc = validation_stats.get_val_acc() if val_acc is not None: self.lr_scheduler.step(val_acc) else: msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss': val_loss = validation_stats.get_val_loss() if val_loss is not None: self.lr_scheduler.step(val_loss) else: msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) else: msg = "Unknown mode for calling lr_scheduler!" logger.error(msg) raise ValueError(msg) return train_stats, validation_stats
cudnn.benchmark = True net, _, _, _ = build_model(args) _, teloader = prepare_test_data(args) _, trloader = prepare_train_data(args) optimizer = optim.SGD(list(net.parameters()), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, [args.milestone_1, args.milestone_2], gamma=0.1, last_epoch=-1) criterion = nn.CrossEntropyLoss().cuda() all_err_cls = [] print('Running...') print('Error (%)\t\ttest') adversary = LinfPGDAttack( net, loss_fn=nn.CrossEntropyLoss().cuda(), eps=16/255, nb_iter=7, eps_iter=4/255, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) if args.alp: criterion_alp = nn.MSELoss().cuda() for epoch in range(1, args.nepoch+1): net.train() for batch_idx, (inputs, labels) in enumerate(trloader): inputs_cls, labels_cls = inputs.cuda(), labels.cuda() optimizer.zero_grad() with ctx_noparamgrad_and_eval(net): inputs_adv = adversary.perturb(inputs_cls, labels_cls) if args.weight == 0:
def train(self, train_loader, test_loader, adv_models, l_test_classif_paths, l_train_classif=None, eval_fn=None): args = self.args pgd_adversary = None if args.save_model is None: filename = 'saved_models/generator.pt' else: filename = os.path.join(args.save_model, 'generator.pt') dirname = os.path.dirname(filename) if not os.path.exists(dirname): os.makedirs(dirname) gen_opt = optim.Adam(self.G.parameters(), lr=args.lr, betas=(args.momentum, .99)) if args.lr_model is None: args.lr_model = args.lr model_opt = self.get_optim(self.predict.parameters(), lr=args.lr_model, betas=(args.momentum, .99)) if args.extragradient: gen_opt = Extragradient(gen_opt, self.G.parameters()) if not args.fixed_critic: model_opt = Extragradient(model_opt, self.predict.parameters()) if args.lr_model is None: args.lr_model = args.lr if args.pgd_on_critic: pgd_adversary = LinfPGDAttack( self.predict, loss_fn=torch.nn.CrossEntropyLoss(reduction="sum"), eps=args.epsilon, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,clip_max=1.0, targeted=False) # Choose Attack Loss perturb_loss_func = kwargs_perturb_loss[args.perturb_loss] pick_prob = self.pick_prob_start ''' Training Phase ''' best_acc = 100 for epoch in range(0, args.attack_epochs): adv_correct = 0 clean_correct = 0 num_clean = 0 if args.train_set == 'test': print("Training generator on the test set") loader = test_loader elif args.train_set == 'train_and_test': print("Training generator on the test set and the train set") loader = concat_dataset(args, train_loader,test_loader) else: print("Training generator on the train set") loader = train_loader train_itr = tqdm(enumerate(loader), total=len(list(loader))) if self.robust_train_flag: use_robust_critic = self.pick_rob_prob.sample() pick_prob = min(1.0, pick_prob + self.anneal_rate) self.pick_rob_prob = Bernoulli(torch.tensor([1-pick_prob])) print("Using Robust Critic this Epoch with Prob: %f" %(1-pick_prob)) for batch_idx, (data, target) in train_itr: x, target = data.to(args.dev), target.to(args.dev) num_unperturbed = 10 iter_count = 0 loss_perturb = 20 loss_model = 0 loss_misclassify = -10 anneal_eps = 1 + args.anneal_eps**(epoch+1) for i in range(args.max_iter): if args.not_use_labels: # erase the current target to provide adv_inputs, kl_div = self.perturb(x, compute_kl=True, anneal_eps=anneal_eps) else: adv_inputs, kl_div = self.perturb(x, compute_kl=True, anneal_eps=anneal_eps, target=target) # Optim step for the generator loss_misclassify, loss_gen, loss_perturb = self.gen_update(args, epoch, batch_idx, x, target, adv_inputs, l_train_classif, kl_div, perturb_loss_func, gen_opt) iter_count = iter_count + 1 if iter_count > args.max_iter: break if args.gp_coeff > 0.: x = autograd.Variable(x, requires_grad=True) adv_inputs = torch.clamp(self.perturb(x, anneal_eps=anneal_eps), min=0., max=1.0) adv_pred = self.predict(adv_inputs) adv_out = adv_pred.max(1, keepdim=True)[1] # Optim step for the classifier loss_model, clean_out, target_clean = self.critic_update(args, epoch, train_loader, batch_idx, x, target, adv_pred, model_opt, pgd_adversary) adv_correct += adv_out.eq(target.unsqueeze(1).data).sum() clean_correct += clean_out.eq(target_clean.unsqueeze(1).data).sum() num_clean += target_clean.shape[0] if args.wandb: nobox_wandb(args, epoch, x, target, adv_inputs, adv_out, adv_correct, clean_correct, loss_misclassify, loss_model, loss_perturb, loss_gen, train_loader) print(f'\nTrain: Epoch:{epoch} Loss: {loss_model:.4f}, Gen Loss :{loss_gen:.4f}, ' f'Missclassify Loss :{loss_misclassify:.4f} ' f'Clean. Acc: {clean_correct}/{num_clean} ' f'({100. * clean_correct.cpu().numpy()/num_clean:.0f}%) ' f'Perturb Loss {loss_perturb:.4f} Adv. Acc: {adv_correct}/{len(loader.dataset)} ' f'({100. * adv_correct.cpu().numpy()/len(loader.dataset):.0f}%)\n') if (epoch + 1) % args.eval_freq == 0: if eval_fn is None: with torch.no_grad(): mean_acc, std_acc, all_acc = eval.eval(args, self, test_loader, l_test_classif_paths, logger=self.logger, epoch=epoch) else: if args.target_arch is not None: model_type = args.target_arch elif args.source_arch == "adv" or (args.source_arch == "ens_adv" and args.dataset == "mnist"): model_type = [args.model_type] else: model_type = [args.adv_models[0]] eval_helpers = [self.predict, model_type, adv_models, l_test_classif_paths, test_loader] mean_acc = eval_attacker(args, self, "AEG", eval_helpers, args.num_eval_samples) # eval_fn(list(l_train_classif.values())[0]) if mean_acc < best_acc: best_acc = mean_acc try: torch.save({"model": self.G.state_dict(), "args": args}, filename) except: print("Warning: Failed to save model !") return adv_out, adv_inputs
batch_size=100, shuffle=False, num_workers=4) net = ResNet18() net = net.to(device) net = torch.nn.DataParallel(net) cudnn.benchmark = True checkpoint = torch.load('./checkpoint/' + file_name) net.load_state_dict(checkpoint['net']) adversary = LinfPGDAttack(net, loss_fn=nn.CrossEntropyLoss(), eps=0.0314, nb_iter=7, eps_iter=0.00784, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False) criterion = nn.CrossEntropyLoss() def test(): print('\n[ Test Start ]') net.eval() benign_loss = 0 adv_loss = 0 benign_correct = 0 adv_correct = 0 total = 0
def sample_cases(sdim, args): sdim.eval() n_classes = args.get(args.dataset).n_classes sample_likelihood_dict = {} # logger.info('==> Corruption type: {}, severity level {}'.format(corruption_type, level)) data_dir = hydra.utils.to_absolute_path(args.data_dir) dataset = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False) test_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) x, y = next(iter(test_loader)) x, y = x.to(args.device), y.long().to(args.device) def f_forward(x_, y_, image_name): with torch.no_grad(): log_lik = sdim(x_) save_name = '{}.png'.format(image_name) save_image(x_, save_name, normalize=True) return log_lik[:, y_].item() sample_likelihood_dict['original'] = f_forward(x, y, 'original') eps_2 = 2 / 255 eps_4 = 4 / 255 eps_8 = 8 / 255 x_u_4 = (x + torch.FloatTensor(x.size()).uniform_(-eps_4, eps_4).to( args.device)).clamp_(0., 1.) x_g_4 = (x + torch.randn(x.size()).clamp_(-eps_4, eps_4).to( args.device)).clamp_(0., 1.) x_u_8 = (x + torch.FloatTensor(x.size()).uniform_(-eps_8, eps_8).to( args.device)).clamp_(0., 1.) x_g_8 = (x + torch.randn(x.size()).clamp_(-eps_8, eps_8).to( args.device)).clamp_(0., 1.) sample_likelihood_dict['uniform_4'] = f_forward(x_u_4, y, 'uniform_4') sample_likelihood_dict['uniform_8'] = f_forward(x_u_8, y, 'uniform_8') sample_likelihood_dict['gaussian_4'] = f_forward(x_g_4, y, 'gaussian_4') sample_likelihood_dict['gaussian_8'] = f_forward(x_g_8, y, 'gaussian_8') adversary = LinfPGDAttack(sdim, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps_2, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) adv_pgd_2 = adversary.perturb(x, y) adversary = LinfPGDAttack(sdim, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps_4, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) adv_pgd_4 = adversary.perturb(x, y) adversary = LinfPGDAttack(sdim, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=eps_8, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) adv_pgd_8 = adversary.perturb(x, y) # adversary = CW(sdim, n_classes, max_iterations=1000, c=1, clip_min=0., clip_max=1., learning_rate=0.01, # targeted=False) # # adv_cw_1, _, _, _ = adversary.perturb(x, y) # # adversary = CW(sdim, n_classes, max_iterations=1000, c=10, clip_min=0., clip_max=1., learning_rate=0.01, # targeted=False) # # adv_cw_10, _, _, _ = adversary.perturb(x, y) sample_likelihood_dict['pgd_2'] = f_forward(adv_pgd_2, y, 'pgd_2') sample_likelihood_dict['pgd_4'] = f_forward(adv_pgd_4, y, 'pgd_4') sample_likelihood_dict['pgd_8'] = f_forward(adv_pgd_8, y, 'pgd_8') # sample_likelihood_dict['cw_1'] = f_forward(adv_cw_1, y, 'cw_1') # sample_likelihood_dict['cw_10'] = f_forward(adv_cw_10, y, 'cw_10') print(sample_likelihood_dict) save_dir = hydra.utils.to_absolute_path('attack_logs/case_study') if not os.path.exists(save_dir): os.mkdir(save_dir) torch.save(sample_likelihood_dict, os.path.join(save_dir, 'sample_likelihood_dict.pt'))
def train_adv(args, model, device, train_loader, optimizer, scheduler, epoch, cycles, mse_parameter=1.0, clean_parameter=1.0, clean='supclean'): model.train() correct = 0 train_loss = 0.0 model.reset() adversary = LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=args.eps, nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True, clip_min=-1.0, clip_max=1.0, targeted=False) print(len(train_loader)) for batch_idx, (images, targets) in enumerate(train_loader): optimizer.zero_grad() images = images.cuda() targets = targets.cuda() model.reset() with ctx_noparamgrad_and_eval(model): adv_images = adversary.perturb(images, targets) images_all = torch.cat((images, adv_images), 0) # Reset the model latent variables model.reset() if (args.dataset == 'cifar10'): logits, orig_feature_all, block1_all, block2_all, block3_all = model( images_all, first=True, inter=True) elif (args.dataset == 'fashion'): logits, orig_feature_all, block1_all, block2_all = model( images_all, first=True, inter=True) ff_prev = orig_feature_all # f1 the original feature of clean images orig_feature, _ = torch.split(orig_feature_all, images.size(0)) block1_clean, _ = torch.split(block1_all, images.size(0)) block2_clean, _ = torch.split(block2_all, images.size(0)) if (args.dataset == 'cifar10'): block3_clean, _ = torch.split(block3_all, images.size(0)) logits_clean, logits_adv = torch.split(logits, images.size(0)) if not ('no' in clean): loss = (clean_parameter * F.cross_entropy(logits_clean, targets) + F.cross_entropy(logits_adv, targets)) / (2 * (cycles + 1)) else: loss = F.cross_entropy(logits_adv, targets) / (cycles + 1) for i_cycle in range(cycles): if (args.dataset == 'cifar10'): recon, block1_recon, block2_recon, block3_recon = model( logits, step='backward', inter_recon=True) elif (args.dataset == 'fashion'): recon, block1_recon, block2_recon = model(logits, step='backward', inter_recon=True) recon_clean, recon_adv = torch.split(recon, images.size(0)) recon_block1_clean, recon_block1_adv = torch.split( block1_recon, images.size(0)) recon_block2_clean, recon_block2_adv = torch.split( block2_recon, images.size(0)) if (args.dataset == 'cifar10'): recon_block3_clean, recon_block3_adv = torch.split( block3_recon, images.size(0)) loss += (F.mse_loss(recon_adv, orig_feature) + F.mse_loss(recon_block1_adv, block1_clean) + F.mse_loss(recon_block2_adv, block2_clean) + F.mse_loss(recon_block3_adv, block3_clean) ) * mse_parameter / (4 * cycles) elif (args.dataset == 'fashion'): loss += (F.mse_loss(recon_adv, orig_feature) + F.mse_loss(recon_block1_adv, block1_clean) + F.mse_loss(recon_block2_adv, block2_clean) ) * mse_parameter / (3 * cycles) # feedforward ff_current = ff_prev + args.res_parameter * (recon - ff_prev) logits = model(ff_current, first=False) ff_prev = ff_current logits_clean, logits_adv = torch.split(logits, images.size(0)) if not ('no' in clean): loss += ( clean_parameter * F.cross_entropy(logits_clean, targets) + F.cross_entropy(logits_adv, targets)) / (2 * (cycles + 1)) else: loss += F.cross_entropy(logits_adv, targets) / (cycles + 1) pred = logits_clean.argmax( dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(targets.view_as(pred)).sum().item() loss.backward() if (args.grad_clip): nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() scheduler.step() train_loss += loss if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(images[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_loss /= len(train_loader) acc = correct / len(train_loader.dataset) return train_loss, acc