Пример #1
0
def whitebox_attack(model, args):
    print("Using a white box attack")
    test_loader = get_test_loader(args.dataset, batch_size=args.batch_size)
    print("Model configuration")

    attack_class, attack_kwargs = extract_attack(args)
    prefix = "%s-%f" % (args.attack, args.eps)
    # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances,
    #                     clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification,
    #                     **attack_kwargs)
    # accuracy, confusion_matrix = attacker.eval()
    # print("Accuracy under attack : %f"%accuracy)
    # print('Confusion Matrix:')
    # print(np.diag(confusion_matrix))

    attackers = [
        attack_class(model, **attack_kwargs) for i in range(args.nb_restarts)
    ]
    if len(attackers) > 1:
        attacker = ChooseBestAttack(model,
                                    attackers,
                                    targeted=attackers[0].targeted)
    else:
        attacker = attackers[0]
    adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader)
    print(prefix, 'clean accuracy:', get_accuracy(pred, label))
    print(prefix, 'robust accuracy:', get_accuracy(advpred, label))
    detection_TPR = (advpred == label.max() + 1).float().mean()
    detection_FPR = (pred == label.max() + 1).float().mean()
    print(
        prefix, 'attack success rate:',
        1 - ((advpred == label) | (advpred == label.max() + 1)).float().mean())
    print(prefix, 'attack detection TPR:', detection_TPR)
    print(prefix, 'attack detection FPR:', detection_FPR)

    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % (
        args.attack, args.eps, args.nb_restarts)
    torch.save(
        {
            'args': dict(vars(args)),
            'data': adv,
            'preds': advpred,
            'clean_preds': pred,
            'labels': label
        }, outfile)
Пример #2
0
def get_data_loaders(cfg):
    if cfg.dataset.upper() == "MNIST":
        train_transform = None
    elif cfg.dataset.upper() == "CIFAR10":
        train_transform = get_madry_et_al_cifar10_train_transform()
    else:
        raise ValueError(cfg.dataset)

    datasetname = cfg.dataset.upper()
    train_loader, val_loader = get_train_val_loaders(
        datasetname,
        train_size=cfg.train_size,
        val_size=cfg.val_size,
        train_batch_size=cfg.training_batch_size,
        val_batch_size=100,
        train_transform=train_transform,
    )
    test_loader = get_test_loader(datasetname,
                                  test_size=cfg.test_size,
                                  batch_size=100)

    return train_loader, val_loader, test_loader
    model.eval()

    print("model loaded")

    if args.dataset.upper() == "CIFAR10" and args.norm == "Linf" \
            and args.eps > 1.:
        args.eps = round(args.eps / 255., 4)

    if args.eps_iter is None:
        if args.dataset.upper() == "MNIST" and args.norm == "Linf":
            args.eps_iter = args.eps / 40.
        else:
            args.eps_iter = args.eps / 4.

    test_loader = get_test_loader(args.dataset.upper(),
                                  test_size=args.test_size,
                                  batch_size=100)

    if args.norm == "Linf":
        attack_class = LinfPGDAttack
    elif args.norm == "L2":
        attack_class = L2PGDAttack
    else:
        raise

    base_adversaries = generate_adversaries(attack_class,
                                            args.nb_restart,
                                            predict=model,
                                            eps=args.eps,
                                            nb_iter=args.nb_iter,
                                            eps_iter=args.eps_iter,
Пример #4
0
    model.eval()

    print("model loaded")

    if args.dataset.upper() == "CIFAR10" and args.norm == "Linf" \
            and args.eps > 1.:
        args.eps = round(args.eps / 255., 4)


    if args.eps_iter is None:
        if args.dataset.upper() == "MNIST" and args.norm == "Linf":
            args.eps_iter = args.eps / 40.
        else:
            args.eps_iter = args.eps / 4.

    test_loader = get_test_loader(args.dataset, batch_size=128) # model.test_dataloader()

    if args.norm == "Linf":
        attack_class = LinfPGDAttack
    elif args.norm == "L2":
        attack_class = L2PGDAttack
    elif args.norm == "none":
        attack_class = NullAdversary
    else:
        raise
    
    base_adversaries = generate_adversaries(
        attack_class, args.nb_restart, predict=model, eps=args.eps,
        nb_iter=args.nb_iter, eps_iter=args.eps_iter, rand_init=True)

    adversary = ChooseBestAttack(model, base_adversaries)
def whitebox_attack(model, args):
    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts.pt' % (
        args.attack, args.eps, args.nb_restarts)
    # if os.path.exists(outfile):
    #     return

    print("Using a white box attack")
    if args.use_train_data:
        train_dataset, val_dataset, test_dataset, nclasses = get_cifar10_dataset(
            args.datafolder, [torchvision.transforms.ToTensor()] * 2)
        rand_idx = np.arange(len(train_dataset))[:10000]
        train_dataset = Subset(train_dataset, rand_idx)
        print(len(train_dataset))
        test_loader = DataLoader(train_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)
    else:
        test_loader = get_test_loader(args.dataset, batch_size=args.batch_size)
    print("Model configuration")

    attack_class, attack_kwargs = extract_attack(args)
    prefix = "%s-%f" % (args.attack,
                        args.conf if args.attack == "cwl2" else args.eps)
    # attacker = Attacker(model,test_loader, attack_class=attack_class, max_instances=args.max_instances,
    #                     clip_min=0., clip_max=1., targeted=False, binary_classification=args.binary_classification,
    #                     **attack_kwargs)
    # accuracy, confusion_matrix = attacker.eval()
    # print("Accuracy under attack : %f"%accuracy)
    # print('Confusion Matrix:')
    # print(np.diag(confusion_matrix))

    attackers = [
        attack_class(model, **attack_kwargs) for i in range(args.nb_restarts)
    ]
    if len(attackers) > 1:
        attacker = ChooseBestAttack(model,
                                    attackers,
                                    targeted=attackers[0].targeted)
    else:
        attacker = attackers[0]
    adv, label, pred, advpred = attack_whole_dataset(attacker, test_loader)
    print(prefix, 'clean accuracy:', get_accuracy(pred, label))
    print(prefix, 'robust accuracy:', get_accuracy(advpred, label))
    detection_TPR = (advpred == label.max() + 1).float().mean()
    detection_FPR = (pred == label.max() + 1).float().mean()
    print(prefix, 'attack success rate:',
          ((pred == label) & (advpred != label)).float().mean())
    print(prefix, 'attack detection TPR:', detection_TPR)
    print(prefix, 'attack detection FPR:', detection_FPR)

    outfile = args.model_path + 'advdata_%s_eps=%f_%drestarts' % (
        args.attack, args.conf if args.attack == "cwl2" else args.eps,
        args.nb_restarts)
    if args.use_train_data:
        outfile += '_trainset'
    outfile += '.pt'
    torch.save(
        {
            'args': dict(vars(args)),
            'data': adv,
            'preds': advpred,
            'clean_preds': pred,
            'labels': label
        }, outfile)
