Exemplo n.º 1
0
def build_model():
    if args.model_type == 'wide':
        model = WideResNet(args.layers,
                           args.dataset == 'cifar10' and 10 or 100,
                           args.widen_factor,
                           dropRate=args.droprate)
    else:
        model = ResNet32(args.dataset == 'cifar10' and 10 or 100)
    # weights_init(model)

    # print('Number of model parameters: {}'.format(
    #     sum([p.data.nelement() for p in model.params()])))

    if torch.cuda.is_available():
        model.cuda()
        torch.backends.cudnn.benchmark = True

    return model
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
Exemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.arch == 'wrn':
        num_classes = 1000
        size = 64
        config = [args.depth, args.k, args.drop, num_classes, size]
        model = WideResNet(config)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 65, 75],
                                                     gamma=0.1)
    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                  std=[0.229, 0.224, 0.225])
    normalize = transforms.Normalize(mean=[0.482, 0.458, 0.408],
                                     std=[0.269, 0.261, 0.276])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    begin = time.time()
    for epoch in trange(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # adjust_learning_rate(optimizer, epoch, args)
        scheduler.step()

        epoch_begin = time.time()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)
        epoch_end = time.time()

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        print("epoch %d, time %.2f" % (epoch, epoch_end - epoch_begin))

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            fname = "%s.pth.tar" % args.arch
            if args.arch == "wrn":
                fname = "%s-%d-%s.pth.tar" % (args.arch, args.depth, str(
                    args.k))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=fname)
    print("num epoch %d, time %.2f" % (args.epochs, epoch_end - begin))
