def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
Exemplo n.º 2
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG,
                        handlers=[
                            logging.FileHandler(
                                os.path.join(args.fname, 'output.log')),
                            logging.StreamHandler()
                        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    num_workers = 2
    train_dataset = datasets.CIFAR100(args.data_dir,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
    test_dataset = datasets.CIFAR100(args.data_dir,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18(num_classes=100)
        proxy = PreActResNet18(num_classes=100)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    proxy = proxy.cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    best_test_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(args.fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       epsilon,
                                       pgd_alpha,
                                       args.attack_iters,
                                       args.restarts,
                                       args.norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                       args.attack_iters, args.restarts,
                                       args.norm)
                delta = delta.detach()
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)
            X_adv = normalize(
                torch.clamp(X + delta[:X.size(0)],
                            min=lower_limit,
                            max=upper_limit))

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                   args.attack_iters_test, args.restarts,
                                   args.norm)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        logger.info(
            '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
            epoch, train_time - start_time, test_time - train_time, lr,
            train_loss / train_n, train_acc / train_n,
            train_robust_loss / train_n, train_robust_acc / train_n,
            test_loss / test_n, test_acc / test_n, test_robust_loss / test_n,
            test_robust_acc / test_n)

        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_n,
                    'test_robust_loss': test_robust_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    model.train()

    params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    epochs = args.epochs

    if args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, args.epochs // 2, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop:
                return args.lr_max
            else:
                return args.lr_max / 10.

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_robust_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc \t Defence Mean \t Attack Mean'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        defence_mean = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   args.fgsm_alpha * epsilon,
                                   1,
                                   1,
                                   'l_inf',
                                   fgsm_init=args.fgsm_init)
            elif args.attack == 'multigrad':
                delta = multi_grad(model, X, y, epsilon,
                                   args.fgsm_alpha * epsilon, args.multi_samps,
                                   args.multi_th, args.multi_parallel)
            elif args.attack == 'zerograd':
                delta = zero_grad(model, X, y, epsilon,
                                  args.fgsm_alpha * epsilon, args.zero_qval,
                                  args.zero_iters, args.fgsm_init)
            elif args.attack == 'pgd':
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 10, 1,
                                   'l_inf')
            delta = delta.detach()
            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            defence_mean += torch.mean(torch.abs(delta)) * y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        attack_mean = 0
        for i, batch in enumerate(test_batches):
            if not epoch + 1 == epochs and not args.full_test and i > len(
                    test_batches) / 10:
                break
            X, y = batch['input'], batch['target']

            delta = attack_pgd(model,
                               X,
                               y,
                               epsilon,
                               pgd_alpha,
                               args.test_iters,
                               args.test_restarts,
                               'l_inf',
                               early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)
            attack_mean += torch.mean(torch.abs(delta)) * y.size(0)

        test_time = time.time()

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n,
                defence_mean * 255 / train_n, attack_mean * 255 / test_n)

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n, -1,
                attack_mean * 255 / test_n)

        if args.eval or epoch + 1 == epochs:
            start_test_time = time.time()
            test_loss = 0
            test_acc = 0
            test_robust_loss = 0
            test_robust_acc = 0
            test_n = 0
            for i, batch in enumerate(test_batches):
                X, y = batch['input'], batch['target']

                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   pgd_alpha,
                                   50,
                                   10,
                                   'l_inf',
                                   early_stop=True)
                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                test_robust_loss += robust_loss.item() * y.size(0)
                test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                test_loss += loss.item() * y.size(0)
                test_acc += (output.max(1)[1] == y).sum().item()
                test_n += y.size(0)

            logger.info(
                'PGD50 \t time: %.1f,\t clean loss: %.4f,\t clean acc: %.4f,\t robust loss: %.4f,\t robust acc: %.4f',
                time.time() - start_test_time, test_loss / test_n,
                test_acc / test_n, test_robust_loss / test_n,
                test_robust_acc / test_n)
            return
