def __init__(self, device, model, model_num_labels, image_nc, box_min, box_max, eps, pgd_iter, models_path, out_path, model_name, writer, E_lr, defG_lr): output_nc = image_nc self.device = device self.model_num_labels = model_num_labels self.model = model self.input_nc = image_nc self.output_nc = output_nc self.box_min = box_min self.box_max = box_max self.eps = eps self.pgd_iter = pgd_iter self.models_path = models_path self.out_path = out_path self.model_name = model_name self.writer = writer self.E_lr = E_lr self.defG_lr = defG_lr self.en_input_nc = image_nc self.E = models.Encoder(image_nc).to(device) self.defG = models.Generator(adv=False).to(device) self.pgd = PGD(self.model, self.E, self.defG, self.device, self.eps) # initialize all weights self.E.apply(weights_init) self.defG.apply(weights_init) # initialize optimizers self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr)
class AdvGAN_Attack: def __init__(self, device, model, model_num_labels, image_nc, box_min, box_max, eps, pgd_iter, models_path, out_path, model_name, writer, E_lr, defG_lr): output_nc = image_nc self.device = device self.model_num_labels = model_num_labels self.model = model self.input_nc = image_nc self.output_nc = output_nc self.box_min = box_min self.box_max = box_max self.eps = eps self.pgd_iter = pgd_iter self.models_path = models_path self.out_path = out_path self.model_name = model_name self.writer = writer self.E_lr = E_lr self.defG_lr = defG_lr self.en_input_nc = image_nc self.E = models.Encoder(image_nc).to(device) self.defG = models.Generator(adv=False).to(device) self.pgd = PGD(self.model, self.E, self.defG, self.device, self.eps) # initialize all weights self.E.apply(weights_init) self.defG.apply(weights_init) # initialize optimizers self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr) # generate images for training def gen_images(self, x, labels): # pgd image pgd_images = self.pgd.perturb(x, labels, itr=1) # def(pgd) image def_pgd_images = self.defG(self.E(pgd_images)) + pgd_images def_pgd_images = torch.clamp(def_pgd_images, self.box_min, self.box_max) # make def(nat) image def_images = self.defG(self.E(x)) + x def_images = torch.clamp(def_images, self.box_min, self.box_max) return pgd_images, def_pgd_images, def_images # performance tester def test(self): self.E.eval() self.defG.eval() test_full(self.device, self.model, self.E, self.defG, self.eps, self.out_path, self.model_name, label_count=True, save_img=True) self.E.train() self.defG.train() # train single batch def train_batch(self, x, labels): # optimize E, def for i in range(1): # clear grad self.optimizer_E.zero_grad() self.optimizer_defG.zero_grad() pgd_images, def_pgd_images, def_images = self.gen_images(x, labels) # def(pgd) loss logits_def_pgd = self.model(def_pgd_images) loss_def_pgd = F.cross_entropy(logits_def_pgd, labels) # def(nat) loss logits_def = self.model(def_images) loss_def = F.cross_entropy(logits_def, labels) # backprop loss = loss_def + loss_def_pgd loss.backward() self.optimizer_E.step() self.optimizer_defG.step() # pgd performance check self.E.eval() self.defG.eval() pgd_acc_li = [] pgd_nat_acc_li = [] pgd_nat_images = self.pgd.perturb(x, labels, itr=0) # obsufcated check pgd_nat_images = self.defG(self.E(pgd_nat_images)) + pgd_nat_images pgd_nat_images = torch.clamp(pgd_nat_images, self.box_min, self.box_max) pred = torch.argmax(self.model(pgd_images), 1) num_correct = torch.sum(pred == labels, 0) pgd_acc = num_correct.item()/len(labels) pgd_acc_li.append(pgd_acc) pred = torch.argmax(self.model(pgd_nat_images), 1) num_correct = torch.sum(pred == labels, 0) pgd_nat_acc = num_correct.item()/len(labels) pgd_nat_acc_li.append(pgd_nat_acc) self.E.train() self.defG.train() return pgd_acc_li, pgd_nat_acc_li, torch.sum(loss).item() # main training function def train(self, train_dataloader, epochs): for epoch in range(1, epochs+1): if epoch == 50: self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr/10) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr/10) if epoch == 80: self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr/100) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr/100) loss_sum = 0 pgd_acc_li_sum = [] pgd_nat_acc_li_sum = [] for i, data in enumerate(train_dataloader, start=0): images, labels = data images, labels = images.to(self.device), labels.to(self.device) pgd_acc_li_batch, pgd_nat_acc_li_batch, loss_batch = \ self.train_batch(images, labels) loss_sum += loss_batch pgd_acc_li_sum.append(pgd_acc_li_batch) pgd_nat_acc_li_sum.append(pgd_nat_acc_li_batch) # print statistics num_batch = len(train_dataloader) print("epoch %d:\nloss_E: %.5f" % (epoch, loss_sum/num_batch)) pgd_acc_li_sum = np.mean(np.array(pgd_acc_li_sum), axis=0) for idx in range(len(self.pgd_iter)): print("pgd iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_acc_li_sum[idx])) pgd_nat_acc_li_sum = np.mean(np.array(pgd_nat_acc_li_sum), axis=0) for idx in range(len(self.pgd_iter)): print("pgd nat iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_nat_acc_li_sum[idx])) print() # write to tensorboard if self.writer: self.writer.add_scalar('loss', loss_sum/num_batch, epoch) for idx in range(len(self.pgd_iter)): self.writer.add_scalar('pgd_acc_%d' % (self.pgd_iter[idx]), pgd_acc_li_sum[idx], epoch) self.writer.add_scalar('pgd_nat_acc_%d' % (self.pgd_iter[idx]), pgd_nat_acc_li_sum[idx], epoch) # save generator if epoch%20==0: E_file_name = self.models_path + self.model_name + 'E_epoch_' + str(epoch) + '.pth' defG_file_name = self.models_path + self.model_name + 'defG_epoch_' + str(epoch) + '.pth' torch.save(self.E.state_dict(), E_file_name) torch.save(self.defG.state_dict(), defG_file_name) if self.writer: self.writer.close() # test performance self.test()
def tester(dataset, dataloader, device, target_model, E, defG, advG, eps, out_path, model_name, label_count=True, save_img=False): # load PGD pgd = PGD(target_model, E, defG, device, eps) num_correct_adv = 0 num_correct_pgd = 0 num_correct_def_adv = 0 num_correct_def = 0 num_correct_def_pgd = 0 num_correct_def_pgd_nat = 0 num_correct = 0 test_img_full = [] adv_img_full = [] pgd_img_full = [] pgd_nat_img_full = [] def_img_full = [] def_adv_img_full = [] def_pgd_img_full = [] def_pgd_nat_img_full = [] pred_adv_full = [] pred_pgd_full = [] pred_def_pgd_full = [] pred_def_pgd_nat_full = [] for i, data in enumerate(dataloader, 0): # load images test_img, test_label = data test_img, test_label = test_img.to(device), test_label.to(device) target_labels = torch.randint_like(test_label, 0, 10) target_one_hot = torch.eye(10, device=device)[target_labels] target_one_hot = target_one_hot.view(-1, 10, 1, 1) # prep images x_encoded = E(test_img) adv_noise = advG(x_encoded, target_one_hot) adv_img = adv_noise * eps + test_img adv_img = torch.clamp(adv_img, 0, 1) def_adv_noise = defG(E(adv_img)) def_adv_img = def_adv_noise + adv_img def_adv_img = torch.clamp(def_adv_img, 0, 1) def_noise = defG(E(test_img)) def_img = def_noise + test_img def_img = torch.clamp(def_img, 0, 1) pgd_img = pgd.perturb(test_img, test_label, itr=1) def_pgd_noise = defG(E(pgd_img)) def_pgd_img = def_pgd_noise + pgd_img def_pgd_img = torch.clamp(def_pgd_img, 0, 1) pgd_nat_img = pgd.perturb(test_img, test_label, itr=0) def_pgd_nat_noise = defG(E(pgd_nat_img)) def_pgd_nat_img = def_pgd_nat_noise + pgd_nat_img def_pgd_nat_img = torch.clamp(def_pgd_nat_img, 0, 1) # calculate acc. pred = torch.argmax(normalized_eval(test_img, target_model), 1) pred_adv = torch.argmax(normalized_eval(adv_img, target_model), 1) pred_pgd = torch.argmax(normalized_eval(pgd_nat_img, target_model), 1) pred_def_adv = torch.argmax(normalized_eval(def_adv_img, target_model), 1) pred_def = torch.argmax(normalized_eval(def_img, target_model), 1) pred_def_pgd = torch.argmax(normalized_eval(def_pgd_img, target_model), 1) pred_def_pgd_nat = torch.argmax( normalized_eval(def_pgd_nat_img, target_model), 1) num_correct += torch.sum(pred == test_label, 0) num_correct_adv += torch.sum(pred_adv == test_label, 0) num_correct_pgd += torch.sum(pred_pgd == test_label, 0) num_correct_def_adv += torch.sum(pred_def_adv == test_label, 0) num_correct_def += torch.sum(pred_def == test_label, 0) num_correct_def_pgd += torch.sum(pred_def_pgd == test_label, 0) num_correct_def_pgd_nat += torch.sum(pred_def_pgd_nat == test_label, 0) ''' l_one = np.mean(np.abs(adv_noise.cpu().detach().numpy())) np.save('./out/noise/adv_noise.npy', adv_noise.cpu().detach().numpy()) print('l-one of adv noise:%f' % (l_one)) l_one = np.mean(np.abs(def_noise.cpu().detach().numpy())) np.save('./out/noise/def_noise.npy', def_noise.cpu().detach().numpy()) print('l-one of def noise:%f' % (l_one)) l_one = np.mean(np.abs(def_adv_noise.cpu().detach().numpy())) np.save('./out/noise/def_adv_noise.npy', def_adv_noise.cpu().detach().numpy()) print('l-one of def(adv) noise:%f' % (l_one)) l_one = np.mean(np.abs(def_pgd_noise.cpu().detach().numpy())) np.save('./out/noise/def_pgd_noise.npy', def_pgd_noise.cpu().detach().numpy()) print('l-one of def(pgd) noise:%f' % (l_one)) l_one = np.mean(np.abs(def_pgd_nat_noise.cpu().detach().numpy())) np.save('./out/noise/def_pgd_nat_noise.npy', def_pgd_nat_noise.cpu().detach().numpy()) print('l-one of def(pgd_nat) noise:%f' % (l_one)) exit() ''' if label_count: pred_adv_full.append(pred_adv) pred_pgd_full.append(pred_pgd) pred_def_pgd_full.append(pred_def_pgd) pred_def_pgd_nat_full.append(pred_def_pgd_nat) if save_img and i < 1: test_img_full.append(test_img) adv_img_full.append(adv_img) pgd_img_full.append(pgd_img) pgd_nat_img_full.append(pgd_nat_img) def_img_full.append(def_img) def_adv_img_full.append(def_adv_img) def_pgd_img_full.append(def_pgd_img) def_pgd_nat_img_full.append(def_pgd_nat_img) print('num_correct(nat): ', num_correct.item()) print('num_correct(adv): ', num_correct_adv.item()) print('num_correct(pgd): ', num_correct_pgd.item()) print('num_correct(def(adv)): ', num_correct_def_adv.item()) print('num_correct(def(nat)): ', num_correct_def.item()) print('num_correct(def(pgd)): ', num_correct_def_pgd.item()) print('num_correct(def(pgd_nat)): ', num_correct_def_pgd_nat.item()) print() print('accuracy of nat imgs: %f' % (num_correct.item() / len(dataset))) print('accuracy of adv imgs: %f' % (num_correct_adv.item() / len(dataset))) print('accuracy of pgd imgs: %f' % (num_correct_pgd.item() / len(dataset))) print('accuracy of def(adv) imgs: %f' % (num_correct_def_adv.item() / len(dataset))) print('accuracy of def(nat) imgs: %f' % (num_correct_def.item() / len(dataset))) print('accuracy of def(pgd) imgs: %f' % (num_correct_def_pgd.item() / len(dataset))) print('accuracy of def(pgd_nat) imgs: %f' % (num_correct_def_pgd_nat.item() / len(dataset))) print() l_inf = np.amax( np.abs(adv_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of adv imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_adv_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def(adv) imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_pgd_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def(pgd) imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_pgd_nat_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def(pgd_nat) imgs:%f' % (l_inf)) print() l_one = np.mean(np.abs(adv_noise.cpu().detach().numpy())) np.save('./out/noise/adv_noise.npy', adv_noise.cpu().detach().numpy()) print('l-one of adv noise:%f' % (l_one)) l_one = np.mean(np.abs(def_noise.cpu().detach().numpy())) np.save('./out/noise/def_noise.npy', def_noise.cpu().detach().numpy()) print('l-one of def noise:%f' % (l_one)) l_one = np.mean(np.abs(def_adv_noise.cpu().detach().numpy())) np.save('./out/noise/def_adv_noise.npy', def_adv_noise.cpu().detach().numpy()) print('l-one of def(adv) noise:%f' % (l_one)) l_one = np.mean(np.abs(def_pgd_noise.cpu().detach().numpy())) np.save('./out/noise/def_pgd_noise.npy', def_pgd_noise.cpu().detach().numpy()) print('l-one of def(pgd) noise:%f' % (l_one)) l_one = np.mean(np.abs(def_pgd_nat_noise.cpu().detach().numpy())) np.save('./out/noise/def_pgd_nat_noise.npy', def_pgd_nat_noise.cpu().detach().numpy()) print('l-one of def(pgd_nat) noise:%f' % (l_one)) print() if label_count: pred_adv_full = torch.cat(pred_adv_full) preds = pred_adv_full.cpu().detach().numpy() print('label counts in adv imgs:') print(np.unique(preds, return_counts=True)) pred_pgd_full = torch.cat(pred_pgd_full) preds = pred_pgd_full.cpu().detach().numpy() print('label counts in pgd imgs:') print(np.unique(preds, return_counts=True)) pred_def_pgd_full = torch.cat(pred_def_pgd_full) preds = pred_def_pgd_full.cpu().detach().numpy() print('label counts in def_pgd imgs:') print(np.unique(preds, return_counts=True)) pred_def_pgd_nat_full = torch.cat(pred_def_pgd_nat_full) preds = pred_def_pgd_nat_full.cpu().detach().numpy() print('label counts in def_pgd_nat imgs:') print(np.unique(preds, return_counts=True)) print() if save_img: test_img_full = torch.cat(test_img_full) adv_img_full = torch.cat(adv_img_full) pgd_img_full = torch.cat(pgd_img_full) pgd_nat_img_full = torch.cat(pgd_nat_img_full) def_img_full = torch.cat(def_img_full) def_adv_img_full = torch.cat(def_adv_img_full) def_pgd_img_full = torch.cat(def_pgd_img_full) def_pgd_nat_img_full = torch.cat(def_pgd_nat_img_full) test_grid = make_grid(test_img_full) adv_grid = make_grid(adv_img_full) pgd_grid = make_grid(pgd_img_full) pgd_nat_grid = make_grid(pgd_nat_img_full) def_grid = make_grid(def_img_full) def_adv_grid = make_grid(def_adv_img_full) def_pgd_grid = make_grid(def_pgd_img_full) def_pgd_nat_grid = make_grid(def_pgd_nat_img_full) if not os.path.exists(out_path + model_name): os.makedirs(out_path + model_name) save_image(test_grid, out_path + model_name + 'test_grid.png') save_image(adv_grid, out_path + model_name + 'adv_grid.png') save_image(pgd_grid, out_path + model_name + 'pgd_grid.png') save_image(pgd_nat_grid, out_path + model_name + 'pgd_nat_grid.png') save_image(def_grid, out_path + model_name + 'def_grid.png') save_image(def_adv_grid, out_path + model_name + 'def_adv_grid.png') save_image(def_pgd_grid, out_path + model_name + 'def_pgd_grid.png') save_image(def_pgd_nat_grid, out_path + model_name + 'def_pgd_nat_grid.png') print('images saved')
print(arg, getattr(args, arg)) # Define what device we are using print("CUDA Available: ", torch.cuda.is_available()) device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu") mnist_dataset = torchvision.datasets.MNIST('./dataset', train=True, transform=transforms.ToTensor(), download=True) train_dataloader = DataLoader(mnist_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) # training the target model target_model = MNIST_target_net().to(device) target_model.train() opt_model = torch.optim.Adam(target_model.parameters(), lr=0.0001) epochs = args.epochs pgd = PGD(target_model, None, None, device, args.epsilon, 7, args.epsilon/4) for epoch in range(epochs): loss_epoch = 0 if epoch == 20: opt_model = torch.optim.Adam(target_model.parameters(), lr=0.00001) num_corrects = 0 adv_num_corrects = 0 total = 0 for i, data in enumerate(train_dataloader, 0): train_imgs, train_labels = data train_imgs, train_labels = train_imgs.to(device), train_labels.to(device) #print(torch.max(train_imgs)) logits_model = target_model(train_imgs) adv_imgs = pgd.perturb(train_imgs, train_labels, itr=0) #print(torch.max(adv_imgs-train_imgs))
class AdvGAN_Attack: def __init__(self, device, model, model_num_labels, image_nc, box_min, box_max, eps, pgd_iter, models_path, out_path, model_name, writer, E_lr, advG_lr, defG_lr): output_nc = image_nc self.device = device self.model_num_labels = model_num_labels self.model = model self.input_nc = image_nc self.output_nc = output_nc self.box_min = box_min self.box_max = box_max self.eps = eps self.pgd_iter = pgd_iter self.models_path = models_path self.out_path = out_path self.model_name = model_name self.writer = writer self.E_lr = E_lr self.advG_lr = advG_lr self.defG_lr = defG_lr self.en_input_nc = image_nc self.E = models.Encoder(image_nc).to(device) self.defG = models.Generator(adv=False).to(device) self.advG = models.Generator(y_dim=model_num_labels, adv=True).to(device) self.pgd = PGD(self.model, self.E, self.defG, self.device, self.eps, step_size=self.eps / 4) # initialize all weights self.E.apply(weights_init) self.defG.apply(weights_init) self.advG.apply(weights_init) # initialize optimizers self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr) self.optimizer_advG = torch.optim.Adam(self.advG.parameters(), lr=self.advG_lr) # generate images for training def gen_images(self, x, labels, adv=True, def_adv=True, def_nat=True): results = [] # random target labels target_labels = torch.randint_like(labels, 0, self.model_num_labels) target_one_hot = torch.eye(self.model_num_labels, device=self.device)[target_labels] target_one_hot = target_one_hot.view(-1, self.model_num_labels, 1, 1) x_encoded = self.E(x) if adv or def_adv: # make adv image adv_images = self.advG(x_encoded, target_one_hot) * self.eps + x adv_images = torch.clamp(adv_images, self.box_min, self.box_max) results.append(adv_images) if def_adv: # make def(adv) image def_adv_images = self.defG(self.E(adv_images)) + adv_images def_adv_images = torch.clamp(def_adv_images, self.box_min, self.box_max) results.append(def_adv_images) if def_nat: # make def(nat) image def_images = self.defG(x_encoded) + x def_images = torch.clamp(def_images, self.box_min, self.box_max) results.append(def_images) results.append(target_labels) return results # performance tester def test(self): self.E.eval() self.defG.eval() test_full(self.device, self.model, self.E, self.defG, self.advG, self.eps, self.out_path, self.model_name, label_count=True, save_img=True) self.E.train() self.defG.train() # train single batch def train_batch(self, x, labels, batch_num): # optimize E for i in range(1): # clear grad self.optimizer_E.zero_grad() adv_images, def_adv_images, def_images, target_labels = self.gen_images( x, labels) # adv loss logits_adv = normalized_eval(adv_images, self.model) loss_adv = F.cross_entropy(logits_adv, target_labels) # def(adv) loss logits_def_adv = normalized_eval(def_adv_images, self.model) loss_def_adv = F.cross_entropy(logits_def_adv, labels) # def(nat) loss logits_def = normalized_eval(def_images, self.model) loss_def = F.cross_entropy(logits_def, labels) # backprop loss_E = loss_adv + loss_def_adv + loss_def loss_E.backward() self.optimizer_E.step() # optimize G for i in range(1): # clear grad self.optimizer_advG.zero_grad() adv_images, def_adv_images, target_labels = self.gen_images( x, labels, def_nat=False) # adv loss logits_adv = normalized_eval(adv_images, self.model) loss_adv = F.cross_entropy(logits_adv, target_labels) # def(adv) loss logits_def_adv = normalized_eval(def_adv_images, self.model) loss_def_adv = F.cross_entropy(logits_def_adv, target_labels) # backprop loss_advG = loss_adv + loss_def_adv loss_advG.backward() self.optimizer_advG.step() # optimize defG for i in range(1): # clear grad self.optimizer_defG.zero_grad() _, def_adv_images, def_images, _ = self.gen_images(x, labels, adv=False) # def(adv) loss logits_def_adv = normalized_eval(def_adv_images, self.model) loss_def_adv = F.cross_entropy(logits_def_adv, labels) # def loss logits_def = normalized_eval(def_images, self.model) loss_def = F.cross_entropy(logits_def, labels) # backprop loss_defG = loss_def_adv + loss_def loss_defG.backward() self.optimizer_defG.step() if batch_num == 0: self.E.eval() self.advG.eval() self.defG.eval() # adv, def performance check adv_pred = torch.argmax(normalized_eval(adv_images, self.model), 1) adv_correct = torch.sum(adv_pred == labels, 0) adv_acc = adv_correct.item() / len(labels) def_adv_pred = torch.argmax( normalized_eval(def_adv_images, self.model), 1) def_adv_correct = torch.sum(def_adv_pred == labels, 0) def_adv_acc = def_adv_correct.item() / len(labels) def_pred = torch.argmax(normalized_eval(def_images, self.model), 1) def_correct = torch.sum(def_pred == labels, 0) def_acc = def_correct.item() / len(labels) nat_pred = torch.argmax(normalized_eval(x, self.model), 1) nat_correct = torch.sum(nat_pred == labels, 0) nat_acc = nat_correct.item() / len(labels) print('adv mean perturbation: %.5f' % torch.abs(adv_images - x).mean().item()) print('def_adv mean perturbation: %.5f' % torch.abs(def_adv_images - x).mean().item()) print('def mean perturbation: %.5f' % torch.abs(def_images - x).mean().item()) print() # pgd performance check pgd_acc_li = [] pgd_nat_acc_li = [] for itr in self.pgd_iter: pgd_img = self.pgd.perturb(x, labels, itr=itr) pgd_nat_img = self.pgd.perturb(x, labels, itr=0) for _ in range(itr): pgd_img = self.defG(self.E(pgd_img)) + pgd_img pgd_img = torch.clamp(pgd_img, self.box_min, self.box_max) # obfuscated check pgd_nat_img = self.defG(self.E(pgd_nat_img)) + pgd_nat_img pgd_nat_img = torch.clamp(pgd_nat_img, self.box_min, self.box_max) pred = torch.argmax(normalized_eval(pgd_img, self.model), 1) num_correct = torch.sum(pred == labels, 0) pgd_acc = num_correct.item() / len(labels) pgd_acc_li.append(pgd_acc) pred = torch.argmax(normalized_eval(pgd_nat_img, self.model), 1) num_correct = torch.sum(pred == labels, 0) pgd_nat_acc = num_correct.item() / len(labels) pgd_nat_acc_li.append(pgd_nat_acc) self.defG.train() self.advG.train() self.E.train() else: pgd_acc_li = None pgd_nat_acc_li = None adv_acc = None def_adv_acc = None def_acc = None nat_acc = None return pgd_acc_li, pgd_nat_acc_li, torch.sum(loss_E).item(), torch.sum(loss_advG).item(), \ torch.sum(loss_defG).item(), \ adv_acc, def_adv_acc, def_acc, nat_acc # main training function def train(self, train_dataloader, epochs): for epoch in range(1, epochs + 1): if epoch == 50: self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr / 10) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr / 10) self.optimizer_advG = torch.optim.Adam(self.advG.parameters(), lr=self.advG_lr / 10) if epoch == 80: self.optimizer_E = torch.optim.Adam(self.E.parameters(), lr=self.E_lr / 100) self.optimizer_defG = torch.optim.Adam(self.defG.parameters(), lr=self.defG_lr / 100) self.optimizer_advG = torch.optim.Adam(self.advG.parameters(), lr=self.advG_lr / 100) loss_E_sum = 0 loss_defG_sum = 0 loss_advG_sum = 0 pgd_acc_li_sum = [] pgd_nat_acc_li_sum = [] nat_acc_sum = 0 adv_acc_sum = 0 def_adv_acc_sum = 0 def_acc_sum = 0 for i, data in enumerate(train_dataloader, start=0): images, labels = data images, labels = images.to(self.device), labels.to(self.device) pgd_acc_li_batch, pgd_nat_acc_li_batch, loss_E_batch, loss_advG_batch, loss_defG_batch, \ adv_acc, def_adv_acc, def_acc, nat_acc = \ self.train_batch(images, labels, i) loss_E_sum += loss_E_batch loss_advG_sum += loss_advG_batch loss_defG_sum += loss_defG_batch if pgd_acc_li_batch: pgd_acc_li_sum.append(pgd_acc_li_batch) pgd_nat_acc_li_sum.append(pgd_nat_acc_li_batch) nat_acc_sum += nat_acc adv_acc_sum += adv_acc def_adv_acc_sum += def_adv_acc def_acc_sum += def_acc # print statistics num_batch = len(train_dataloader) print("epoch %d:\nloss_E: %.5f, loss_advG: %.5f, loss_defG: %.5f" % (epoch, loss_E_sum / num_batch, loss_advG_sum / num_batch, loss_defG_sum / num_batch)) pgd_acc_li_sum = np.mean(np.array(pgd_acc_li_sum), axis=0) for idx in range(len(self.pgd_iter)): print("pgd iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_acc_li_sum[idx])) pgd_nat_acc_li_sum = np.mean(np.array(pgd_nat_acc_li_sum), axis=0) for idx in range(len(self.pgd_iter)): print("pgd nat iter %d acc.: %.5f" % (self.pgd_iter[idx], pgd_nat_acc_li_sum[idx])) print("nat acc.: %.5f" % (nat_acc_sum)) print("adv acc.: %.5f" % (adv_acc_sum)) print("def_adv acc.: %.5f" % (def_adv_acc_sum)) print("def acc.: %.5f" % (def_acc_sum)) print() # write to tensorboard if self.writer: self.writer.add_scalar('loss_E', loss_E_sum / num_batch, epoch) self.writer.add_scalar('loss_advG', loss_advG_sum / num_batch, epoch) self.writer.add_scalar('loss_defG', loss_defG_sum / num_batch, epoch) for idx in range(len(self.pgd_iter)): self.writer.add_scalar('pgd_acc_%d' % (self.pgd_iter[idx]), pgd_acc_li_sum[idx], epoch) self.writer.add_scalar( 'pgd_nat_acc_%d' % (self.pgd_iter[idx]), pgd_nat_acc_li_sum[idx], epoch) self.writer.add_scalar('nat_acc', nat_acc_sum, epoch) self.writer.add_scalar('adv_acc', adv_acc_sum, epoch) self.writer.add_scalar('def_adv_acc', def_adv_acc_sum, epoch) self.writer.add_scalar('def_acc', def_acc_sum, epoch) # save generator if epoch % 20 == 0: E_file_name = self.models_path + self.model_name + 'E_epoch_' + str( epoch) + '.pth' advG_file_name = self.models_path + self.model_name + 'advG_epoch_' + str( epoch) + '.pth' defG_file_name = self.models_path + self.model_name + 'defG_epoch_' + str( epoch) + '.pth' torch.save(self.E.state_dict(), E_file_name) torch.save(self.advG.state_dict(), advG_file_name) torch.save(self.defG.state_dict(), defG_file_name) if self.writer: self.writer.close() # test performance self.test()
def tester(dataset, dataloader, device, target_model, E, defG, eps, out_path, model_name, label_count=True, save_img=False): # load PGD pgd = PGD(target_model, E, defG, device) num_correct_pgd = 0 num_correct_def = 0 num_correct_def_pgd = 0 num_correct_def_pgd_nat = 0 num_correct = 0 test_img_full = [] pgd_img_full = [] pgd_nat_img_full = [] def_img_full = [] def_pgd_img_full = [] def_pgd_nat_img_full = [] pred_pgd_full = [] pred_def_pgd_full = [] pred_def_pgd_nat_full = [] for i, data in enumerate(dataloader, 0): # load images test_img, test_label = data test_img, test_label = test_img.to(device), test_label.to(device) def_noise = defG(E(test_img)) def_img = def_noise + test_img def_img = torch.clamp(def_img, 0, 1) pgd_img = pgd.perturb(test_img, test_label, itr=1) def_pgd_noise = defG(E(pgd_img)) def_pgd_img = def_pgd_noise + pgd_img def_pgd_img = torch.clamp(def_pgd_img, 0, 1) pgd_nat_img = pgd.perturb(test_img, test_label, itr=0) def_pgd_nat_noise = defG(E(pgd_nat_img)) def_pgd_nat_img = def_pgd_nat_noise + pgd_nat_img def_pgd_nat_img = torch.clamp(def_pgd_nat_img, 0, 1) # calculate acc. pred = torch.argmax(target_model(test_img), 1) pred_pgd = torch.argmax(target_model(pgd_nat_img), 1) pred_def = torch.argmax(target_model(def_img), 1) pred_def_pgd = torch.argmax(target_model(def_pgd_img), 1) pred_def_pgd_nat = torch.argmax(target_model(def_pgd_nat_img), 1) num_correct += torch.sum(pred == test_label, 0) num_correct_pgd += torch.sum(pred_pgd == test_label, 0) num_correct_def += torch.sum(pred_def == test_label, 0) num_correct_def_pgd += torch.sum(pred_def_pgd == test_label, 0) num_correct_def_pgd_nat += torch.sum(pred_def_pgd_nat == test_label, 0) if label_count: pred_pgd_full.append(pred_pgd) pred_def_pgd_full.append(pred_def_pgd) pred_def_pgd_nat_full.append(pred_def_pgd_nat) if save_img and i < 1: test_img_full.append(test_img) pgd_img_full.append(pgd_img) pgd_nat_img_full.append(pgd_nat_img) def_img_full.append(def_img) def_pgd_img_full.append(def_pgd_img) def_pgd_nat_img_full.append(def_pgd_nat_img) print('num_correct(nat): ', num_correct.item()) print('num_correct(pgd): ', num_correct_pgd.item()) print('num_correct(def(nat)): ', num_correct_def.item()) print('num_correct(def(pgd)): ', num_correct_def_pgd.item()) print('num_correct(def(pgd_nat)): ', num_correct_def_pgd_nat.item()) print() print('accuracy of nat imgs: %f' % (num_correct.item() / len(dataset))) print('accuracy of pgd imgs: %f' % (num_correct_pgd.item() / len(dataset))) print('accuracy of def(nat) imgs: %f' % (num_correct_def.item() / len(dataset))) print('accuracy of def(pgd) imgs: %f' % (num_correct_def_pgd.item() / len(dataset))) print('accuracy of def(pgd_nat) imgs: %f' % (num_correct_def_pgd_nat.item() / len(dataset))) print() l_inf = np.amax( np.abs(def_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_pgd_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def(pgd) imgs:%f' % (l_inf)) l_inf = np.amax( np.abs(def_pgd_nat_img.cpu().detach().numpy() - test_img.cpu().detach().numpy())) print('l-inf of def(pgd_nat) imgs:%f' % (l_inf)) print() if label_count: pred_pgd_full = torch.cat(pred_pgd_full) preds = pred_pgd_full.cpu().detach().numpy() print('label counts in pgd imgs:') print(np.unique(preds, return_counts=True)) pred_def_pgd_full = torch.cat(pred_def_pgd_full) preds = pred_def_pgd_full.cpu().detach().numpy() print('label counts in def_pgd imgs:') print(np.unique(preds, return_counts=True)) pred_def_pgd_nat_full = torch.cat(pred_def_pgd_nat_full) preds = pred_def_pgd_nat_full.cpu().detach().numpy() print('label counts in def_pgd_nat imgs:') print(np.unique(preds, return_counts=True)) print() if save_img: test_img_full = torch.cat(test_img_full) pgd_img_full = torch.cat(pgd_img_full) pgd_nat_img_full = torch.cat(pgd_nat_img_full) def_img_full = torch.cat(def_img_full) def_pgd_img_full = torch.cat(def_pgd_img_full) def_pgd_nat_img_full = torch.cat(def_pgd_nat_img_full) test_grid = make_grid(test_img_full) pgd_grid = make_grid(pgd_img_full) pgd_nat_grid = make_grid(pgd_nat_img_full) def_grid = make_grid(def_img_full) def_pgd_grid = make_grid(def_pgd_img_full) def_pgd_nat_grid = make_grid(def_pgd_nat_img_full) if not os.path.exists(out_path + model_name): os.makedirs(out_path + model_name) save_image(test_grid, out_path + model_name + 'test_grid.png') save_image(pgd_grid, out_path + model_name + 'pgd_grid.png') save_image(pgd_nat_grid, out_path + model_name + 'pgd_nat_grid.png') save_image(def_grid, out_path + model_name + 'def_grid.png') save_image(def_pgd_grid, out_path + model_name + 'def_pgd_grid.png') save_image(def_pgd_nat_grid, out_path + model_name + 'def_pgd_nat_grid.png') print('images saved')
trainset = torchvision.datasets.CIFAR10(root='./dataset/', train=True, download=True, transform=train_transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8 ) testset = torchvision.datasets.CIFAR10(root='./dataset/', train=False, download=True, transform=test_transform) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_bs, shuffle=False, num_workers=8) # Create model net = WideResNet(args.layers, 10 , args.widen_factor, dropRate=args.droprate) net.rot_pred = nn.Linear(128, 4) #create PGD adversary adversary = PGD( epsilon=8./255., num_steps=10, step_size=2./255.) adversary_test = PGD( epsilon=8./255., num_steps=20, step_size=1./255., attack_rotations=False) start_epoch = 0 # Restore model if desired if args.load != '': for i in range(1000 - 1, -1, -1): model_name = os.path.join(args.load, 'cifar10' + '_' + 'wrn' + '_baseline_epoch_' + str(i) + '.pt') if os.path.isfile(model_name): net.load_state_dict(torch.load(model_name)) print('Model restored! Epoch:', i) start_epoch = i + 1 break