Exemplo n.º 4
0
def main(args):
    if torch.cuda.is_available() is True:
        print('Utilizing GPU')
        # torch.cuda.set_device(args.gpu_num)
    train_loader, val_loader = load_data(args)
    # create model
    if args.dataset == 'image_net':
        model = alexnet()
        top_k = (1, 5)
        val_len = len(val_loader.dataset.imgs)
    else:
        model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                           args.widen_factor, dropRate=args.droprate)
        top_k = (1,)
        val_len = len(val_loader.dataset.test_labels)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    # for training on multiple GPUs.
    model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    server = ParameterServer.get_server(args.optimizer, model, args)
    val_statistics = Statistics.get_statistics('image_classification', args)
    train_statistics = Statistics.get_statistics('image_classification', args)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    # ghost batch normalization (128 as baseline)
    repeat = args.batch_size // 128 if args.gbn == 1 else 1
    total_iterations = args.iterations_per_epoch + val_len // args.batch_size
    if args.bar is True:
        train_bar = IncrementalBar('Training  ', max=args.iterations_per_epoch, suffix='%(percent)d%%')
        val_bar = IncrementalBar('Evaluating', max=total_iterations, suffix='%(percent)d%%')
    else:
        train_bar = None
        val_bar = None

    print(
        '{}: Training neural network for {} epochs with {} workers'.format(args.sim_num, args.epochs, args.workers_num))
    train_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, server, epoch, args.workers_num, args.grad_clip, repeat, train_bar)
        train_time = time.time() - train_time
        if args.bar is True:
            train_bar.finish()
            train_bar.index = 0

        # evaluate on validation set
        val_time = time.time()
        val_loss, val_error = validate(val_loader, model, criterion, server, val_statistics, top_k, val_bar)
        train_loss, train_error = validate(train_loader, model, criterion, server, train_statistics, top_k, val_bar,
                                           save_norm=True)
        val_time = time.time() - val_time
        if args.bar is True:
            val_bar.finish()
            val_bar.index = 0
        print('Epoch [{0:1d}]: Train: Time [{1:.2f}], Loss [{2:.3f}], Error[{3:.3f}] |'
              ' Test: Time [{4:.2f}], Loss [{5:.3f}], Error[{6:.3f}]'
              .format(epoch, train_time, train_loss, train_error, val_time, val_loss, val_error))
        train_time = time.time()

    return train_statistics, val_statistics
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    num_workers = 2
    train_dataset = datasets.CIFAR100(args.data_dir,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
    test_dataset = datasets.CIFAR100(args.data_dir,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18(num_classes=100)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    model.train()

    params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    epochs = args.epochs

    if args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, args.epochs // 2, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop:
                return args.lr_max
            else:
                return args.lr_max / 10.

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_robust_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc \t Defence Mean \t Attack Mean'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        defence_mean = 0
        for i, (X, y) in enumerate(train_loader):
            if args.eval:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   args.fgsm_alpha * epsilon,
                                   1,
                                   1,
                                   'l_inf',
                                   fgsm_init=args.fgsm_init)
            elif args.attack == 'multigrad':
                delta = multi_grad(model, X, y, epsilon,
                                   args.fgsm_alpha * epsilon, args.multi_samps,
                                   args.multi_th, args.multi_parallel)
            elif args.attack == 'zerograd':
                delta = zero_grad(model, X, y, epsilon,
                                  args.fgsm_alpha * epsilon, args.zero_qval,
                                  args.zero_iters, args.fgsm_init)
            elif args.attack == 'pgd':
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 10, 1,
                                   'l_inf')
            delta = delta.detach()
            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            defence_mean += torch.mean(torch.abs(delta)) * y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        attack_mean = 0
        for i, (X, y) in enumerate(test_loader):
            if not epoch + 1 == epochs and not args.full_test and i > len(
                    test_loader) / 10:
                break
            X, y = X.cuda(), y.cuda()

            delta = attack_pgd(model,
                               X,
                               y,
                               epsilon,
                               pgd_alpha,
                               args.test_iters,
                               args.test_restarts,
                               'l_inf',
                               early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)
            attack_mean += torch.mean(torch.abs(delta)) * y.size(0)

        test_time = time.time()

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n,
                defence_mean * 255 / train_n, attack_mean * 255 / test_n)

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n, -1,
                attack_mean * 255 / test_n)

        if args.eval or epoch + 1 == epochs:
            start_test_time = time.time()
            test_loss = 0
            test_acc = 0
            test_robust_loss = 0
            test_robust_acc = 0
            test_n = 0
            for i, (X, y) in enumerate(test_loader):
                X, y = X.cuda(), y.cuda()

                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   pgd_alpha,
                                   50,
                                   10,
                                   'l_inf',
                                   early_stop=True)
                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                test_robust_loss += robust_loss.item() * y.size(0)
                test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                test_loss += loss.item() * y.size(0)
                test_acc += (output.max(1)[1] == y).sum().item()
                test_n += y.size(0)

            logger.info(
                'PGD50 \t time: %.1f,\t clean loss: %.4f,\t clean acc: %.4f,\t robust loss: %.4f,\t robust acc: %.4f',
                time.time() - start_test_time, test_loss / test_n,
                test_acc / test_n, test_robust_loss / test_n,
                test_robust_acc / test_n)
            return
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = 'noise_predictor/' + names
    else:
        args.fname = 'noise_predictor/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = True

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    print("Attacks")
    attacks = args.list.split("_")
    print(attacks)

    for i in range(len(attacks)):
        _adv_dir = "adv_examples/{}/".format(attacks[i])
        train_path = _adv_dir + "train.pth"
        test_path = _adv_dir + "test.pth"

        adv_train_data = torch.load(train_path)
        adv_test_data = torch.load(test_path)

        if i == 0:
            train_adv_images = adv_train_data["adv"]
            test_adv_images = adv_test_data["adv"]
            train_adv_labels = [i] * len(adv_train_data["label"])
            test_adv_labels = [i] * len(adv_test_data["label"])
        else:
            train_adv_images = np.concatenate(
                (train_adv_images, adv_train_data["adv"]))
            test_adv_images = np.concatenate(
                (test_adv_images, adv_test_data["adv"]))
            train_adv_labels = np.concatenate(
                (train_adv_labels, [i] * len(adv_train_data["label"])))
            test_adv_labels = np.concatenate(
                (test_adv_labels, [i] * len(adv_test_data["label"])))

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    print("")
    print("Train Adv Attack Data: ", attacks)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(num_classes=args.num_classes)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(num_classes=args.num_classes)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(num_classes=3)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()

    model.train()

    logger.info('Epoch \t Train Robust Acc \t Test Robust Acc')

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, adv_batch in enumerate(train_adv_batches):
            if args.eval:
                break
            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_n += adv_y.size(0)

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t\t %.4f', epoch + 1,
                    train_robust_acc / train_n, test_adv_acc / test_adv_n)

        # Save results
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_adv_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_adv_n,
                    'test_adv_loss': test_adv_loss / test_adv_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_adv_n
