示例#1
0
 def predict_then_update_loss_acc_meter(self, meter, data, target):
     with torch.no_grad(), ctx_eval(self.model):
         output = self.model(data)
     acc = get_accuracy(predict_from_logits(output), target)
     loss = self.loss_fn(output, target).item()
     update_loss_acc_meter(meter, loss, acc, len(data))
     return loss, acc
示例#2
0
def whitebox_attack(model, args):
    print("Using a white box attack")
    test_loader = get_test_loader(args.dataset, batch_size=args.batch_size)
    print("Model configuration")

    attack_class, attack_kwargs = extract_attack(args)
    prefix = "%s-%f" % (args.attack, args.eps)
    # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances,
    #                     clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification,
    #                     **attack_kwargs)
    # accuracy, confusion_matrix = attacker.eval()
    # print("Accuracy under attack : %f"%accuracy)
    # print('Confusion Matrix:')
    # print(np.diag(confusion_matrix))

    attackers = [
        attack_class(model, **attack_kwargs) for i in range(args.nb_restarts)
    ]
    if len(attackers) > 1:
        attacker = ChooseBestAttack(model,
                                    attackers,
                                    targeted=attackers[0].targeted)
    else:
        attacker = attackers[0]
    adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader)
    print(prefix, 'clean accuracy:', get_accuracy(pred, label))
    print(prefix, 'robust accuracy:', get_accuracy(advpred, label))
    detection_TPR = (advpred == label.max() + 1).float().mean()
    detection_FPR = (pred == label.max() + 1).float().mean()
    print(
        prefix, 'attack success rate:',
        1 - ((advpred == label) | (advpred == label.max() + 1)).float().mean())
    print(prefix, 'attack detection TPR:', detection_TPR)
    print(prefix, 'attack detection FPR:', detection_FPR)

    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % (
        args.attack, args.eps, args.nb_restarts)
    torch.save(
        {
            'args': dict(vars(args)),
            'data': adv,
            'preds': advpred,
            'clean_preds': pred,
            'labels': label
        }, outfile)
def transfer_attack(model, args):
    # args.dataset must be path to a that file loadable by torch.load and that contains a dictionary:
    # {
    #   data: (adversarially perturbed) data samples,
    #   preds: the predictions of the source model on the data
    #   labels: the true labels of the data
    # }
    print('Running transfer attack...')
    print('source:', args.dataset)
    print('target:', args.model_path)

    source_data = torch.load(args.dataset)
    loader = DataLoader(source_data['data'],
                        batch_size=args.batch_size,
                        shuffle=False)

    preds = []
    for x_adv in loader:
        x_adv = x_adv.cuda()
        logits = model(x_adv)
        preds.append(logits.argmax(1))
    preds = torch.cat(preds)

    print('accuracy:', get_accuracy(preds, source_data['labels']))
    print('agreement:', get_accuracy(preds, source_data['preds']))

    outfile = "logs/transfer_attack_outputs/%s/%s.pt" % (os.path.basename(
        args.model_path).split('.')[0], os.path.basename(args.dataset))
    if not os.path.exists(os.path.dirname(outfile)):
        os.makedirs(os.path.dirname(outfile))

    torch.save(
        {
            'sourc_attack_args': source_data['args'],
            'source_adv_data': source_data['data'],
            'source_preds': source_data['preds'],
            'target_preds': preds,
            'labels': source_data['labels']
        }, outfile)
