コード例 #1
0
def main():
    # create dataloader
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    data_set = SubsetImageNet(root=args.input_dir, transform=transform_test)
    data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=False, **kwargs)

    # create models
    net = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
    model = nn.Sequential(Normalize(mean=net.mean, std=net.std), net)
    model = model.to(device)
    model.eval()

    # create adversary attack
    epsilon = args.epsilon / 255.0
    if args.step_size < 0:
        step_size = epsilon / args.num_steps
    else:
        step_size = args.step_size / 255.0

    # if args.gamma < 1.0:
    #     print('using our method')
    #     register_hook(model, args.arch, args.gamma, is_conv=args.is_conv)

    # using our method - Skip Gradient Method (SGM)
    if args.gamma < 1.0:
        if args.arch in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']:
            register_hook_for_resnet(model, arch=args.arch, gamma=args.gamma)
        elif args.arch in ['densenet121', 'densenet169', 'densenet201']:
            register_hook_for_densenet(model, arch=args.arch, gamma=args.gamma)
        else:
            raise ValueError('Current code only supports resnet/densenet. '
                             'You can extend this code to other architectures.')

    if args.momentum > 0.0:
        print('using PGD attack with momentum = {}'.format(args.momentum))
        adversary = MomentumIterativeAttack(predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                            eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size,
                                            decay_factor=args.momentum,
                                            clip_min=0.0, clip_max=1.0, targeted=False)
    else:
        print('using linf PGD attack')
        adversary = LinfPGDAttack(predict=model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                  eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size,
                                  rand_init=False, clip_min=0.0, clip_max=1.0, targeted=False)

    generate_adversarial_example(model=model, data_loader=data_loader,
                                 adversary=adversary, img_path=data_set.img_path)
コード例 #2
0
def generate(datasetname, batch_size):
    save_dir_path = "{}/data_adv_defense/guided_denoiser".format(PY_ROOT)
    os.makedirs(save_dir_path, exist_ok=True)
    set_log_file(save_dir_path + "/generate_{}.log".format(datasetname))
    data_loader = DataLoaderMaker.get_img_label_data_loader(datasetname, batch_size, is_train=True)
    attackers = []
    for model_name in MODELS_TRAIN_STANDARD[datasetname] + MODELS_TEST_STANDARD[datasetname]:
        model = StandardModel(datasetname, model_name, no_grad=False)
        model = model.cuda().eval()
        linf_PGD_attack =LinfPGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.031372, nb_iter=30,
                      eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0, targeted=False)
        l2_PGD_attack = L2PGDAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),eps=4.6,
                                    nb_iter=30,clip_min=0.0, clip_max=1.0, targeted=False)
        FGSM_attack = FGSM(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"))
        momentum_attack = MomentumIterativeAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.031372, nb_iter=30,
                      eps_iter=0.01, clip_min=0.0, clip_max=1.0, targeted=False)
        attackers.append(linf_PGD_attack)
        attackers.append(l2_PGD_attack)
        attackers.append(FGSM_attack)
        attackers.append(momentum_attack)
        log.info("Create model {} done!".format(model_name))

    generate_and_save_adv_examples(datasetname, data_loader, attackers, save_dir_path)
コード例 #3
0
def main():
    # create model
    net = pretrainedmodels.__dict__[args.arch](num_classes=1000,
                                               pretrained='imagenet')
    height, width = net.input_size[1], net.input_size[2]
    model = nn.Sequential(Normalize(mean=net.mean, std=net.std), net)
    model = model.to(device)

    # create dataloader
    data_loader, image_list = load_images(input_dir=args.input_dir,
                                          batch_size=args.batch_size,
                                          input_height=height,
                                          input_width=width)

    # create adversary
    epsilon = args.epsilon / 255.0
    if args.step_size < 0:
        step_size = epsilon / args.num_steps
    else:
        step_size = args.step_size / 255.0

    # using our method - Skip Gradient Method (SGM)
    if args.gamma < 1.0:
        if args.arch in [
                'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
        ]:
            register_hook_for_resnet(model, arch=args.arch, gamma=args.gamma)
        elif args.arch in ['densenet121', 'densenet169', 'densenet201']:
            register_hook_for_densenet(model, arch=args.arch, gamma=args.gamma)
        else:
            raise ValueError(
                'Current code only supports resnet/densenet. '
                'You can extend this code to other architectures.')

    if args.momentum > 0.0:
        print('using PGD attack with momentum = {}'.format(args.momentum))
        adversary = MomentumIterativeAttack(
            predict=model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            eps=epsilon,
            nb_iter=args.num_steps,
            eps_iter=step_size,
            decay_factor=args.momentum,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False)
    else:
        print('using linf PGD attack')
        adversary = LinfPGDAttack(predict=model,
                                  loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                  eps=epsilon,
                                  nb_iter=args.num_steps,
                                  eps_iter=step_size,
                                  rand_init=False,
                                  clip_min=0.0,
                                  clip_max=1.0,
                                  targeted=False)

    generate_adversarial_example(model=model,
                                 data_loader=data_loader,
                                 adversary=adversary,
                                 img_path=image_list)