Exemplo n.º 4
0
class TransClassifier():
    def __init__(self, num_trans, args):
        self.n_trans = num_trans
        self.args = args
        self.netWRN = WideResNet(self.args.depth, num_trans,
                                 self.args.widen_factor).cuda()
        self.optimizer = torch.optim.Adam(self.netWRN.parameters())

    def fit_trans_classifier(self, x_train, x_test, y_test):
        print("Training")
        self.netWRN.train()
        bs = self.args.batch_size
        N, sh, sw, nc = x_train.shape
        n_rots = self.n_trans
        m = self.args.m
        celoss = torch.nn.CrossEntropyLoss()
        ndf = 256

        for epoch in range(self.args.epochs):
            rp = np.random.permutation(N // n_rots)
            rp = np.concatenate(
                [np.arange(n_rots) + rp[i] * n_rots for i in range(len(rp))])
            assert len(rp) == N
            all_zs = torch.zeros((len(x_train), ndf)).cuda()
            diffs_all = []

            for i in range(0, len(x_train), bs):
                batch_range = min(bs, len(x_train) - i)
                idx = np.arange(batch_range) + i
                xs = torch.from_numpy(x_train[rp[idx]]).float().cuda()
                zs_tc, zs_ce = self.netWRN(xs)

                all_zs[idx] = zs_tc
                train_labels = torch.from_numpy(
                    np.tile(np.arange(n_rots),
                            batch_range // n_rots)).long().cuda()
                zs = torch.reshape(zs_tc, (batch_range // n_rots, n_rots, ndf))

                means = zs.mean(0).unsqueeze(0)
                diffs = -(
                    (zs.unsqueeze(2).detach().cpu().numpy() -
                     means.unsqueeze(1).detach().cpu().numpy())**2).sum(-1)
                diffs_all.append(
                    torch.diagonal(torch.tensor(diffs), dim1=1, dim2=2))

                tc = tc_loss(zs, m)
                ce = celoss(zs_ce, train_labels)
                if self.args.reg:
                    loss = ce + self.args.lmbda * tc + 10 * (zs * zs).mean()
                else:
                    loss = ce + self.args.lmbda * tc
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            self.netWRN.eval()
            all_zs = torch.reshape(all_zs, (N // n_rots, n_rots, ndf))
            means = all_zs.mean(0, keepdim=True)

            with torch.no_grad():
                batch_size = bs
                val_probs_rots = np.zeros((len(y_test), self.n_trans))
                for i in range(0, len(x_test), batch_size):
                    batch_range = min(batch_size, len(x_test) - i)
                    idx = np.arange(batch_range) + i
                    xs = torch.from_numpy(x_test[idx]).float().cuda()

                    zs, fs = self.netWRN(xs)
                    zs = torch.reshape(zs,
                                       (batch_range // n_rots, n_rots, ndf))

                    diffs = ((zs.unsqueeze(2) - means)**2).sum(-1)
                    diffs_eps = self.args.eps * torch.ones_like(diffs)
                    diffs = torch.max(diffs, diffs_eps)
                    logp_sz = torch.nn.functional.log_softmax(-diffs, dim=2)

                    zs_reidx = np.arange(batch_range // n_rots) + i // n_rots
                    val_probs_rots[zs_reidx] = -torch.diagonal(
                        logp_sz, 0, 1, 2).cpu().data.numpy()

                val_probs_rots = val_probs_rots.sum(1)
                print("Epoch:", epoch, ", AUC: ",
                      roc_auc_score(y_test, -val_probs_rots))
Exemplo n.º 5
0
def main():
    global args, best_prec1, exp_dir

    best_prec1 = 0
    args = parser.parse_args()
    print(args.lr_decay_at)
    assert args.normalization in ['GCN_ZCA',
                                  'GCN'], 'normalization {} unknown'.format(
                                      args.normalization)

    global zca
    if 'ZCA' in args.normalization:
        zca_params = torch.load('./data/cifar-10-batches-py/zca_params.pth')
        zca = ZCA(zca_params)
    else:
        zca = None

    exp_dir = os.path.join('experiments', args.name)
    if os.path.exists(exp_dir):
        print("same experiment exist...")
        #return
    else:
        os.makedirs(exp_dir)

    # DATA SETTINGS
    global dataset_train, dataset_test
    if args.dataset == 'cifar10':
        import cifar
        dataset_train = cifar.CIFAR10(args, train=True)
        dataset_test = cifar.CIFAR10(args, train=False)
    if args.UDA:
        # loader for UDA
        dataset_train_uda = cifar.CIFAR10(args, True, True)
        uda_loader = torch.utils.data.DataLoader(
            dataset_train_uda,
            batch_size=args.batch_size_unsup,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)
        iter_uda = iter(uda_loader)
    else:
        iter_uda = None

    train_loader, test_loader = initialize_loader()

    # MODEL SETTINGS
    if args.arch == 'WRN-28-2':
        model = WideResNet(28, [100, 10][int(args.dataset == 'cifar10')],
                           2,
                           dropRate=args.dropout_rate)
        model = torch.nn.DataParallel(model.cuda())
    else:
        raise NotImplementedError('arch {} is not implemented'.format(
            args.arch))
    if args.optimizer == 'Adam':
        print("use Adam optimizer")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.l2_reg)
    elif args.optimizer == 'SGD':
        print("use SGD optimizer")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.l2_reg,
                                    nesterov=args.nesterov)

    if args.lr_decay == 'cosine':
        print("use cosine lr scheduler")
        global scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_iter, eta_min=args.final_lr)

    global batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup
    batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter()
    t = time.time()
    model.train()
    iter_sup = iter(train_loader)
    for train_iter in range(args.max_iter):
        # TRAIN
        lr = adjust_learning_rate(optimizer, train_iter + 1)
        train(model,
              iter_sup,
              optimizer,
              train_iter,
              data_iterator_uda=iter_uda)

        # LOGGING
        if (train_iter + 1) % args.print_freq == 0:
            print('ITER: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})\t'
                  'Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Learning rate {2} TSA th {3}'.format(
                      train_iter,
                      args.max_iter,
                      lr,
                      TSA_th(train_iter),
                      batch_time=batch_time,
                      loss=losses_sup,
                      l1_loss=losses_l1,
                      unsup_loss=losses_unsup,
                      top1=top1))

        if (train_iter +
                1) % args.eval_iter == 0 or train_iter + 1 == args.max_iter:
            # EVAL
            print("evaluation at iter {}".format(train_iter))
            prec1 = test(model, test_loader)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'iter': train_iter + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            print("* Best accuracy: {}".format(best_prec1))
            eval_interval_time = time.time() - t
            t = time.time()
            print("total {} sec for {} iterations".format(
                eval_interval_time, args.eval_iter))
            seconds_remaining = eval_interval_time / float(
                args.eval_iter) * (args.max_iter - train_iter)
            print("{}:{}:{} remaining".format(
                int(seconds_remaining // 3600),
                int((seconds_remaining % 3600) // 60),
                int(seconds_remaining % 60)))
            model.train()
            batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
            ), AverageMeter()
            iter_sup = iter(train_loader)
            if iter_uda is not None:
                iter_uda = iter(uda_loader)
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = 'noise_predictor/' + names
    else:
        args.fname = 'noise_predictor/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = True

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    print("Attacks")
    attacks = args.list.split("_")
    print(attacks)

    for i in range(len(attacks)):
        _adv_dir = "adv_examples/{}/".format(attacks[i])
        train_path = _adv_dir + "train.pth"
        test_path = _adv_dir + "test.pth"

        adv_train_data = torch.load(train_path)
        adv_test_data = torch.load(test_path)

        if i == 0:
            train_adv_images = adv_train_data["adv"]
            test_adv_images = adv_test_data["adv"]
            train_adv_labels = [i] * len(adv_train_data["label"])
            test_adv_labels = [i] * len(adv_test_data["label"])
        else:
            train_adv_images = np.concatenate(
                (train_adv_images, adv_train_data["adv"]))
            test_adv_images = np.concatenate(
                (test_adv_images, adv_test_data["adv"]))
            train_adv_labels = np.concatenate(
                (train_adv_labels, [i] * len(adv_train_data["label"])))
            test_adv_labels = np.concatenate(
                (test_adv_labels, [i] * len(adv_test_data["label"])))

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    print("")
    print("Train Adv Attack Data: ", attacks)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(num_classes=args.num_classes)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(num_classes=args.num_classes)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(num_classes=3)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()

    model.train()

    logger.info('Epoch \t Train Robust Acc \t Test Robust Acc')

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, adv_batch in enumerate(train_adv_batches):
            if args.eval:
                break
            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_n += adv_y.size(0)

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t\t %.4f', epoch + 1,
                    train_robust_acc / train_n, test_adv_acc / test_adv_n)

        # Save results
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_adv_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_adv_n,
                    'test_adv_loss': test_adv_loss / test_adv_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_adv_n
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = '../../trained_models/' + names
    else:
        args.fname = '../../trained_models/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    print("")
    print("Train Original Data: ")
    print("Len: ", len(train_set))
    print("")

    shuffle = False

    train_batches = Batches(train_set, args.batch_size, shuffle=shuffle)
    test_batches = Batches(test_set, args.batch_size, shuffle=False)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    #     ATTACK_LIST = ["autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm", "newtonfool", "pgd", "pixelattack", "spatialtransformation", "squareattack"]
    ATTACK_LIST = [
        "pixelattack", "spatialtransformation", "squareattack", "fgsm",
        "deepfool", "bim", "cw", "pgd", "autoattack", "autopgd", "newtonfool"
    ]

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "all":

        for i in range(len(ATTACK_LIST)):
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                #                 print(train_adv_images.shape)
                #                 print(adv_train_data["adv"].shape)
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    elif args.attack == "combine":

        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        if args.balanced == None:
            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                if i == 0:
                    train_adv_images = adv_train_data["adv"]
                    train_adv_labels = adv_train_data["label"]
                    test_adv_images = adv_test_data["adv"]
                    test_adv_labels = adv_test_data["label"]
                else:
                    #                 print(train_adv_images.shape)
                    #                 print(adv_train_data["adv"].shape)
                    train_adv_images = np.concatenate(
                        (train_adv_images, adv_train_data["adv"]))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels, adv_train_data["label"]))
                    test_adv_images = np.concatenate(
                        (test_adv_images, adv_test_data["adv"]))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels, adv_test_data["label"]))
        else:
            proportion_str = args.balanced.split("_")
            proportion = [int(x) for x in proportion_str]
            sum_proportion = sum(proportion)
            proportion = [float(x) / float(sum_proportion) for x in proportion]
            sum_samples = 0

            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                random_state = 0
                num_samples = 0
                total = 50000
                if i != len(attacks) - 1:
                    n_samples = int(proportion[i] * total)
                    sum_samples += n_samples
                else:
                    n_samples = total - sum_samples
                print("Sample")
                print(n_samples)

                if i == 0:
                    train_adv_images = resample(adv_train_data["adv"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    train_adv_labels = resample(adv_train_data["label"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    test_adv_images = resample(adv_test_data["adv"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                    test_adv_labels = resample(adv_test_data["label"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                else:
                    train_adv_images = np.concatenate(
                        (train_adv_images,
                         resample(adv_train_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels,
                         resample(adv_train_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_images = np.concatenate(
                        (test_adv_images,
                         resample(adv_test_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels,
                         resample(adv_test_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))

    else:
        raise ValueError("Unknown adversarial data")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(pretrained=True)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(pretrained=True)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(pretrained=True)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    # If we use freeAT or fastAT with previous init
    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']

            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            # Training losses
            if args.mixup:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)

            elif args.mixture:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = args.mixture_alpha * criterion(
                    robust_output,
                    adv_y) + (1 - args.mixture_alpha) * criterion(output, y)

            else:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                if args.focalloss:
                    criterion_nonreduct = nn.CrossEntropyLoss(reduction='none')
                    robust_confidence = F.softmax(robust_output,
                                                  dim=1)[:, adv_y].detach()
                    robust_loss = (criterion_nonreduct(robust_output, adv_y) *
                                   ((1. - robust_confidence)**
                                    args.focallosslambda)).mean()

                elif args.use_DLRloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * dlr_loss(
                            robust_output, adv_y)

                elif args.use_CWloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * CW_loss(
                            robust_output, adv_y)

                elif args.use_FNandWN:
                    #print('use FN and WN with margin')
                    robust_loss = criterion(
                        args.s_FN * robust_output -
                        onehot_target_withmargin_HE, adv_y)

                else:
                    robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            clean_input = normalize(X)
            clean_input.requires_grad = True
            output = model(clean_input)
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            # Get the gradient norm values
            input_grads = torch.autograd.grad(loss,
                                              clean_input,
                                              create_graph=False)[0]

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            train_grad += input_grads.abs().sum()

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            clean_input = normalize(X)
            output = model(clean_input)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t %.4f \t\t %.4f \t %.4f', epoch + 1,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_adv_acc / test_adv_n)

        # Save results
        train_loss_record.append(train_loss / train_n)
        train_acc_record.append(train_acc / train_n)
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_loss_record.txt',
                   np.array(train_loss_record))
        np.savetxt(args.fname + '/train_acc_record.txt',
                   np.array(train_acc_record))
        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_loss_record.append(test_loss / train_n)
        test_acc_record.append(test_acc / train_n)
        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_loss_record.txt',
                   np.array(test_loss_record))
        np.savetxt(args.fname + '/test_acc_record.txt',
                   np.array(test_acc_record))
        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_n,
                    'test_adv_loss': test_adv_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_n
Exemplo n.º 8
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print("Couldn't find a dataset with a validation split, did you run "
                  "generate_validation.py?")
            return
        val_set = list(zip(transpose(dataset['val']['data']/255.), dataset['val']['labels']))
        val_batches = Batches(val_set, args.batch_size, shuffle=False, num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(zip(transpose(pad(dataset['train']['data'], 4)/255.),
        dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x, args.batch_size, shuffle=True, set_random_choices=True, num_workers=2)

    test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels']))
    test_batches = Batches(test_set, args.batch_size_test, shuffle=False, num_workers=2)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
        proxy = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
        proxy = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    proxy = nn.DataParallel(proxy).cuda()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model, proxy=proxy, proxy_optim=proxy_opt, gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs // 3, args.epochs * 2 // 3, args.epochs], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':
        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':
        def lr_schedule(t):
            return args.lr_max - (t//(args.epochs//10))*(args.lr_max/10)
    elif args.lr_schedule == 'cosine':
        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp([t], [0, 0.4 * args.epochs, args.epochs], [0, args.lr_max, 0])[0]

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(args.fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()
            elif args.attack == 'fgsm':
                delta = attack_pgd(model, X, y, epsilon, args.fgsm_alpha*epsilon, 1, 1, args.norm)
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)
            X_adv = normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv,
                                             targets=y)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters_test, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.attack == 'none':
                    delta = torch.zeros_like(X)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters_test, args.restarts, args.norm, early_stop=args.eval)
                delta = delta.detach()

                robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                    val_loss/val_n, val_acc/val_n, val_robust_loss/val_n, val_robust_acc/val_n)

                if val_robust_acc/val_n > best_val_robust_acc:
                    torch.save({
                            'state_dict':model.state_dict(),
                            'test_robust_acc':test_robust_acc/test_n,
                            'test_robust_loss':test_robust_loss/test_n,
                            'test_loss':test_loss/test_n,
                            'test_acc':test_acc/test_n,
                            'val_robust_acc':val_robust_acc/val_n,
                            'val_robust_loss':val_robust_loss/val_n,
                            'val_loss':val_loss/val_n,
                            'val_acc':val_acc/val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc/val_n

            # save checkpoint
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc/test_n > best_test_robust_acc:
                torch.save({
                        'state_dict':model.state_dict(),
                        'test_robust_acc':test_robust_acc/test_n,
                        'test_robust_loss':test_robust_loss/test_n,
                        'test_loss':test_loss/test_n,
                        'test_acc':test_acc/test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc/test_n
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1, -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
Exemplo n.º 9
0
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    num_workers = 2
    train_dataset = datasets.SVHN(
        args.data_dir, split='train', transform=train_transform, download=True)
    test_dataset = datasets.SVHN(
        args.data_dir, split='test', transform=test_transform, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    # model = models_dict[args.architecture]().cuda()
    # model.apply(initialize_weights)
    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    model.train()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')
    else:
        start_epoch = 0


    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            if args.eval:
                break
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()

            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        if not args.eval: 
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1, -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
Exemplo n.º 10
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty
    fname = None
    if args.sample == 100:
        fname = args.output_dir + "/default/" + args.attack + "/"
    else:
        fname = args.output_dir + "/" + str(
            args.sample) + "sample/" + args.attack + "/"

    if args.attack == "combine":
        fname += args.list + "/"

    if not os.path.exists(fname):
        os.makedirs(fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = False

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    train_batches = Batches(train_set,
                            args.batch_size,
                            shuffle=shuffle,
                            set_random_choices=False,
                            num_workers=2)

    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=0)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "combine":
        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        for i in range(len(attacks)):
            _adv_dir = "adv_examples/{}/".format(attacks[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))

    elif args.attack == "all":
        print("Loading attacks")
        ATTACK_LIST = [
            "autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm",
            "newtonfool", "pgd", "pixelattack", "spatialtransformation",
            "squareattack"
        ]
        for i in range(len(ATTACK_LIST)):
            print("Attack: ", ATTACK_LIST[i])
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    else:
        raise ValueError("Unknown adversarial data")

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Dataset shape: ", train_adv_images.shape)
    print("Dataset type: ", type(train_adv_images))
    print("Label shape: ", len(train_adv_labels))
    print("")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=0)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=0)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == "ResNet18":
        model = resnet18(pretrained=True)
        proxy = resnet18(pretrained=True)
    elif args.model == 'PreActResNet18':
        model = PreActResNet18()
        proxy = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    proxy = nn.DataParallel(proxy).cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, 0.4 * args.epochs, args.epochs], [0, args.lr_max, 0])[0]

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.output_dir,
                             f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            X_adv = normalize(adv_batch["input"])
            y_adv = adv_batch["target"]

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y_adv)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y_adv)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y_adv.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_robust_loss = 0
        test_robust_acc = 0
        test_robust_n = 0

        for i, batch in enumerate(test_batches):
            X, y = normalize(batch['input']), batch['target']

            output = model(X)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, adv_batch in enumerate(test_adv_batches):
            X_adv, y_adv = normalize(adv_batch["input"]), adv_batch["target"]

            robust_output = model(X_adv)
            robust_loss = criterion(robust_output, y_adv)

            test_robust_loss += robust_loss.item() * y_adv.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            test_robust_n += y_adv.size(0)

        test_time = time.time()

        logger.info('%d \t %.3f \t\t %.3f \t\t\t %.3f \t\t %.3f', epoch,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_robust_acc / test_robust_n)

        # save checkpoint
        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(), os.path.join(fname,
                                                      f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_robust_n,
                    'test_robust_loss': test_robust_loss / test_robust_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
Exemplo n.º 11
0
def main():
    args = get_args()
    #     names = get_auto_fname(args)
    #     args.fname = args.fname + names

    args.fname += args.train_adversarial + "/"

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)


#     eval_dir = args.fname + '/eval/' + args.test_adversarial + "/"

    eval_dir = args.fname + "eval/"
    if args.model_epoch != -1:
        eval_dir += str(args.model_epoch) + "/"
    else:
        eval_dir += "best/"
    eval_dir += args.test_adversarial + "/"

    if not os.path.exists(eval_dir):
        print("Make dirs: ", eval_dir)
        os.makedirs(eval_dir)

    logger = logging.getLogger(__name__)
    logging.basicConfig(format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG,
                        handlers=[
                            logging.FileHandler(
                                os.path.join(eval_dir, "output.log")),
                            logging.StreamHandler()
                        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Prepare data
    dataset = cifar10(args.data_dir)

    x_test = (dataset['test']['data'] / 255.)
    x_test = transpose(x_test).astype(np.float32)
    y_test = dataset['test']['labels']

    test_set = list(zip(x_test, y_test))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=4)

    print("Train Adv Attack Data: ", args.train_adversarial)

    adv_dir = "adv_examples/{}/".format(args.test_adversarial)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    ATTACK_LIST = [
        "autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm", "newtonfool",
        "pgd", "pixelattack", "spatialtransformation", "squareattack"
    ]

    if args.test_adversarial in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        test_adv_images_on_train = adv_train_data["adv"]
        test_adv_labels_on_train = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images_on_test = adv_test_data["adv"]
        test_adv_labels_on_test = adv_test_data["label"]
    elif args.test_adversarial == "all":
        for i in range(len(ATTACK_LIST)):
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                #                 print(train_adv_images.shape)
                #                 print(adv_train_data["adv"].shape)
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))

            test_adv_images_on_train = train_adv_images
            test_adv_labels_on_train = train_adv_labels
            test_adv_images_on_test = test_adv_images
            test_adv_labels_on_test = test_adv_labels

    elif args.test_adversarial in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        test_adv_images_on_train = adv_data["adv"].numpy()
        test_adv_labels_on_train = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images_on_test = adv_data["adv"].numpy()
        test_adv_labels_on_test = adv_data["label"].numpy()
    else:
        raise ValueError("Unknown adversarial data")

    print("Test Adv Attack Data: ", args.test_adversarial)

    test_adv_on_train_set = list(
        zip(test_adv_images_on_train, test_adv_labels_on_train))

    test_adv_on_train_batches = Batches(test_adv_on_train_set,
                                        args.batch_size,
                                        shuffle=False,
                                        set_random_choices=False,
                                        num_workers=4)

    test_adv_on_test_set = list(
        zip(test_adv_images_on_test, test_adv_labels_on_test))

    test_adv_on_test_batches = Batches(test_adv_on_test_set,
                                       args.batch_size,
                                       shuffle=False,
                                       num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(pretrained=True)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(pretrained=True)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(pretrained=True)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    model.cuda()
    model.train()

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    # If we use freeAT or fastAT with previous init
    if args.train_adversarial == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.train_adversarial == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.train_adversarial == 'free':
        epochs = int(math.ceil(args.epochs / args.train_adversarial_iters))
    else:
        epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.model_epoch == -1:
        if args.train_adversarial == "original":
            logger.info(f'Run using the original model')
        else:
            logger.info(f'Run using the best model')
            model.load_state_dict(
                torch.load(os.path.join(args.fname,
                                        f'model_best.pth'))["state_dict"])
    else:
        model.load_state_dict(
            torch.load(
                os.path.join(args.fname,
                             'model_' + str(args.model_epoch) + '.pth')))

    if args.eval:
        logger.info("[Evaluation mode]")

    # Evaluate on test data
    model.eval()

    test_loss = 0
    test_acc = 0
    test_n = 0
    y_original = np.array([])
    y_original_pred = np.array([])

    test_adv_test_loss = 0
    test_adv_test_acc = 0
    test_adv_test_n = 0
    y_adv = np.array([])
    y_adv_pred = np.array([])

    test_adv_train_loss = 0
    test_adv_train_acc = 0
    test_adv_train_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)
        loss = criterion(output, y)

        test_loss += loss.item() * y.size(0)
        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

        y_original = np.append(y_original, y.cpu().numpy())
        y_original_pred = np.append(y_original_pred,
                                    output.max(1)[1].cpu().numpy())

    for i, batch in enumerate(test_adv_on_test_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        cross_robust_output = model(adv_input)
        cross_robust_loss = criterion(cross_robust_output, y)

        test_adv_test_loss += cross_robust_loss.item() * y.size(0)
        test_adv_test_acc += (cross_robust_output.max(1)[1] == y).sum().item()
        test_adv_test_n += y.size(0)

        y_adv = np.append(y_adv, y.cpu().numpy())
        y_adv_pred = np.append(y_adv_pred,
                               cross_robust_output.max(1)[1].cpu().numpy())

    for i, batch in enumerate(test_adv_on_train_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        cross_robust_output = model(adv_input)
        cross_robust_loss = criterion(cross_robust_output, y)

        test_adv_train_loss += cross_robust_loss.item() * y.size(0)
        test_adv_train_acc += (cross_robust_output.max(1)[1] == y).sum().item()
        test_adv_train_n += y.size(0)

    test_time = time.time()

    logger.info(
        "Test Acc \tTest Robust Acc on Test \tTest Robust Acc on Train")
    logger.info('%.4f \t\t %.4f \t\t %.4f', test_acc / test_n,
                test_adv_test_acc / test_adv_test_n,
                test_adv_train_acc / test_adv_train_n)

    y_original = y_original.astype(np.int)
    y_original_pred = y_original_pred.astype(np.int)

    y_adv = y_adv.astype(np.int)
    y_adv_pred = y_adv_pred.astype(np.int)

    logger.info("Y_original")
    logger.info(y_original)
    np.savetxt(os.path.join(eval_dir, "Y_original.txt"), y_original, fmt='%i')

    logger.info("Y_original_pred")
    logger.info(y_original_pred)
    np.savetxt(os.path.join(eval_dir, "Y_original_pred.txt"),
               y_original_pred,
               fmt='%i')

    logger.info("Y_adv")
    logger.info(y_adv)
    np.savetxt(os.path.join(eval_dir, "Y_adv.txt"), y_adv, fmt='%i')

    logger.info("Y_adv_pred")
    logger.info(y_adv_pred)
    np.savetxt(os.path.join(eval_dir, "Y_adv_pred.txt"), y_adv_pred, fmt='%i')