示例#4
0
    def train_one_epoch(self):
        _bgn_epoch = time.time()
        if self.verbose:
            print("Training epoch {}".format(self.epochs))
        self.model.train()
        self.model.to(self.device)
        self.reset_epoch_meters()
        self.reset_disp_meters()

        _train_time = 0.

        for batch_idx, (data, idx) in enumerate(self.loader):
            data, idx = data.to(self.device), idx.to(self.device)
            target = self.loader.targets[idx]

            _bgn_train = time.time()
            clnoutput, clnloss, eps = self.train_one_batch(data, idx, target)
            _train_time = _train_time + (time.time() - _bgn_train)

            clnacc = get_accuracy(predict_from_logits(clnoutput), target)
            update_loss_acc_meter(self.cln_meter, clnloss.item(), clnacc,
                                  len(data))
            update_eps_meter(self.eps_meter, eps.mean().item(), len(data))

            if self.disp_interval is not None and \
                    batch_idx % self.disp_interval == 0:
                self.print_disp_meters(batch_idx)
                self.reset_disp_meters()

            if self.steps == self.max_steps:
                self.stop_training()
                break

        self.print_disp_meters(batch_idx)
        self.disp_eps_hist()
        self.epochs += 1
        self._adjust_lr_by_epochs()

        print("total epoch time", time.time() - _bgn_epoch)
        print("training total time", _train_time)
    else:
        raise

    base_adversaries = generate_adversaries(attack_class,
                                            args.nb_restart,
                                            predict=model,
                                            eps=args.eps,
                                            nb_iter=args.nb_iter,
                                            eps_iter=args.eps_iter,
                                            rand_init=True)

    adversary = ChooseBestAttack(model, base_adversaries)

    adv, label, pred, advpred = attack_whole_dataset(adversary,
                                                     test_loader,
                                                     device=args.device)

    print(get_accuracy(advpred, label))
    print(get_accuracy(advpred, pred))

    torch.save({"adv": adv},
               os.path.join(os.path.dirname(args.model),
                            "advdata_eps-{}.pt".format(args.eps)))
    torch.save({
        "label": label,
        "pred": pred,
        "advpred": advpred
    },
               os.path.join(os.path.dirname(args.model),
                            "advlabel_eps-{}.pt".format(args.eps)))
示例#6
0
    if args.norm == "Linf":
        attack_class = LinfPGDAttack
    elif args.norm == "L2":
        attack_class = L2PGDAttack
    elif args.norm == "none":
        attack_class = NullAdversary
    else:
        raise
    
    base_adversaries = generate_adversaries(
        attack_class, args.nb_restart, predict=model, eps=args.eps,
        nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True)

    adversary = ChooseBestAttack(model, base_adversaries)


    adv, label, pred, advpred = attack_whole_dataset(
        adversary, test_loader, device=args.device)

    print('clean accuracy:',get_accuracy(pred, label))
    print('robust accuracy:',get_accuracy(advpred, label))
    print(get_accuracy(advpred, pred))

    torch.save({"adv": adv}, os.path.join(
        os.path.dirname(args.model_path), "advdata_eps-{}.pt".format(args.eps)))
    torch.save(
        {"label": label, "pred": pred, "advpred": advpred},
        os.path.join(os.path.dirname(args.model_path),
                     "advlabel_eps-{}.pt".format(args.eps)))
def whitebox_attack(model, args):
    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % (
        args.attack, args.eps, args.nb_restarts)
    # if os.path.exists(outfile):
    #     return

    print("Using a white box attack")
    if args.use_train_data:
        train_dataset, val_dataset, test_dataset, nclasses = get_cifar10_dataset(
            args.datafolder, [torchvision.transforms.ToTensor()] * 2)
        rand_idx = np.arange(len(train_dataset))[:10000]
        train_dataset = Subset(train_dataset, rand_idx)
        print(len(train_dataset))
        test_loader = DataLoader(train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    else:
        test_loader = get_test_loader(args.dataset, batch_size=args.batch_size)
    print("Model configuration")

    attack_class, attack_kwargs = extract_attack(args)
    prefix = "%s-%f" % (args.attack,
                        args.conf if args.attack == "cwl2" else args.eps)
    # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances,
    #                     clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification,
    #                     **attack_kwargs)
    # accuracy, confusion_matrix = attacker.eval()
    # print("Accuracy under attack : %f"%accuracy)
    # print('Confusion Matrix:')
    # print(np.diag(confusion_matrix))

    attackers = [
        attack_class(model, **attack_kwargs) for i in range(args.nb_restarts)
    ]
    if len(attackers) > 1:
        attacker = ChooseBestAttack(model,
                                    attackers,
                                    targeted=attackers[0].targeted)
    else:
        attacker = attackers[0]
    adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader)
    print(prefix, 'clean accuracy:', get_accuracy(pred, label))
    print(prefix, 'robust accuracy:', get_accuracy(advpred, label))
    detection_TPR = (advpred == label.max() + 1).float().mean()
    detection_FPR = (pred == label.max() + 1).float().mean()
    print(prefix, 'attack success rate:',
          ((pred == label) & (advpred != label)).float().mean())
    print(prefix, 'attack detection TPR:', detection_TPR)
    print(prefix, 'attack detection FPR:', detection_FPR)

    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts' % (
        args.attack, args.conf if args.attack == "cwl2" else args.eps,
        args.nb_restarts)
    if args.use_train_data:
        outfile += '_trainset'
    outfile += '.pt'
    torch.save(
        {
            'args': dict(vars(args)),
            'data': adv,
            'preds': advpred,
            'clean_preds': pred,
            'labels': label
        }, outfile)