Exemplo n.º 7
0
def main():
    global args, best_prec1, data
    if args.tensorboard: configure("runs/%s"%(args.name))

    # Data loading code
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x/255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
        	transforms.ToTensor(),
        	transforms.Lambda(lambda x: F.pad(
        						Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
        						(4,4,4,4),mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
        ])

    # create model
    print("Creating the model...")
    model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor, dropRate=args.droprate)

    rank = MPI.COMM_WORLD.Get_rank()
    args.rank = rank
    args.seed += rank
    _set_seed(args.seed)

    print("Creating the DataLoader...")
    kwargs = {'num_workers': args.num_workers}#, 'pin_memory': args.use_cuda}
    assert(args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data', train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data', train=False, transform=transform_test),
        batch_size=args.batch_size, shuffle=False, **kwargs)

    # get the number of model parameters
    args.num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print('Number of model parameters: {}'.format(args.num_parameters))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if args.use_cuda:
        print("Moving the model to the GPU")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model = model.cuda()
        #model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        criterion = criterion.cuda()
    #  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                              #  momentum=args.momentum, nesterov=args.nesterov,
                              #  weight_decay=args.weight_decay)
    #  optimizer = torch.optim.ASGD(model.parameters(), args.lr)
    #  from distributed_opt import MiniBatchSGD
    #  import torch.distributed as dist
    #  rank = np.random.choice('gloo')

    # TODO: pass kwargs to code
    print('initing MiniBatchSGD')
    print(list(model.parameters())[6].view(-1)[:3])
    if args.code == 'sgd':
        code = codings.svd.SVD(compress=False)
    elif args.code == 'svd':
        print("train.py, svd_rank =", args.svd_rank)
        code = codings.svd.SVD(random_sample=args.svd_rescale, rank=args.svd_rank,
                               compress=args.compress)
    elif args.code == 'qsgd':
        code = codings.qsgd.QSGD(scheme='qsgd')
    elif args.code == 'terngrad':
        code = codings.qsgd.QSGD(scheme='terngrad')
    elif args.code == 'qsvd':
        code = codings.qsvd.QSVD(scheme=args.scheme, rank=args.svd_rank)
    else:
        raise ValueError('args.code not recognized')

    names = [n for n, p in model.named_parameters()]
    assert len(names) == len(set(names))
    #  optimizer = MPI_PS(model.named_parameters(), model.parameters(), args.lr,
                       #  code=code, optim='adam',
                       #  use_mpi=args.use_mpi, cuda=args.use_cuda)
    optimizer = SGD(model.named_parameters(), model.parameters(), args.lr,
                    code=code, optim='sgd',
                    use_mpi=args.use_mpi, cuda=args.use_cuda, compress_level=args.compress_level)

    data = []
    train_data = []
    train_time = 0
    for epoch in range(args.start_epoch, args.epochs + 1):
        print(f"epoch {epoch}")
        adjust_learning_rate(optimizer, epoch+1)

        # train for one epoch
        start = time.time()
        if epoch >= 0:
            train_d = train(train_loader, model, criterion, optimizer, epoch)
        else:
            train_d = []
        train_time += time.time() - start
        train_data += [dict(datum, **vars(args)) for datum in train_d]

        # evaluate on validation set
        if epoch >= args.epochs:
            train_datum = validate(train_loader, model, criterion, epoch)
        else:
            train_datum = {'acc_test': np.inf, 'loss_test': np.inf}
        datum = validate(val_loader, model, criterion, epoch)
        #  train_datum = {'acc_train': 0.1, 'loss_train': 2.3}
        data += [{'train_time': train_time,
                  'whole_train_acc': train_datum['acc_test'],
                  'whole_train_loss': train_datum['loss_test'],
                  'epoch': epoch + 1, **vars(args), **datum}]
        if epoch > 0:
            if len(data) > 1:
                data[-1]['epoch_train_time'] = data[-1]['train_time'] - data[-2]['train_time']
            for key in train_data[-1]:
                values = [datum[key] for i, datum in enumerate(train_data)]
                if 'time' in key:
                    data[-1]["epoch_" + key] = np.sum(values)
                else:
                    data[-1]["epoch_" + key] = values[0]

        df = pd.DataFrame(data)
        train_df = pd.DataFrame(train_data)
        if True:
            time.sleep(1)
            # Yes loss_test IS on train data. (look at what validate returns)
            print('\n\nmin_train_loss', train_datum['loss_test'],
                  optimizer.steps, '\n\n')
            time.sleep(1)
        ids = [str(getattr(args, key)) for key in
               ['layers', 'lr', 'batch_size', 'compress', 'seed', 'num_workers',
                'svd_rank', 'svd_rescale', 'use_mpi', 'qsgd', 'world_size',
                'rank', 'code', 'scheme']]
        _write_csv(df, id=f'-'.join(ids))
        _write_csv(train_df, id=f'-'.join(ids) + '_train')
        pprint({k: v for k, v in data[-1].items()
                if k in ['svd_rank', 'svd_rescale', 'qsgd', 'compress']})
        pprint({k: v for k, v in data[-1].items()
                if k in ['train_time', 'num_workers', 'loss_test',
                         'acc_test', 'epoch', 'compress', 'svd_rank', 'qsgd'] or 'time' in k})
        prec1 = datum['acc_test']

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)
        print('Best accuracy: ', best_prec1)
Exemplo n.º 8
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard:
        configure("runs/%s" % (args.name))

    # Data loading code
    # normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
    #								 std=[x/255.0 for x in [63.0, 62.1, 66.7]])


# 	if args.augment:
# 		transform_train = transforms.Compose([
# 			transforms.Resize(256,256),
# 			transforms.ToTensor(),
# 			transforms.Lambda(lambda x: F.pad(
# 								Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
# 								(4,4,4,4),mode='reflect').data.squeeze()),
# 			transforms.ToPILImage(),
# 			transforms.RandomCrop(32),
# 			transforms.RandomHorizontalFlip(),
# 			transforms.ToTensor(),
# 			normalize,
# 			])
# 	else:
# 		transform_train = transforms.Compose([
# 			transforms.ToTensor(),
# 			normalize,
# 			])

# 	transform_test = transforms.Compose([
# 		transforms.ToTensor(),
# 		normalize
# 		])