コード例 #4
0
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
コード例 #5
0
         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
         clip_min=0.0,
         clip_max=1.0,
         eps=0.007,
         targeted=False)
 # elif args.attack_method == "JSMA":
 #     adversary =JacobianSaliencyMapAttack(
 #         model,num_classes=args.num_classes,
 #         clip_min=0.0, clip_max=1.0,gamma=0.145,theta=1)
 elif args.attack_method == "Momentum":
     adversary = MomentumIterativeAttack(
         model,
         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
         eps=args.epsilon,
         nb_iter=40,
         decay_factor=1.0,
         eps_iter=1.0,
         clip_min=0.0,
         clip_max=1.0,
         targeted=False,
         ord=np.inf)
 elif args.attack_method == "STA":
     adversary = SpatialTransformAttack(
         model,
         num_classes=args.num_classes,
         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
         initial_const=0.05,
         max_iterations=1000,
         search_steps=1,
         confidence=0,
         clip_min=0.0,
コード例 #6
0
ファイル: data_generator.py プロジェクト: leeyegy/TRADES
def _get_test_adv(attack_method,epsilon):
    # define parameter
    parser = argparse.ArgumentParser(description='Train MNIST')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--mode', default="adv", help="cln | adv")
    parser.add_argument('--sigma', default=75, type=int, help='noise level')
    parser.add_argument('--train_batch_size', default=50, type=int)
    parser.add_argument('--test_batch_size', default=1000, type=int)
    parser.add_argument('--log_interval', default=200, type=int)
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--monitor', default=False, type=bool, help='if monitor the training process')
    parser.add_argument('--start_save', default=90, type=int,
                        help='the threshold epoch which will start to save imgs data using in testing')

    # attack
    parser.add_argument("--attack_method", default="PGD", type=str,
                        choices=['FGSM', 'PGD', 'Momentum', 'STA'])

    parser.add_argument('--epsilon', type=float, default=8 / 255, help='if pd_block is used')

    parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/MNIST]')

    # net
    parser.add_argument('--net_type', default='wide-resnet', type=str, help='model')
    parser.add_argument('--depth', default=28, type=int, help='depth of model')
    parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
    parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate')
    parser.add_argument('--num_classes', default=10, type=int)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # load basic data
    # 测试包装的loader
    test_loader = get_handled_cifar10_test_loader(num_workers=4, shuffle=False, batch_size=50)

    # 加载网络模型
    # Load checkpoint
    print('| Resuming from checkpoint...')
    assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
    _, file_name = getNetwork(args)
    checkpoint = torch.load('./checkpoint/' + args.dataset + os.sep + file_name + '.t7')  # os.sep提供跨平台的分隔符
    model = checkpoint['net']

    #
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # 定义对抗攻击类型:C&W
    from advertorch.attacks import LinfPGDAttack
    if attack_method == "PGD":
        adversary = LinfPGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon,
            nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
            targeted=False)
    elif attack_method == "FGSM":
        adversary = GradientSignAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            clip_min=0.0, clip_max=1.0, eps=0.007, targeted=False)  # 先测试一下不含扰动范围限制的,FGSM的eps代表的是一般的eps_iter
    elif attack_method == "Momentum":
        adversary = MomentumIterativeAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon,
            nb_iter=40, decay_factor=1.0, eps_iter=1.0, clip_min=0.0, clip_max=1.0,
            targeted=False, ord=np.inf)
    elif attack_method == "STA":
        adversary = SpatialTransformAttack(
            model, num_classes=args.num_classes, loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            initial_const=0.05, max_iterations=1000, search_steps=1, confidence=0, clip_min=0.0, clip_max=1.0,
            targeted=False, abort_early=True)  # 先测试一下不含扰动范围限制的

    # generate for train.h5 | save as train_adv_attackMethod_epsilon
    test_adv = []
    test_true_target = []
    for clndata, target in test_loader:
        print("clndata:{}".format(clndata.size()))
        clndata, target = clndata.to(device), target.to(device)
        with ctx_noparamgrad_and_eval(model):
            advdata = adversary.perturb(clndata, target)
            test_adv.append(advdata.detach().cpu().numpy())
        test_true_target.append(target.cpu().numpy())
    test_adv = np.reshape(np.asarray(test_adv),[-1,3,32,32])
    test_true_target = np.reshape(np.asarray(test_true_target),[-1])
    print("test_adv.shape:{}".format(test_adv.shape))
    print("test_true_target.shape:{}".format(test_true_target.shape))
    del model

    return test_adv, test_true_target
