def predict_then_update_loss_acc_meter(self, meter, data, target): with torch.no_grad(), ctx_eval(self.model): output = self.model(data) acc = get_accuracy(predict_from_logits(output), target) loss = self.loss_fn(output, target).item() update_loss_acc_meter(meter, loss, acc, len(data)) return loss, acc
def whitebox_attack(model, args): print("Using a white box attack") test_loader = get_test_loader(args.dataset, batch_size=args.batch_size) print("Model configuration") attack_class, attack_kwargs = extract_attack(args) prefix = "%s-%f" % (args.attack, args.eps) # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances, # clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification, # **attack_kwargs) # accuracy, confusion_matrix = attacker.eval() # print("Accuracy under attack : %f"%accuracy) # print('Confusion Matrix:') # print(np.diag(confusion_matrix)) attackers = [ attack_class(model, **attack_kwargs) for i in range(args.nb_restarts) ] if len(attackers) > 1: attacker = ChooseBestAttack(model, attackers, targeted=attackers[0].targeted) else: attacker = attackers[0] adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader) print(prefix, 'clean accuracy:', get_accuracy(pred, label)) print(prefix, 'robust accuracy:', get_accuracy(advpred, label)) detection_TPR = (advpred == label.max() + 1).float().mean() detection_FPR = (pred == label.max() + 1).float().mean() print( prefix, 'attack success rate:', 1 - ((advpred == label) | (advpred == label.max() + 1)).float().mean()) print(prefix, 'attack detection TPR:', detection_TPR) print(prefix, 'attack detection FPR:', detection_FPR) outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % ( args.attack, args.eps, args.nb_restarts) torch.save( { 'args': dict(vars(args)), 'data': adv, 'preds': advpred, 'clean_preds': pred, 'labels': label }, outfile)
def transfer_attack(model, args): # args.dataset must be path to a that file loadable by torch.load and that contains a dictionary: # { # data: (adversarially perturbed) data samples, # preds: the predictions of the source model on the data # labels: the true labels of the data # } print('Running transfer attack...') print('source:', args.dataset) print('target:', args.model_path) source_data = torch.load(args.dataset) loader = DataLoader(source_data['data'], batch_size=args.batch_size, shuffle=False) preds = [] for x_adv in loader: x_adv = x_adv.cuda() logits = model(x_adv) preds.append(logits.argmax(1)) preds = torch.cat(preds) print('accuracy:', get_accuracy(preds, source_data['labels'])) print('agreement:', get_accuracy(preds, source_data['preds'])) outfile = "logs/transfer_attack_outputs/%s/%s.pt" % (os.path.basename( args.model_path).split('.')[0], os.path.basename(args.dataset)) if not os.path.exists(os.path.dirname(outfile)): os.makedirs(os.path.dirname(outfile)) torch.save( { 'sourc_attack_args': source_data['args'], 'source_adv_data': source_data['data'], 'source_preds': source_data['preds'], 'target_preds': preds, 'labels': source_data['labels'] }, outfile)
def train_one_epoch(self): _bgn_epoch = time.time() if self.verbose: print("Training epoch {}".format(self.epochs)) self.model.train() self.model.to(self.device) self.reset_epoch_meters() self.reset_disp_meters() _train_time = 0. for batch_idx, (data, idx) in enumerate(self.loader): data, idx = data.to(self.device), idx.to(self.device) target = self.loader.targets[idx] _bgn_train = time.time() clnoutput, clnloss, eps = self.train_one_batch(data, idx, target) _train_time = _train_time + (time.time() - _bgn_train) clnacc = get_accuracy(predict_from_logits(clnoutput), target) update_loss_acc_meter(self.cln_meter, clnloss.item(), clnacc, len(data)) update_eps_meter(self.eps_meter, eps.mean().item(), len(data)) if self.disp_interval is not None and \ batch_idx % self.disp_interval == 0: self.print_disp_meters(batch_idx) self.reset_disp_meters() if self.steps == self.max_steps: self.stop_training() break self.print_disp_meters(batch_idx) self.disp_eps_hist() self.epochs += 1 self._adjust_lr_by_epochs() print("total epoch time", time.time() - _bgn_epoch) print("training total time", _train_time)
else: raise base_adversaries = generate_adversaries(attack_class, args.nb_restart, predict=model, eps=args.eps, nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True) adversary = ChooseBestAttack(model, base_adversaries) adv, label, pred, advpred = attack_whole_dataset(adversary, test_loader, device=args.device) print(get_accuracy(advpred, label)) print(get_accuracy(advpred, pred)) torch.save({"adv": adv}, os.path.join(os.path.dirname(args.model), "advdata_eps-{}.pt".format(args.eps))) torch.save({ "label": label, "pred": pred, "advpred": advpred }, os.path.join(os.path.dirname(args.model), "advlabel_eps-{}.pt".format(args.eps)))
if args.norm == "Linf": attack_class = LinfPGDAttack elif args.norm == "L2": attack_class = L2PGDAttack elif args.norm == "none": attack_class = NullAdversary else: raise base_adversaries = generate_adversaries( attack_class, args.nb_restart, predict=model, eps=args.eps, nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True) adversary = ChooseBestAttack(model, base_adversaries) adv, label, pred, advpred = attack_whole_dataset( adversary, test_loader, device=args.device) print('clean accuracy:',get_accuracy(pred, label)) print('robust accuracy:',get_accuracy(advpred, label)) print(get_accuracy(advpred, pred)) torch.save({"adv": adv}, os.path.join( os.path.dirname(args.model_path), "advdata_eps-{}.pt".format(args.eps))) torch.save( {"label": label, "pred": pred, "advpred": advpred}, os.path.join(os.path.dirname(args.model_path), "advlabel_eps-{}.pt".format(args.eps)))
def whitebox_attack(model, args): outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % ( args.attack, args.eps, args.nb_restarts) # if os.path.exists(outfile): # return print("Using a white box attack") if args.use_train_data: train_dataset, val_dataset, test_dataset, nclasses = get_cifar10_dataset( args.datafolder, [torchvision.transforms.ToTensor()] * 2) rand_idx = np.arange(len(train_dataset))[:10000] train_dataset = Subset(train_dataset, rand_idx) print(len(train_dataset)) test_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False) else: test_loader = get_test_loader(args.dataset, batch_size=args.batch_size) print("Model configuration") attack_class, attack_kwargs = extract_attack(args) prefix = "%s-%f" % (args.attack, args.conf if args.attack == "cwl2" else args.eps) # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances, # clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification, # **attack_kwargs) # accuracy, confusion_matrix = attacker.eval() # print("Accuracy under attack : %f"%accuracy) # print('Confusion Matrix:') # print(np.diag(confusion_matrix)) attackers = [ attack_class(model, **attack_kwargs) for i in range(args.nb_restarts) ] if len(attackers) > 1: attacker = ChooseBestAttack(model, attackers, targeted=attackers[0].targeted) else: attacker = attackers[0] adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader) print(prefix, 'clean accuracy:', get_accuracy(pred, label)) print(prefix, 'robust accuracy:', get_accuracy(advpred, label)) detection_TPR = (advpred == label.max() + 1).float().mean() detection_FPR = (pred == label.max() + 1).float().mean() print(prefix, 'attack success rate:', ((pred == label) & (advpred != label)).float().mean()) print(prefix, 'attack detection TPR:', detection_TPR) print(prefix, 'attack detection FPR:', detection_FPR) outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts' % ( args.attack, args.conf if args.attack == "cwl2" else args.eps, args.nb_restarts) if args.use_train_data: outfile += '_trainset' outfile += '.pt' torch.save( { 'args': dict(vars(args)), 'data': adv, 'preds': advpred, 'clean_preds': pred, 'labels': label }, outfile)
def train(model, train_dataset, test_dataset, nclasses, adversary, args, val_dataset=None, mLogger=None): print(mLogger) if mLogger is not None: logger = mLogger if val_dataset is None: new_train_size = int(0.8 * len(train_dataset)) val_size = len(train_dataset) - new_train_size train_dataset, val_dataset = random_split(train_dataset, [new_train_size, val_size]) train_loader = DataLoader(train_dataset, args.batch_size, num_workers=(cpu_count()) // 2) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=(cpu_count()) // 2) criterion = utils.loss_wrapper(args.C) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(get_trainable_params(model), lr=args.lr, weight_decay=5e-4, momentum=0.9, nesterov=True) if args.optimizer == 'adam': optimizer = torch.optim.Adam(get_trainable_params(model), lr=args.lr, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', patience=args.patience, factor=0.2) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=(cpu_count()) // 2) test_label_counts = utils.label_counts(test_loader, nclasses) test_correct = evaluate(model, test_loader, test_label_counts) test_acc = np.sum(test_correct) / np.sum(test_label_counts) print('test_accuracy:', test_acc) logger.info('test_accuracy = %0.3f' % test_acc) val_label_counts = utils.label_counts(val_loader, nclasses) bad_iters = 0 for i in range(args.nepochs): epoch_loss = 0 epoch_correct = 0 epoch_count = 0 t = tqdm(enumerate(train_loader)) t.set_description('epoch#%d' % i) for j, batch in t: x, y = batch x = x.cuda() y = y.cuda() if args.gaussian_smoothing: eps = torch.normal(mean=0, std=args.sigma, size=x.shape).cuda() x += eps else: flips = np.random.binomial(1, 0.5, size=x.shape[0]) flips = flips == 1 x[flips] = adversary.perturb(x[flips], y[flips]) train_loss, train_correct = train_on_batch(model, (x, y), optimizer, criterion) epoch_loss += train_loss epoch_correct += train_correct epoch_count += x.shape[0] t.set_postfix(loss=epoch_loss / ((j + 1) * args.batch_size), accuracy=epoch_correct / (epoch_count), lr=optimizer.param_groups[0]['lr']) epoch_loss /= len(train_dataset) epoch_acc = epoch_correct / len(train_dataset) # val_correct = evaluate(model, val_loader, val_label_counts) # val_acc = np.mean(val_correct / val_label_counts) # print('val_accuracy:', val_acc, ) adv, label, pred, advpred = attack_whole_dataset(adversary, val_loader) val_acc = get_accuracy(pred, label) adv_acc = get_accuracy(advpred, label) print('clean val accuracy:', val_acc) print('robust val accuracy:', adv_acc) if i == 0 or scheduler.is_better(val_acc, scheduler.best): with open(args.outfile, 'wb') as f: torch.save(model, f) bad_iters = 0 else: bad_iters += 1 if bad_iters >= 3 * args.patience: print('early stopping...') break scheduler.step(adv_acc) logger.info( 'epoch#%d train_loss=%.3f train_acc=%.3f val_acc=%.3f lr=%.4f' % (i, epoch_loss, epoch_acc, val_acc, optimizer.param_groups[0]['lr'])) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=(cpu_count()) // 2) model = torch.load(args.outfile) test_label_counts = utils.label_counts(test_loader, nclasses) test_correct = evaluate(model, test_loader, test_label_counts) test_acc = np.sum(test_correct) / np.sum(test_label_counts) print('test_accuracy:', test_acc) logger.info('test_accuracy = %0.3f' % test_acc) adv, label, pred, advpred = attack_whole_dataset(adversary, test_loader) test_acc = get_accuracy(pred, label) print('clean test accuracy:', test_acc) print('robust test accuracy:', get_accuracy(advpred, label))