def main(args):
    model = torch.load(args.model, map_location=torch.device('cpu'))
    if args.layer_idx is not None and not hasattr(model, 'layers'):
        model = get_layers(model, args.normalize_input)
        fe_model = model[:args.layer_idx+1].to(args.device).eval()
        model = model[args.layer_idx+1:]
        print(fe_model)
        print(model)
    elif args.normalize_input:
        model = NormalizationWrapper(model)

    args.attack = args.denoising_attack    
    adversaries = []
    eps = args.eps
    eps_iter = args.eps_iter
    for e in eps:            
        args.eps = e
        if eps_iter is None:
            args.eps_iter = e / 10
        else:
            args.eps_iter = eps_iter
        attack_class,attack_kwargs = extract_attack(args)
        if args.uniform_anpgd_target:
            attack_kwargs['loss_fn'] = uniform_distribution_loss
        else:
            attack_kwargs['loss_fn'] = torch.nn.CrossEntropyLoss(reduction='sum')
        if args.layer_idx is not None:
            attack_kwargs['clip_min'] = -np.inf
            attack_kwargs['clip_max'] = np.inf
        adversaries.append(attack_class(model, targeted=args.target_closest_boundary, **attack_kwargs))    

    max_eps = 0.08
    adversaries = [ANPGDWrapper(attack, max_eps, 20) for attack in adversaries]
    args.eps = eps
    adversaries = adversaries*args.nb_adversaries

    # print(vars(adversaries[0]))    
    model = AdversarialDenoisingWrapper(model, 
                                        adversaries, 
                                        args.consensus_pc,
                                        args.target_closest_boundary)
    model = model.to(args.device)
    model = model.eval()
    
    data = [torch.load(df) for df in args.data_file]
    adv_data = [d['data'] for d in data]    
    adv_data = torch.stack(adv_data, dim=1).to(torch.device('cpu'))
    adv_preds = [d['preds'] for d in data]    
    adv_preds = torch.stack(adv_preds, dim=1).to(torch.device('cpu'))
    cln_preds = [d['clean_preds'] for d in data]
    cln_preds = torch.stack(cln_preds, dim=1).to(torch.device('cpu'))
    labels = [d['labels'] for d in data]    
    labels = torch.stack(labels, dim=1).to(torch.device('cpu'))    
    adv_loader = torch.utils.data.DataLoader(list(zip(adv_data, adv_preds, cln_preds, labels)), batch_size=args.batch_size, shuffle=False)
    clean_loader = get_test_loader(args.dataset.upper(), batch_size=args.batch_size)
    
    # adv, label, pred, advpred = attack_whole_dataset(attack_class(model.model, **attack_kwargs), clean_loader)
    # print('clean accuracy:',(pred == label).float().mean())
    # print('robust accuracy:',(advpred == label).float().mean())
    
    margins = []
    adv_logits = []
    cln_logits = []
    adv_deltas = []
    cln_deltas = []
    adv_preds = []
    cln_preds = []
    norm_inf_delta = []
    norm_2_delta = []
    labels = []
    
    t = tqdm(zip(adv_loader, clean_loader))
    for (xadv, advpred, clnpred, L), (x, y) in t:        
        if not (L == y.view(-1,1)).all():
            print('label mismatch')
            print(L)
            print(y)
            exit(0)
        if ((clnpred == y.view(-1,1)) & (advpred != y.view(-1,1))).all(1).any():
            idx_mask = ((clnpred == y.view(-1,1)) & (advpred != y.view(-1,1))).all(1)           
            xadv = xadv[idx_mask]
            x = x[idx_mask]
            y = y[idx_mask]
        else:
            print('skipping instance...')
            continue
        xadv = xadv.to(args.device)
        x = x.to(args.device)
        y = y.to(args.device)

        xadv_shape = xadv.shape
        xadv = xadv.view(xadv.shape[0]*xadv.shape[1], *(xadv.shape[2:]))
        if args.layer_idx is not None:
            x = fe_model(x)
            xadv = fe_model(xadv)

        adv_votes, preds = model(xadv) 
        adv_logit, adv_pred = torch.max(adv_votes, dim=1)
        adv_logit = adv_logit.view(xadv_shape[0],xadv_shape[1], *(adv_logit.shape[1:]))
        adv_pred = adv_pred.view(xadv_shape[0],xadv_shape[1], *(adv_pred.shape[1:]))

        d = np.stack(model.deltas, axis=1)
        d = d.reshape(xadv_shape[0],xadv_shape[1], *(d.shape[1:]))
        adv_deltas.append(d)
        
        adv_votes, preds = model(x)
        cln_logit, cln_pred = torch.max(adv_votes, dim=1)
        cln_deltas.append(np.stack(model.deltas, axis=1))
        adv_logits.append(adv_logit.detach().cpu().numpy())
        cln_logits.append(cln_logit.detach().cpu().numpy())
        adv_preds.append(adv_pred.detach().cpu().numpy())
        cln_preds.append(cln_pred.detach().cpu().numpy())
        labels.append(y.cpu().numpy())

        if len(labels) >= 2500/args.batch_size:
            break

    adv_preds = np.concatenate(adv_preds, axis=0)
    cln_preds = np.concatenate(cln_preds, axis=0)
    adv_deltas = np.concatenate(adv_deltas, axis=0)
    cln_deltas = np.concatenate(cln_deltas, axis=0)
    labels = np.concatenate(labels, axis=0)
    adv_logits = np.concatenate(adv_logits, axis=0)
    cln_logits = np.concatenate(cln_logits, axis=0)  

    outfile = '%s_advDenoising_%d-%s-eps=%s_consesus-pc=%.1f' % (args.model, len(adversaries), args.attack, '-'.join(["%.4f" % e for e in args.eps]), args.consensus_pc)
    if args.target_closest_boundary:
        outfile += '_targeted'
    if args.uniform_anpgd_target:
        outfile += '_uniformTargets'
    if args.layer_idx is not None:
        outfile += '_layer-%d' % args.layer_idx
    outfile += '_data.pkl'
    with open(outfile, 'wb') as f:
        pickle.dump([adv_preds,
                        cln_preds,
                        adv_deltas,
                        cln_deltas,
                        labels,
                        adv_logits,
                        cln_logits], f)

    print(adv_preds.shape, adv_logits.shape, adv_deltas.shape, cln_deltas.shape)
    exit()