コード例 #7
0
    index += 1
    if index == 23:
        break

bs, c, h, w = np.shape(cln_data)
# print true labels
print(true_label)
cln_data = cln_data.view(-1, c, h, w)
cln_data, true_label = cln_data.to(device), true_label.to(device)

# MomentumIterativateAttack
adversary = MomentumIterativeAttack(
    model,
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    eps=0.2,
    nb_iter=40,
    eps_iter=0.01,
    clip_min=0.0,
    clip_max=1.0,
    targeted=False)

# L2PGDAttack
'''
adversary = L2PGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.15,
    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)
'''

# LinfPGDAttack
'''
コード例 #8
0
def _generate_adv_file(attack_method, num_classes, epsilon, set_size):
    # load model
    model = torch.load(os.path.join("checkpoint", "resnet50_epoch_22.pth"))
    model = model.cuda()

    #define attack
    if attack_method == "PGD":
        adversary = LinfPGDAttack(model,
                                  loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                  eps=epsilon,
                                  nb_iter=20,
                                  eps_iter=0.01,
                                  rand_init=True,
                                  clip_min=0.0,
                                  clip_max=1.0,
                                  targeted=False)
    elif attack_method == "FGSM":
        adversary = GradientSignAttack(
            model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            clip_min=0.0,
            clip_max=1.0,
            eps=epsilon,
            targeted=False)
    elif attack_method == "Momentum":
        adversary = MomentumIterativeAttack(
            model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            eps=epsilon,
            nb_iter=20,
            decay_factor=1.0,
            eps_iter=1.0,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            ord=np.inf)
    elif attack_method == "STA":
        adversary = SpatialTransformAttack(
            model,
            num_classes=num_classes,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            initial_const=0.05,
            max_iterations=500,
            search_steps=1,
            confidence=0,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            abort_early=True)
    elif attack_method == "DeepFool":
        adversary = DeepFool(model,
                             max_iter=20,
                             clip_max=1.0,
                             clip_min=0.0,
                             epsilon=epsilon)
    elif attack_method == "CW":
        adversary = CarliniWagnerL2Attack(
            model,
            num_classes=args.num_classes,
            epsilon=epsilon,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            max_iterations=20,
            confidence=0,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            abort_early=True)

    # version two
    h5_store = h5py.File("data/test_tiny_ImageNet_" + str(set_size) + ".h5",
                         "r")
    data = h5_store['data'][:]
    target = h5_store['true_target'][:]
    data = torch.from_numpy(data)
    target = torch.from_numpy(target)
    test_dataset = ImageNetDataset(data, target)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=4,
                             drop_last=True,
                             batch_size=50,
                             shuffle=False)

    torch.manual_seed(0)
    test_adv = np.zeros([set_size, 3, 64, 64])
    test_true_target = np.zeros([set_size])

    # perturb
    for batch_idx, (clndata, target) in enumerate(test_loader):
        print("{}/{}".format(batch_idx, set_size // 50))
        clndata, target = clndata.cuda().float(), target.cuda().long()
        with ctx_noparamgrad_and_eval(model):
            # print(target)
            advdata = adversary.perturb(clndata, target)
            test_adv[batch_idx * 50:(batch_idx + 1) *
                     50, :, :, :] = advdata.detach().cpu().numpy()
        test_true_target[batch_idx * 50:(batch_idx + 1) *
                         50] = target.cpu().numpy()

    print("test_adv.shape:{}".format(test_adv.shape))
    print("test_true_target.shape:{}".format(test_true_target.shape))
    del model

    h5_store = h5py.File(
        "data/test_tiny_ImageNet_" + str(set_size) + "_adv_" +
        str(attack_method) + "_" + str(epsilon) + ".h5", 'w')
    h5_store.create_dataset('data', data=test_adv)
    h5_store.create_dataset('true_target', data=test_true_target)
    h5_store.close()
コード例 #9
0
ファイル: test_adv.py プロジェクト: CEA-LIST/adv-sat

epsilon = args.eps
epsilon = epsilon / 255.
ddn = False
if args.attack == 'PGD':
    adversary = PGDAttack(lambda x: wrapper(x, pcl=pcl),
                          eps=epsilon,
                          eps_iter=epsilon / 4,
                          nb_iter=10,
                          ord=norm,
                          rand_init=True)
elif args.attack == 'MIFGSM':
    adversary = MomentumIterativeAttack(
        lambda x: wrapper(normalize(x), pcl=pcl),
        eps=epsilon,
        eps_iter=epsilon / 10,
        ord=norm,
        nb_iter=10)
elif args.attack == 'FGSM':
    adversary = GradientSignAttack(lambda x: wrapper(x, pcl=pcl), eps=epsilon)
    # adversary = PGDAttack(lambda x: wrapper(x, pcl=pcl), eps=epsilon, eps_iter=epsilon, nb_iter=1, ord=norm, rand_init=False)
elif args.attack == 'CW':
    adversary = CarliniWagnerL2Attack(lambda x: wrapper(x, pcl=pcl),
                                      10,
                                      binary_search_steps=2,
                                      max_iterations=500,
                                      initial_const=1e-1)
elif args.attack == 'DDN':
    adversary = DDN(steps=100, device=device)
    ddn = True
else:
コード例 #10
0
def MIM(model,X,y,num_iter=10):
    adversary = MomentumIterativeAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3, nb_iter=10, decay_factor=1.0, eps_iter=0.003, clip_min=0.0, clip_max=1.0)
    adv_untargeted = adversary.perturb(X, y)-X
    return adv_untargeted
コード例 #11
0
ファイル: other_attack_test.py プロジェクト: ksouvik52/ATMC
def model_test(model, data_loader, output_file_path, attack='mia', eps=8/255, nb_iter=3):
    model.eval()
    
    test_loss, adv_loss, correct, correct_adv, nb_data, adv_l2dist, adv_linfdist = \
    0, 0, 0, 0, 0, 0.0, 0.0

    start_time = time.time()
    for i, (data, target) in enumerate(data_loader):
        print('i:', i)

        indx_target = target.clone()
        data_length = data.shape[0]
        nb_data += data_length
        
        data, target = data.cuda(), target.cuda()

        with torch.no_grad():
            output = model(data)
        
        # print('data max:', torch.max(data))
        # print('data min:', torch.min(data))
        if attack == 'cw':
            if i >= 5:
                break
            adversary = CarliniWagnerL2Attack(predict=model, num_classes=10, targeted=True, 
                clip_min=min_v, clip_max=max_v, max_iterations=50)
        elif attack == 'mia':
            adversary = MomentumIterativeAttack(predict=model, targeted=True, eps=eps, nb_iter=40, eps_iter=0.01*(max_v-min_v), 
                clip_min=min_v, clip_max=max_v )
        elif attack == 'pgd':
            adversary = LinfPGDAttack(predict=model, targeted=True, eps=eps, nb_iter=nb_iter, eps_iter=eps*1.25/nb_iter,
                clip_min=min_v, clip_max=max_v )
        else:
            raise 'unimplemented error'
        pred = model(data) # torch.Size([128, 10])
        print('pred:', type(pred), pred.shape)
        print('target:', type(target), target.shape, target[0:20])
        # pred_argmax = torch.argmax(pred, dim=1)
        # print('pred_argmax:', type(pred_argmax), pred_argmax.shape, pred_argmax[0:10])
        # for i in range(list(pred.shape)[0]):
        #     pred[i,pred_argmax[i]] = -1
        for i in range(list(pred.shape)[0]):
            pred[i,target[i]] = -1
        # target_adv = torch.argmax(pred, dim=1)
        target_adv = (target + 5) % 10
        print('target_adv:', type(target_adv), target_adv.shape, target_adv[0:20])
        data_adv = adversary.perturb(data, target_adv)

        print('data_adv max:', torch.max(data_adv))
        print('data_adv min:', torch.min(data_adv))
        print('linf:', torch.max(torch.abs(data_adv-data)) )

        adv_l2dist += torch.norm((data-data_adv).view(data.size(0), -1), p=2, dim=-1).sum().item()
        adv_linfdist += torch.max((data-data_adv).view(data.size(0), -1).abs(), dim=-1)[0].sum().item()

        with torch.no_grad():
            output_adv = model(data_adv)

        pred_adv = output_adv.data.max(1)[1]
        correct_adv += pred_adv.cpu().eq(indx_target).sum()
        
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.cpu().eq(indx_target).sum()
        
        time_consume = time.time() - start_time
        print('time_consume:', time_consume)

        acc = float(100. * correct) / nb_data
        print('\tTest set: Accuracy: {}/{}({:.2f}%)'.format(
            correct, nb_data, acc))

        acc_adv = float(100. * correct_adv) / nb_data
        print('\tAdv set: Accuracy : {}/{}({:.2f}%)'.format(
            correct_adv, nb_data, acc_adv
        ))

    adv_l2dist /= nb_data
    adv_linfdist /= nb_data
    print('\tAdv dist: L2: {:.8f} , Linf: {:.8f}'.format(adv_l2dist, adv_linfdist))

    with open(output_file_path, "a+") as output_file:
        output_file.write(args.model_name + '\n')
        info_string = 'attack: %s:\n acc: %.2f, acc_adv: %.2f, adv_l2dist: %.2f, adv_linfdist: %.2f, time_consume: %.2f' % (
            attack, acc, acc_adv, adv_l2dist, adv_linfdist, time_consume) 
        output_file.write(info_string)

    return acc, acc_adv
コード例 #12
0
def _get_test_adv(attack_method, epsilon):
    torch.manual_seed(args.seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # load basic data
    # 测试包装的loader
    test_loader = get_handled_cifar10_test_loader(num_workers=4,
                                                  shuffle=False,
                                                  batch_size=50)

    # 加载网络模型
    # Load checkpoint
    print('| Resuming from checkpoint...')
    assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
    _, file_name = getNetwork(args)
    checkpoint = torch.load('./checkpoint/' + args.dataset + os.sep +
                            file_name + '.t7')  # os.sep提供跨平台的分隔符
    model = checkpoint['net']

    #
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # 定义对抗攻击类型:C&W
    from advertorch.attacks import LinfPGDAttack
    if attack_method == "PGD":
        adversary = LinfPGDAttack(model,
                                  loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                  eps=epsilon,
                                  nb_iter=20,
                                  eps_iter=0.01,
                                  rand_init=True,
                                  clip_min=0.0,
                                  clip_max=1.0,
                                  targeted=False)
    elif attack_method == "FGSM":
        adversary = GradientSignAttack(
            model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            clip_min=0.0,
            clip_max=1.0,
            eps=epsilon,
            targeted=False)  # 先测试一下不含扰动范围限制的,FGSM的eps代表的是一般的eps_iter
    elif attack_method == "Momentum":
        adversary = MomentumIterativeAttack(
            model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            eps=epsilon,
            nb_iter=20,
            decay_factor=1.0,
            eps_iter=1.0,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            ord=np.inf)
    elif attack_method == "STA":
        adversary = SpatialTransformAttack(
            model,
            num_classes=args.num_classes,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            initial_const=0.05,
            max_iterations=1000,
            search_steps=1,
            confidence=0,
            clip_min=0.0,
            clip_max=1.0,
            targeted=False,
            abort_early=True)  # 先测试一下不含扰动范围限制的

    # generate for train.h5 | save as train_adv_attackMethod_epsilon
    test_adv = []
    test_true_target = []
    for clndata, target in test_loader:
        print("clndata:{}".format(clndata.size()))
        clndata, target = clndata.to(device), target.to(device)
        with ctx_noparamgrad_and_eval(model):
            advdata = adversary.perturb(clndata, target)
            test_adv.append(advdata.detach().cpu().numpy())
        test_true_target.append(target.cpu().numpy())
    test_adv = np.reshape(np.asarray(test_adv), [-1, 3, 32, 32])
    test_true_target = np.reshape(np.asarray(test_true_target), [-1])
    print("test_adv.shape:{}".format(test_adv.shape))
    print("test_true_target.shape:{}".format(test_true_target.shape))
    del model

    return test_adv, test_true_target
コード例 #13
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar')
    parser.add_argument('--start', type=int, default=0)
    parser.add_argument('--end', type=int, default=100)
    parser.add_argument('--n_iter', type=int, default=1000)
    parser.add_argument('--query_step', type=int, default=1)
    parser.add_argument('--transfer', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--sweep', action='store_true')
    parser.add_argument("--wandb", action="store_true", default=False, help='Use wandb for logging')
    parser.add_argument('--ensemble_adv_trained', action='store_true')
    parser.add_argument('--gamma', type=float, default=0.5)
    parser.add_argument('--batch_size', type=int, default=256, metavar='S')
    parser.add_argument('--test_batch_size', type=int, default=128, metavar='S')
    parser.add_argument('--train_set', default='test',
                        choices=['train_and_test','test','train'],
                        help='add the test set in the training set')
    parser.add_argument('--modelIn', type=str,
                        default='../pretrained_classifiers/cifar/res18/model_0.pt')
    parser.add_argument('--robust_model_path', type=str,
                        default="../madry_challenge_models/mnist/adv_trained/mnist_lenet5_advtrained.pt")
    parser.add_argument('--dir_test_models', type=str,
                        default="../",
                        help="The path to the directory containing the classifier models for evaluation.")
    parser.add_argument("--max_test_model", type=int, default=2,
                    help="The maximum number of pretrained classifiers to use for testing.")
    parser.add_argument('--train_on_madry', default=False, action='store_true',
                        help='Train using Madry tf grad')
    parser.add_argument('--momentum', default=0.0, type=float)
    parser.add_argument('--train_on_list', default=False, action='store_true',
                        help='train on a list of classifiers')
    parser.add_argument('--attack_ball', type=str, default="Linf",
                        choices= ['L2','Linf'])
    parser.add_argument('--source_arch', default="res18",
                        help="The architecture we want to attack on CIFAR.")
    parser.add_argument('--adv_models', nargs='*', help='path to adv model(s)')
    parser.add_argument('--target_arch', default=None,
                        help="The architecture we want to blackbox transfer to on CIFAR.")
    parser.add_argument('--epsilon', type=float, default=0.03125, metavar='M',
                        help='Epsilon for Delta (default: 0.1)')
    parser.add_argument('--num_test_samples', default=None, type=int,
                        help="The number of samples used to train and test the attacker.")
    parser.add_argument('--split', type=int, default=None,
                        help="Which subsplit to use.")
    parser.add_argument('--step-size', default=2, type=float, help='perturb step size')
    parser.add_argument('--train_with_critic_path', type=str, default=None,
                        help='Train generator with saved critic model')
    parser.add_argument('--namestr', type=str, default='SGM', \
            help='additional info in output filename to describe experiments')

    args = parser.parse_args()
    args.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader, split_train_loader, split_test_loader = create_loaders(args,
            root='../data', split=args.split)
    if os.path.isfile("../settings.json"):
        with open('../settings.json') as f:
            data = json.load(f)
        args.wandb_apikey = data.get("wandbapikey")

    if args.wandb:
        os.environ['WANDB_API_KEY'] = args.wandb_apikey
        wandb.init(project='NoBox-sweeps', name='AutoAttack-{}'.format(args.dataset))

    model, adv_models, l_test_classif_paths, model_type = data_and_model_setup(args)
    model.to(args.dev)
    model.eval()

    if args.step_size < 0:
        step_size = args.epsilon / args.n_iter
    else:
        step_size = args.step_size / 255.0

    # using our method - Skip Gradient Method (SGM)
    if args.gamma < 1.0:
        if args.source_arch in ['res18', 'res34', 'res50', 'res101', 'res152',
                                'wide_resnet']:
            register_hook_for_resnet(model, arch=args.source_arch, gamma=args.gamma)
        elif args.source_arch in ['dense121', 'dens169', 'dense201']:
            register_hook_for_densenet(model, arch=args.source_arch, gamma=args.gamma)
        else:
            raise ValueError('Current code only supports resnet/densenet. '
                             'You can extend this code to other architectures.')

    if args.momentum > 0.0:
        print('using PGD attack with momentum = {}'.format(args.momentum))
        attacker = MomentumIterativeAttack(predict=model,
                                           loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                           eps=args.epsilon,
                                           nb_iter=args.n_iter,
                                           eps_iter=step_size,
                                           decay_factor=args.momentum,
                                           clip_min=0.0, clip_max=1.0,
                                           targeted=False)
    else:
        print('using Linf PGD attack')
        attacker = LinfPGDAttack(predict=model,
                                 loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                                 eps=args.epsilon, nb_iter=args.n_iter,
                                 eps_iter=step_size, rand_init=False,
                                 clip_min=0.0, clip_max=1.0, targeted=False)


    eval_helpers = [model, model_type, adv_models, l_test_classif_paths, test_loader]
    total_fool_rate = eval(args, attacker, "SGM-Attack", eval_helpers)