示例#8
0
def train(model,
          train_dataset,
          test_dataset,
          nclasses,
          adversary,
          args,
          val_dataset=None,
          mLogger=None):
    print(mLogger)
    if mLogger is not None:
        logger = mLogger
    if val_dataset is None:
        new_train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - new_train_size
        train_dataset, val_dataset = random_split(train_dataset,
                                                  [new_train_size, val_size])

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              num_workers=(cpu_count()) // 2)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=(cpu_count()) // 2)

    criterion = utils.loss_wrapper(args.C)

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(get_trainable_params(model),
                                    lr=args.lr,
                                    weight_decay=5e-4,
                                    momentum=0.9,
                                    nesterov=True)
    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(get_trainable_params(model),
                                     lr=args.lr,
                                     weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=args.patience, factor=0.2)

    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=(cpu_count()) // 2)
    test_label_counts = utils.label_counts(test_loader, nclasses)
    test_correct = evaluate(model, test_loader, test_label_counts)
    test_acc = np.sum(test_correct) / np.sum(test_label_counts)
    print('test_accuracy:', test_acc)
    logger.info('test_accuracy = %0.3f' % test_acc)

    val_label_counts = utils.label_counts(val_loader, nclasses)
    bad_iters = 0
    for i in range(args.nepochs):
        epoch_loss = 0
        epoch_correct = 0
        epoch_count = 0
        t = tqdm(enumerate(train_loader))
        t.set_description('epoch#%d' % i)
        for j, batch in t:
            x, y = batch
            x = x.cuda()
            y = y.cuda()

            if args.gaussian_smoothing:
                eps = torch.normal(mean=0, std=args.sigma, size=x.shape).cuda()
                x += eps
            else:
                flips = np.random.binomial(1, 0.5, size=x.shape[0])
                flips = flips == 1
                x[flips] = adversary.perturb(x[flips], y[flips])

            train_loss, train_correct = train_on_batch(model, (x, y),
                                                       optimizer, criterion)
            epoch_loss += train_loss
            epoch_correct += train_correct
            epoch_count += x.shape[0]
            t.set_postfix(loss=epoch_loss / ((j + 1) * args.batch_size),
                          accuracy=epoch_correct / (epoch_count),
                          lr=optimizer.param_groups[0]['lr'])
        epoch_loss /= len(train_dataset)
        epoch_acc = epoch_correct / len(train_dataset)

        # val_correct = evaluate(model, val_loader, val_label_counts)
        # val_acc = np.mean(val_correct / val_label_counts)
        # print('val_accuracy:', val_acc, )

        adv, label, pred, advpred = attack_whole_dataset(adversary, val_loader)
        val_acc = get_accuracy(pred, label)
        adv_acc = get_accuracy(advpred, label)
        print('clean val accuracy:', val_acc)
        print('robust val accuracy:', adv_acc)

        if i == 0 or scheduler.is_better(val_acc, scheduler.best):
            with open(args.outfile, 'wb') as f:
                torch.save(model, f)
            bad_iters = 0
        else:
            bad_iters += 1
        if bad_iters >= 3 * args.patience:
            print('early stopping...')
            break
        scheduler.step(adv_acc)

        logger.info(
            'epoch#%d train_loss=%.3f train_acc=%.3f val_acc=%.3f lr=%.4f' %
            (i, epoch_loss, epoch_acc, val_acc,
             optimizer.param_groups[0]['lr']))

    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=(cpu_count()) // 2)
    model = torch.load(args.outfile)
    test_label_counts = utils.label_counts(test_loader, nclasses)
    test_correct = evaluate(model, test_loader, test_label_counts)
    test_acc = np.sum(test_correct) / np.sum(test_label_counts)
    print('test_accuracy:', test_acc)
    logger.info('test_accuracy = %0.3f' % test_acc)

    adv, label, pred, advpred = attack_whole_dataset(adversary, test_loader)
    test_acc = get_accuracy(pred, label)
    print('clean test accuracy:', test_acc)
    print('robust test accuracy:', get_accuracy(advpred, label))