def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
def MIM(model,X,y,num_iter=10):
    adversary = MomentumIterativeAttack(model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3, nb_iter=10, decay_factor=1.0, eps_iter=0.003, clip_min=0.0, clip_max=1.0)
    adv_untargeted = adversary.perturb(X, y)-X
    return adv_untargeted
adversary = L2PGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.15,
    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)
'''

# LinfPGDAttack
'''
adversary = LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.15,
    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0, clip_max=1.0,
    targeted=False)
'''

# generate untargeted adversarial samples
adv_untargeted = adversary.perturb(cln_data, true_label)

# generate targeted adversarial samples
target = torch.ones_like(true_label) * 3
adversary.targeted = True
adv_targeted = adversary.perturb(cln_data, target)

# Test the model on these samples
pred_cln = predict_from_logits(model(cln_data))
pred_untargeted_adv = predict_from_logits(model(adv_untargeted))
pred_targeted_adv = predict_from_logits(model(adv_targeted))

# Show the results
# Model performacne on clean, untargeted and targeted images
# ----------------------------------------------------------
plt.figure(figsize=(10, 8))
Esempio n. 4
0
def model_test(model, data_loader, output_file_path, attack='mia', eps=8/255, nb_iter=3):
    model.eval()
    
    test_loss, adv_loss, correct, correct_adv, nb_data, adv_l2dist, adv_linfdist = \
    0, 0, 0, 0, 0, 0.0, 0.0

    start_time = time.time()
    for i, (data, target) in enumerate(data_loader):
        print('i:', i)

        indx_target = target.clone()
        data_length = data.shape[0]
        nb_data += data_length
        
        data, target = data.cuda(), target.cuda()

        with torch.no_grad():
            output = model(data)
        
        # print('data max:', torch.max(data))
        # print('data min:', torch.min(data))
        if attack == 'cw':
            if i >= 5:
                break
            adversary = CarliniWagnerL2Attack(predict=model, num_classes=10, targeted=True, 
                clip_min=min_v, clip_max=max_v, max_iterations=50)
        elif attack == 'mia':
            adversary = MomentumIterativeAttack(predict=model, targeted=True, eps=eps, nb_iter=40, eps_iter=0.01*(max_v-min_v), 
                clip_min=min_v, clip_max=max_v )
        elif attack == 'pgd':
            adversary = LinfPGDAttack(predict=model, targeted=True, eps=eps, nb_iter=nb_iter, eps_iter=eps*1.25/nb_iter,
                clip_min=min_v, clip_max=max_v )
        else:
            raise 'unimplemented error'
        pred = model(data) # torch.Size([128, 10])
        print('pred:', type(pred), pred.shape)
        print('target:', type(target), target.shape, target[0:20])
        # pred_argmax = torch.argmax(pred, dim=1)
        # print('pred_argmax:', type(pred_argmax), pred_argmax.shape, pred_argmax[0:10])
        # for i in range(list(pred.shape)[0]):
        #     pred[i,pred_argmax[i]] = -1
        for i in range(list(pred.shape)[0]):
            pred[i,target[i]] = -1
        # target_adv = torch.argmax(pred, dim=1)
        target_adv = (target + 5) % 10
        print('target_adv:', type(target_adv), target_adv.shape, target_adv[0:20])
        data_adv = adversary.perturb(data, target_adv)

        print('data_adv max:', torch.max(data_adv))
        print('data_adv min:', torch.min(data_adv))
        print('linf:', torch.max(torch.abs(data_adv-data)) )

        adv_l2dist += torch.norm((data-data_adv).view(data.size(0), -1), p=2, dim=-1).sum().item()
        adv_linfdist += torch.max((data-data_adv).view(data.size(0), -1).abs(), dim=-1)[0].sum().item()

        with torch.no_grad():
            output_adv = model(data_adv)

        pred_adv = output_adv.data.max(1)[1]
        correct_adv += pred_adv.cpu().eq(indx_target).sum()
        
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.cpu().eq(indx_target).sum()
        
        time_consume = time.time() - start_time
        print('time_consume:', time_consume)

        acc = float(100. * correct) / nb_data
        print('\tTest set: Accuracy: {}/{}({:.2f}%)'.format(
            correct, nb_data, acc))

        acc_adv = float(100. * correct_adv) / nb_data
        print('\tAdv set: Accuracy : {}/{}({:.2f}%)'.format(
            correct_adv, nb_data, acc_adv
        ))

    adv_l2dist /= nb_data
    adv_linfdist /= nb_data
    print('\tAdv dist: L2: {:.8f} , Linf: {:.8f}'.format(adv_l2dist, adv_linfdist))

    with open(output_file_path, "a+") as output_file:
        output_file.write(args.model_name + '\n')
        info_string = 'attack: %s:\n acc: %.2f, acc_adv: %.2f, adv_l2dist: %.2f, adv_linfdist: %.2f, time_consume: %.2f' % (
            attack, acc, acc_adv, adv_l2dist, adv_linfdist, time_consume) 
        output_file.write(info_string)

    return acc, acc_adv