#kwargs = {'num_workers': 1, 'pin_memory': True}
#assert(args.dataset == 'cifar10' or args.dataset == 'cifar100')

    train_data_path = "/home/mil/gupta/ifood18/data/training_set/train_set/"
    val_data_path = "/home/mil/gupta/ifood18/data/val_set/"

    train_label = "/home/mil/gupta/ifood18/data/labels/train_info.csv"
    val_label = "/home/mil/gupta/ifood18/data/labels/val_info.csv"

    #transformations = transforms.Compose([transforms.ToTensor()])

    # train_h5 = "/home/mil/gupta/ifood18/data/h5data/train_data.h5py"
    train_h5 = "/home/mil/gupta/ifood18/data/h5data/train_data_partial.h5py"
    train_dataset = H5Dataset(train_h5)
    #val_dataset =  H5Dataset(val_data_path, val_label, transform= None)

    #custom_mnist_from_csv = \
    #    CustomDatasetFromCSV('../data/mnist_in_csv.csv',
    #                         28, 28,
    #                         transformations)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=1)

    #val_loader = torch.utils.data.DataLoader(dataset=val_dataset,batch_size=256,num_workers=1,shuffle=True)

    # 	train_labels = pd.read_csv('./data/labels/train_info.csv')

    # 	train_loader = torch.utils.data.DataLoader(
    #         datasets.__dict__[args.dataset.upper()](train_data_path, train=True, download=True,
    #                          transform=transform_train),
    #         batch_size=args.batch_size, shuffle=True, **kwargs)
    #     val_loader = torch.utils.data.DataLoader(
    #         datasets.__dict__[args.dataset.upper()](val_data_path, train=False, transform=transform_test),
    #         batch_size=args.batch_size, shuffle=True, **kwargs)

    # create model
    model = WideResNet(args.layers,
                       211,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        #if epoch % args.validate_freq == 0:
        # evaluate on validation set
        prec3 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec3 > best_prec3
        best_prec1 = max(prec3, best_prec3)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec3,
            }, is_best)

    print('Best accuracy: ', best_prec1)
Exemplo n.º 9
0
def get_resnet(multi_gpu=False):
    model = WideResNet(depth=16, num_classes=10, widen_factor=2, drop_rate=0.0)
    if multi_gpu:
        model = torch.nn.DataParallel(model)
    return model.cuda()
Exemplo n.º 10
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard:
        configure("../WideResNet-pytorch/runs/%s" % args.name)

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')

    train_indices = np.load(
        'underlying_valid_indices_ie_meta_train_indices.npy')
    valid_indices = np.load('meta_val_indices.npy')

    train_dataset = datasets.CIFAR100('../data',
                                      train_indices,
                                      train=True,
                                      transform=transform_test,
                                      download=True)
    valid_dataset = datasets.CIFAR100('../data',
                                      valid_indices,
                                      train=True,
                                      transform=transform_test,
                                      download=True)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)

    val_loader = torch.utils.data.DataLoader(valid_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False)

    # create model
    model = WideResNet(args.layers,
                       100,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    device = torch.device(CUDA_DEVICE)
    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    saved_state = torch.load('../WideResNet-pytorch/runs/model_60.pth.tar')

    model.load_state_dict(saved_state['state_dict'])

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # generate intermediate outputs
    prec1 = generate_intermediate_outputs(val_loader, model, criterion, 0)
Exemplo n.º 11
0
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    num_workers = 2
    train_dataset = datasets.SVHN(
        args.data_dir, split='train', transform=train_transform, download=True)
    test_dataset = datasets.SVHN(
        args.data_dir, split='test', transform=test_transform, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    # model = models_dict[args.architecture]().cuda()
    # model.apply(initialize_weights)
    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    model.train()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')
    else:
        start_epoch = 0


    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            if args.eval:
                break
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()

            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        if not args.eval: 
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1, -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
Exemplo n.º 12
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty
    fname = None
    if args.sample == 100:
        fname = args.output_dir + "/default/" + args.attack + "/"
    else:
        fname = args.output_dir + "/" + str(
            args.sample) + "sample/" + args.attack + "/"

    if args.attack == "combine":
        fname += args.list + "/"

    if not os.path.exists(fname):
        os.makedirs(fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = False

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    train_batches = Batches(train_set,
                            args.batch_size,
                            shuffle=shuffle,
                            set_random_choices=False,
                            num_workers=2)

    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=0)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "combine":
        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        for i in range(len(attacks)):
            _adv_dir = "adv_examples/{}/".format(attacks[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))

    elif args.attack == "all":
        print("Loading attacks")
        ATTACK_LIST = [
            "autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm",
            "newtonfool", "pgd", "pixelattack", "spatialtransformation",
            "squareattack"
        ]
        for i in range(len(ATTACK_LIST)):
            print("Attack: ", ATTACK_LIST[i])
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    else:
        raise ValueError("Unknown adversarial data")

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Dataset shape: ", train_adv_images.shape)
    print("Dataset type: ", type(train_adv_images))
    print("Label shape: ", len(train_adv_labels))
    print("")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=0)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=0)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == "ResNet18":
        model = resnet18(pretrained=True)
        proxy = resnet18(pretrained=True)
    elif args.model == 'PreActResNet18':
        model = PreActResNet18()
        proxy = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    proxy = nn.DataParallel(proxy).cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, 0.4 * args.epochs, args.epochs], [0, args.lr_max, 0])[0]

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.output_dir,
                             f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            X_adv = normalize(adv_batch["input"])
            y_adv = adv_batch["target"]

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y_adv)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y_adv)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y_adv.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_robust_loss = 0
        test_robust_acc = 0
        test_robust_n = 0

        for i, batch in enumerate(test_batches):
            X, y = normalize(batch['input']), batch['target']

            output = model(X)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, adv_batch in enumerate(test_adv_batches):
            X_adv, y_adv = normalize(adv_batch["input"]), adv_batch["target"]

            robust_output = model(X_adv)
            robust_loss = criterion(robust_output, y_adv)

            test_robust_loss += robust_loss.item() * y_adv.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            test_robust_n += y_adv.size(0)

        test_time = time.time()

        logger.info('%d \t %.3f \t\t %.3f \t\t\t %.3f \t\t %.3f', epoch,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_robust_acc / test_robust_n)

        # save checkpoint
        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(), os.path.join(fname,
                                                      f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_robust_n,
                    'test_robust_loss': test_robust_loss / test_robust_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
Exemplo n.º 13
0
            image_paths.append(image_path)

    with open("submission_file.txt", "w") as f:
        for image_path, top_5_prediction in zip(image_paths, predictions):

            f.write(image_path + " " + " ".join(map(str, top_5_prediction)))
            f.write("\n")


if __name__ == '__main__':

    parameters = load_parameters("parameters.ini")

    # fox = FoxNet()
    fox = WideResNet(depth=40, num_classes=100, widen_factor=4, dropRate=0.3)
    # fox = WideResNet(depth=16, num_classes=100, widen_factor=4, dropRate=0.3)

    # If loading
    # fox.load_state_dict(torch.load("current_best_model_weights"))

    use_cuda = torch.cuda.is_available()

    if use_cuda:
        print("Using CUDA")
        fox.cuda()

    epochs = 250

    train_fox(fox, epochs, use_cuda)
    # evaluate_foxnet(fox, use_cuda)
Exemplo n.º 14
0
def main():
    global args, best_prec1, data
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False, volatile=True),
                                              (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': args.num_workers, 'pin_memory': use_cuda}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    print("Creating the DataLoader...")
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    # create model
    print("Creating the model...")
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if use_cuda:
        print("Moving the model to the GPU")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model = model.cuda()
        #model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()
    #optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                            momentum=args.momentum, nesterov=args.nesterov,
    #                            weight_decay=args.weight_decay)
    optimizer = torch.optim.ASGD(model.parameters(), args.lr)

    data = []
    train_time = 0
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        start = time.time()
        train(train_loader, model, criterion, optimizer, epoch)
        train_time += time.time() - start

        # evaluate on validation set
        datum = validate(val_loader, model, criterion, epoch)
        data += [{
            'train_time': train_time,
            'epoch': epoch + 1,
            **vars(args),
            **datum
        }]
        df = pd.DataFrame(data)
        _write_csv(df, id=f'{args.num_workers}_{args.seed}')
        pprint({
            k: v
            for k, v in data[-1].items() if k in
            ['train_time', 'num_workers', 'test_loss', 'test_acc', 'epoch']
        })
        prec1 = datum['test_acc']

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print('Best accuracy: ', best_prec1)
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = '../../trained_models/' + names
    else:
        args.fname = '../../trained_models/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    print("")
    print("Train Original Data: ")
    print("Len: ", len(train_set))
    print("")

    shuffle = False

    train_batches = Batches(train_set, args.batch_size, shuffle=shuffle)
    test_batches = Batches(test_set, args.batch_size, shuffle=False)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    #     ATTACK_LIST = ["autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm", "newtonfool", "pgd", "pixelattack", "spatialtransformation", "squareattack"]
    ATTACK_LIST = [
        "pixelattack", "spatialtransformation", "squareattack", "fgsm",
        "deepfool", "bim", "cw", "pgd", "autoattack", "autopgd", "newtonfool"
    ]

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "all":

        for i in range(len(ATTACK_LIST)):
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                #                 print(train_adv_images.shape)
                #                 print(adv_train_data["adv"].shape)
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    elif args.attack == "combine":

        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        if args.balanced == None:
            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                if i == 0:
                    train_adv_images = adv_train_data["adv"]
                    train_adv_labels = adv_train_data["label"]
                    test_adv_images = adv_test_data["adv"]
                    test_adv_labels = adv_test_data["label"]
                else:
                    #                 print(train_adv_images.shape)
                    #                 print(adv_train_data["adv"].shape)
                    train_adv_images = np.concatenate(
                        (train_adv_images, adv_train_data["adv"]))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels, adv_train_data["label"]))
                    test_adv_images = np.concatenate(
                        (test_adv_images, adv_test_data["adv"]))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels, adv_test_data["label"]))
        else:
            proportion_str = args.balanced.split("_")
            proportion = [int(x) for x in proportion_str]
            sum_proportion = sum(proportion)
            proportion = [float(x) / float(sum_proportion) for x in proportion]
            sum_samples = 0

            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                random_state = 0
                num_samples = 0
                total = 50000
                if i != len(attacks) - 1:
                    n_samples = int(proportion[i] * total)
                    sum_samples += n_samples
                else:
                    n_samples = total - sum_samples
                print("Sample")
                print(n_samples)

                if i == 0:
                    train_adv_images = resample(adv_train_data["adv"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    train_adv_labels = resample(adv_train_data["label"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    test_adv_images = resample(adv_test_data["adv"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                    test_adv_labels = resample(adv_test_data["label"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                else:
                    train_adv_images = np.concatenate(
                        (train_adv_images,
                         resample(adv_train_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels,
                         resample(adv_train_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_images = np.concatenate(
                        (test_adv_images,
                         resample(adv_test_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels,
                         resample(adv_test_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))

    else:
        raise ValueError("Unknown adversarial data")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(pretrained=True)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(pretrained=True)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(pretrained=True)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    # If we use freeAT or fastAT with previous init
    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']

            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            # Training losses
            if args.mixup:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)

            elif args.mixture:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = args.mixture_alpha * criterion(
                    robust_output,
                    adv_y) + (1 - args.mixture_alpha) * criterion(output, y)

            else:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                if args.focalloss:
                    criterion_nonreduct = nn.CrossEntropyLoss(reduction='none')
                    robust_confidence = F.softmax(robust_output,
                                                  dim=1)[:, adv_y].detach()
                    robust_loss = (criterion_nonreduct(robust_output, adv_y) *
                                   ((1. - robust_confidence)**
                                    args.focallosslambda)).mean()

                elif args.use_DLRloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * dlr_loss(
                            robust_output, adv_y)

                elif args.use_CWloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * CW_loss(
                            robust_output, adv_y)

                elif args.use_FNandWN:
                    #print('use FN and WN with margin')
                    robust_loss = criterion(
                        args.s_FN * robust_output -
                        onehot_target_withmargin_HE, adv_y)

                else:
                    robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            clean_input = normalize(X)
            clean_input.requires_grad = True
            output = model(clean_input)
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            # Get the gradient norm values
            input_grads = torch.autograd.grad(loss,
                                              clean_input,
                                              create_graph=False)[0]

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            train_grad += input_grads.abs().sum()

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            clean_input = normalize(X)
            output = model(clean_input)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t %.4f \t\t %.4f \t %.4f', epoch + 1,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_adv_acc / test_adv_n)

        # Save results
        train_loss_record.append(train_loss / train_n)
        train_acc_record.append(train_acc / train_n)
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_loss_record.txt',
                   np.array(train_loss_record))
        np.savetxt(args.fname + '/train_acc_record.txt',
                   np.array(train_acc_record))
        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_loss_record.append(test_loss / train_n)
        test_acc_record.append(test_acc / train_n)
        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_loss_record.txt',
                   np.array(test_loss_record))
        np.savetxt(args.fname + '/test_acc_record.txt',
                   np.array(test_acc_record))
        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_n,
                    'test_adv_loss': test_adv_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_n
Exemplo n.º 16
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG,
                        handlers=[
                            logging.FileHandler(
                                os.path.join(args.fname, 'output.log')),
                            logging.StreamHandler()
                        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    num_workers = 2
    train_dataset = datasets.CIFAR100(args.data_dir,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
    test_dataset = datasets.CIFAR100(args.data_dir,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18(num_classes=100)
        proxy = PreActResNet18(num_classes=100)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    proxy = proxy.cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    best_test_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(args.fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       epsilon,
                                       pgd_alpha,
                                       args.attack_iters,
                                       args.restarts,
                                       args.norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                       args.attack_iters, args.restarts,
                                       args.norm)
                delta = delta.detach()
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)
            X_adv = normalize(
                torch.clamp(X + delta[:X.size(0)],
                            min=lower_limit,
                            max=upper_limit))

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                   args.attack_iters_test, args.restarts,
                                   args.norm)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        logger.info(
            '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
            epoch, train_time - start_time, test_time - train_time, lr,
            train_loss / train_n, train_acc / train_n,
            train_robust_loss / train_n, train_robust_acc / train_n,
            test_loss / test_n, test_acc / test_n, test_robust_loss / test_n,
            test_robust_acc / test_n)

        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_n,
                    'test_robust_loss': test_robust_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
Exemplo n.º 17
0
def main():
    global args, best_prec1, exp_dir

    best_prec1 = 0
    args = parser.parse_args()
    print(args.lr_decay_at)
    assert args.normalization in ['GCN_ZCA',
                                  'GCN'], 'normalization {} unknown'.format(
                                      args.normalization)

    global zca
    if 'ZCA' in args.normalization:
        zca_params = torch.load('./data/cifar-10-batches-py/zca_params.pth')
        zca = ZCA(zca_params)
    else:
        zca = None

    exp_dir = os.path.join('experiments', args.name)
    if os.path.exists(exp_dir):
        print("same experiment exist...")
        #return
    else:
        os.makedirs(exp_dir)

    # DATA SETTINGS
    global dataset_train, dataset_test
    if args.dataset == 'cifar10':
        import cifar
        dataset_train = cifar.CIFAR10(args, train=True)
        dataset_test = cifar.CIFAR10(args, train=False)
    if args.UDA:
        # loader for UDA
        dataset_train_uda = cifar.CIFAR10(args, True, True)
        uda_loader = torch.utils.data.DataLoader(
            dataset_train_uda,
            batch_size=args.batch_size_unsup,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)
        iter_uda = iter(uda_loader)
    else:
        iter_uda = None

    train_loader, test_loader = initialize_loader()

    # MODEL SETTINGS
    if args.arch == 'WRN-28-2':
        model = WideResNet(28, [100, 10][int(args.dataset == 'cifar10')],
                           2,
                           dropRate=args.dropout_rate)
        model = torch.nn.DataParallel(model.cuda())
    else:
        raise NotImplementedError('arch {} is not implemented'.format(
            args.arch))
    if args.optimizer == 'Adam':
        print("use Adam optimizer")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.l2_reg)
    elif args.optimizer == 'SGD':
        print("use SGD optimizer")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.l2_reg,
                                    nesterov=args.nesterov)

    if args.lr_decay == 'cosine':
        print("use cosine lr scheduler")
        global scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_iter, eta_min=args.final_lr)

    global batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup
    batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter()
    t = time.time()
    model.train()
    iter_sup = iter(train_loader)
    for train_iter in range(args.max_iter):
        # TRAIN
        lr = adjust_learning_rate(optimizer, train_iter + 1)
        train(model,
              iter_sup,
              optimizer,
              train_iter,
              data_iterator_uda=iter_uda)

        # LOGGING
        if (train_iter + 1) % args.print_freq == 0:
            print('ITER: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})\t'
                  'Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Learning rate {2} TSA th {3}'.format(
                      train_iter,
                      args.max_iter,
                      lr,
                      TSA_th(train_iter),
                      batch_time=batch_time,
                      loss=losses_sup,
                      l1_loss=losses_l1,
                      unsup_loss=losses_unsup,
                      top1=top1))

        if (train_iter +
                1) % args.eval_iter == 0 or train_iter + 1 == args.max_iter:
            # EVAL
            print("evaluation at iter {}".format(train_iter))
            prec1 = test(model, test_loader)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'iter': train_iter + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            print("* Best accuracy: {}".format(best_prec1))
            eval_interval_time = time.time() - t
            t = time.time()
            print("total {} sec for {} iterations".format(
                eval_interval_time, args.eval_iter))
            seconds_remaining = eval_interval_time / float(
                args.eval_iter) * (args.max_iter - train_iter)
            print("{}:{}:{} remaining".format(
                int(seconds_remaining // 3600),
                int((seconds_remaining % 3600) // 60),
                int(seconds_remaining % 60)))
            model.train()
            batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
            ), AverageMeter()
            iter_sup = iter(train_loader)
            if iter_uda is not None:
                iter_uda = iter(uda_loader)
Exemplo n.º 18
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False, volatile=True),
                                              (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')

    train_set = datasets.ImageFolder(root="../classes_data_with_black/train",
                                     transform=transform_train)
    val_set = datasets.ImageFolder(root="../classes_data_with_black/val",
                                   transform=transform_test)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)

    # create model
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print('Best accuracy: ', best_prec1)
    torch.save(model, 'model_original_mnist.pt')
Exemplo n.º 19
0
def main():
    args = get_args()
    #     names = get_auto_fname(args)
    #     args.fname = args.fname + names

    args.fname += args.train_adversarial + "/"

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)


#     eval_dir = args.fname + '/eval/' + args.test_adversarial + "/"

    eval_dir = args.fname + "eval/"
    if args.model_epoch != -1:
        eval_dir += str(args.model_epoch) + "/"
    else:
        eval_dir += "best/"
    eval_dir += args.test_adversarial + "/"

    if not os.path.exists(eval_dir):
        print("Make dirs: ", eval_dir)
        os.makedirs(eval_dir)

    logger = logging.getLogger(__name__)
    logging.basicConfig(format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG,
                        handlers=[
                            logging.FileHandler(
                                os.path.join(eval_dir, "output.log")),
                            logging.StreamHandler()
                        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Prepare data
    dataset = cifar10(args.data_dir)

    x_test = (dataset['test']['data'] / 255.)
    x_test = transpose(x_test).astype(np.float32)
    y_test = dataset['test']['labels']

    test_set = list(zip(x_test, y_test))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=4)

    print("Train Adv Attack Data: ", args.train_adversarial)

    adv_dir = "adv_examples/{}/".format(args.test_adversarial)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    ATTACK_LIST = [
        "autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm", "newtonfool",
        "pgd", "pixelattack", "spatialtransformation", "squareattack"
    ]

    if args.test_adversarial in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        test_adv_images_on_train = adv_train_data["adv"]
        test_adv_labels_on_train = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images_on_test = adv_test_data["adv"]
        test_adv_labels_on_test = adv_test_data["label"]
    elif args.test_adversarial == "all":
        for i in range(len(ATTACK_LIST)):
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                #                 print(train_adv_images.shape)
                #                 print(adv_train_data["adv"].shape)
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))

            test_adv_images_on_train = train_adv_images
            test_adv_labels_on_train = train_adv_labels
            test_adv_images_on_test = test_adv_images
            test_adv_labels_on_test = test_adv_labels

    elif args.test_adversarial in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        test_adv_images_on_train = adv_data["adv"].numpy()
        test_adv_labels_on_train = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images_on_test = adv_data["adv"].numpy()
        test_adv_labels_on_test = adv_data["label"].numpy()
    else:
        raise ValueError("Unknown adversarial data")

    print("Test Adv Attack Data: ", args.test_adversarial)

    test_adv_on_train_set = list(
        zip(test_adv_images_on_train, test_adv_labels_on_train))

    test_adv_on_train_batches = Batches(test_adv_on_train_set,
                                        args.batch_size,
                                        shuffle=False,
                                        set_random_choices=False,
                                        num_workers=4)

    test_adv_on_test_set = list(
        zip(test_adv_images_on_test, test_adv_labels_on_test))

    test_adv_on_test_batches = Batches(test_adv_on_test_set,
                                       args.batch_size,
                                       shuffle=False,
                                       num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(pretrained=True)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(pretrained=True)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(pretrained=True)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    model.cuda()
    model.train()

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    # If we use freeAT or fastAT with previous init
    if args.train_adversarial == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.train_adversarial == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.train_adversarial == 'free':
        epochs = int(math.ceil(args.epochs / args.train_adversarial_iters))
    else:
        epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.model_epoch == -1:
        if args.train_adversarial == "original":
            logger.info(f'Run using the original model')
        else:
            logger.info(f'Run using the best model')
            model.load_state_dict(
                torch.load(os.path.join(args.fname,
                                        f'model_best.pth'))["state_dict"])
    else:
        model.load_state_dict(
            torch.load(
                os.path.join(args.fname,
                             'model_' + str(args.model_epoch) + '.pth')))

    if args.eval:
        logger.info("[Evaluation mode]")

    # Evaluate on test data
    model.eval()

    test_loss = 0
    test_acc = 0
    test_n = 0
    y_original = np.array([])
    y_original_pred = np.array([])

    test_adv_test_loss = 0
    test_adv_test_acc = 0
    test_adv_test_n = 0
    y_adv = np.array([])
    y_adv_pred = np.array([])

    test_adv_train_loss = 0
    test_adv_train_acc = 0
    test_adv_train_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)
        loss = criterion(output, y)

        test_loss += loss.item() * y.size(0)
        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

        y_original = np.append(y_original, y.cpu().numpy())
        y_original_pred = np.append(y_original_pred,
                                    output.max(1)[1].cpu().numpy())

    for i, batch in enumerate(test_adv_on_test_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        cross_robust_output = model(adv_input)
        cross_robust_loss = criterion(cross_robust_output, y)

        test_adv_test_loss += cross_robust_loss.item() * y.size(0)
        test_adv_test_acc += (cross_robust_output.max(1)[1] == y).sum().item()
        test_adv_test_n += y.size(0)

        y_adv = np.append(y_adv, y.cpu().numpy())
        y_adv_pred = np.append(y_adv_pred,
                               cross_robust_output.max(1)[1].cpu().numpy())

    for i, batch in enumerate(test_adv_on_train_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        cross_robust_output = model(adv_input)
        cross_robust_loss = criterion(cross_robust_output, y)

        test_adv_train_loss += cross_robust_loss.item() * y.size(0)
        test_adv_train_acc += (cross_robust_output.max(1)[1] == y).sum().item()
        test_adv_train_n += y.size(0)

    test_time = time.time()

    logger.info(
        "Test Acc \tTest Robust Acc on Test \tTest Robust Acc on Train")
    logger.info('%.4f \t\t %.4f \t\t %.4f', test_acc / test_n,
                test_adv_test_acc / test_adv_test_n,
                test_adv_train_acc / test_adv_train_n)

    y_original = y_original.astype(np.int)
    y_original_pred = y_original_pred.astype(np.int)

    y_adv = y_adv.astype(np.int)
    y_adv_pred = y_adv_pred.astype(np.int)

    logger.info("Y_original")
    logger.info(y_original)
    np.savetxt(os.path.join(eval_dir, "Y_original.txt"), y_original, fmt='%i')

    logger.info("Y_original_pred")
    logger.info(y_original_pred)
    np.savetxt(os.path.join(eval_dir, "Y_original_pred.txt"),
               y_original_pred,
               fmt='%i')

    logger.info("Y_adv")
    logger.info(y_adv)
    np.savetxt(os.path.join(eval_dir, "Y_adv.txt"), y_adv, fmt='%i')

    logger.info("Y_adv_pred")
    logger.info(y_adv_pred)
    np.savetxt(os.path.join(eval_dir, "Y_adv_pred.txt"), y_adv_pred, fmt='%i')