Ejemplo n.º 1
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False, volatile=True),
                                              (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')

    train_set = datasets.ImageFolder(root="../classes_data_with_black/train",
                                     transform=transform_train)
    val_set = datasets.ImageFolder(root="../classes_data_with_black/val",
                                   transform=transform_test)

    train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)

    # create model
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print('Best accuracy: ', best_prec1)
    torch.save(model, 'model_original_mnist.pt')
Ejemplo n.º 2
0
def main():
    global args, best_prec1, data
    if args.tensorboard: configure("runs/%s" % (args.name))

    # Data loading code
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(Variable(
                x.unsqueeze(0), requires_grad=False, volatile=True),
                                              (4, 4, 4, 4),
                                              mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {'num_workers': args.num_workers, 'pin_memory': use_cuda}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    print("Creating the DataLoader...")
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    # create model
    print("Creating the model...")
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if use_cuda:
        print("Moving the model to the GPU")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model = model.cuda()
        #model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()
    #optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                            momentum=args.momentum, nesterov=args.nesterov,
    #                            weight_decay=args.weight_decay)
    optimizer = torch.optim.ASGD(model.parameters(), args.lr)

    data = []
    train_time = 0
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        start = time.time()
        train(train_loader, model, criterion, optimizer, epoch)
        train_time += time.time() - start

        # evaluate on validation set
        datum = validate(val_loader, model, criterion, epoch)
        data += [{
            'train_time': train_time,
            'epoch': epoch + 1,
            **vars(args),
            **datum
        }]
        df = pd.DataFrame(data)
        _write_csv(df, id=f'{args.num_workers}_{args.seed}')
        pprint({
            k: v
            for k, v in data[-1].items() if k in
            ['train_time', 'num_workers', 'test_loss', 'test_acc', 'epoch']
        })
        prec1 = datum['test_acc']

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print('Best accuracy: ', best_prec1)
Ejemplo n.º 3
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(format='[%(asctime)s] - %(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        level=logging.DEBUG,
                        handlers=[
                            logging.FileHandler(
                                os.path.join(args.fname, 'output.log')),
                            logging.StreamHandler()
                        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    num_workers = 2
    train_dataset = datasets.CIFAR100(args.data_dir,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
    test_dataset = datasets.CIFAR100(args.data_dir,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18(num_classes=100)
        proxy = PreActResNet18(num_classes=100)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           100,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    proxy = proxy.cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    best_test_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(args.fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       epsilon,
                                       pgd_alpha,
                                       args.attack_iters,
                                       args.restarts,
                                       args.norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                       args.attack_iters, args.restarts,
                                       args.norm)
                delta = delta.detach()
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)
            X_adv = normalize(
                torch.clamp(X + delta[:X.size(0)],
                            min=lower_limit,
                            max=upper_limit))

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha,
                                   args.attack_iters_test, args.restarts,
                                   args.norm)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        logger.info(
            '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
            epoch, train_time - start_time, test_time - train_time, lr,
            train_loss / train_n, train_acc / train_n,
            train_robust_loss / train_n, train_robust_acc / train_n,
            test_loss / test_n, test_acc / test_n, test_robust_loss / test_n,
            test_robust_acc / test_n)

        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_n,
                    'test_robust_loss': test_robust_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
Ejemplo n.º 4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.arch == 'wrn':
        num_classes = 1000
        size = 64
        config = [args.depth, args.k, args.drop, num_classes, size]
        model = WideResNet(config)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 65, 75],
                                                     gamma=0.1)
    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                  std=[0.229, 0.224, 0.225])
    normalize = transforms.Normalize(mean=[0.482, 0.458, 0.408],
                                     std=[0.269, 0.261, 0.276])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    begin = time.time()
    for epoch in trange(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # adjust_learning_rate(optimizer, epoch, args)
        scheduler.step()

        epoch_begin = time.time()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)
        epoch_end = time.time()

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        print("epoch %d, time %.2f" % (epoch, epoch_end - epoch_begin))

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            fname = "%s.pth.tar" % args.arch
            if args.arch == "wrn":
                fname = "%s-%d-%s.pth.tar" % (args.arch, args.depth, str(
                    args.k))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=fname)
    print("num epoch %d, time %.2f" % (args.epochs, epoch_end - begin))
Ejemplo n.º 5
0
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
Ejemplo n.º 6
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty
    fname = None
    if args.sample == 100:
        fname = args.output_dir + "/default/" + args.attack + "/"
    else:
        fname = args.output_dir + "/" + str(
            args.sample) + "sample/" + args.attack + "/"

    if args.attack == "combine":
        fname += args.list + "/"

    if not os.path.exists(fname):
        os.makedirs(fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = False

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    train_batches = Batches(train_set,
                            args.batch_size,
                            shuffle=shuffle,
                            set_random_choices=False,
                            num_workers=2)

    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=0)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "combine":
        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        for i in range(len(attacks)):
            _adv_dir = "adv_examples/{}/".format(attacks[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))

    elif args.attack == "all":
        print("Loading attacks")
        ATTACK_LIST = [
            "autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm",
            "newtonfool", "pgd", "pixelattack", "spatialtransformation",
            "squareattack"
        ]
        for i in range(len(ATTACK_LIST)):
            print("Attack: ", ATTACK_LIST[i])
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    else:
        raise ValueError("Unknown adversarial data")

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Dataset shape: ", train_adv_images.shape)
    print("Dataset type: ", type(train_adv_images))
    print("Label shape: ", len(train_adv_labels))
    print("")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=0)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=0)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == "ResNet18":
        model = resnet18(pretrained=True)
        proxy = resnet18(pretrained=True)
    elif args.model == 'PreActResNet18':
        model = PreActResNet18()
        proxy = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
        proxy = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    proxy = nn.DataParallel(proxy).cuda()

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model,
                                     proxy=proxy,
                                     proxy_optim=proxy_opt,
                                     gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, 0.4 * args.epochs, args.epochs], [0, args.lr_max, 0])[0]

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(
                os.path.join(fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.output_dir,
                             f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            X_adv = normalize(adv_batch["input"])
            y_adv = adv_batch["target"]

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv, targets=y_adv)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y_adv)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y_adv.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_robust_loss = 0
        test_robust_acc = 0
        test_robust_n = 0

        for i, batch in enumerate(test_batches):
            X, y = normalize(batch['input']), batch['target']

            output = model(X)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, adv_batch in enumerate(test_adv_batches):
            X_adv, y_adv = normalize(adv_batch["input"]), adv_batch["target"]

            robust_output = model(X_adv)
            robust_loss = criterion(robust_output, y_adv)

            test_robust_loss += robust_loss.item() * y_adv.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y_adv).sum().item()
            test_robust_n += y_adv.size(0)

        test_time = time.time()

        logger.info('%d \t %.3f \t\t %.3f \t\t\t %.3f \t\t %.3f', epoch,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_robust_acc / test_robust_n)

        # save checkpoint
        if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(), os.path.join(fname,
                                                      f'opt_{epoch}.pth'))

        # save best
        if test_robust_acc / test_n > best_test_robust_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_robust_acc': test_robust_acc / test_robust_n,
                    'test_robust_loss': test_robust_loss / test_robust_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(fname, f'model_best.pth'))
            best_test_robust_acc = test_robust_acc / test_n
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    model.train()

    params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    epochs = args.epochs

    if args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp(
            [t], [0, args.epochs // 2, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop:
                return args.lr_max
            else:
                return args.lr_max / 10.

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_robust_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc \t Defence Mean \t Attack Mean'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        defence_mean = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   args.fgsm_alpha * epsilon,
                                   1,
                                   1,
                                   'l_inf',
                                   fgsm_init=args.fgsm_init)
            elif args.attack == 'multigrad':
                delta = multi_grad(model, X, y, epsilon,
                                   args.fgsm_alpha * epsilon, args.multi_samps,
                                   args.multi_th, args.multi_parallel)
            elif args.attack == 'zerograd':
                delta = zero_grad(model, X, y, epsilon,
                                  args.fgsm_alpha * epsilon, args.zero_qval,
                                  args.zero_iters, args.fgsm_init)
            elif args.attack == 'pgd':
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 10, 1,
                                   'l_inf')
            delta = delta.detach()
            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            defence_mean += torch.mean(torch.abs(delta)) * y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        attack_mean = 0
        for i, batch in enumerate(test_batches):
            if not epoch + 1 == epochs and not args.full_test and i > len(
                    test_batches) / 10:
                break
            X, y = batch['input'], batch['target']

            delta = attack_pgd(model,
                               X,
                               y,
                               epsilon,
                               pgd_alpha,
                               args.test_iters,
                               args.test_restarts,
                               'l_inf',
                               early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)
            attack_mean += torch.mean(torch.abs(delta)) * y.size(0)

        test_time = time.time()

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n,
                defence_mean * 255 / train_n, attack_mean * 255 / test_n)

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n, -1,
                attack_mean * 255 / test_n)

        if args.eval or epoch + 1 == epochs:
            start_test_time = time.time()
            test_loss = 0
            test_acc = 0
            test_robust_loss = 0
            test_robust_acc = 0
            test_n = 0
            for i, batch in enumerate(test_batches):
                X, y = batch['input'], batch['target']

                delta = attack_pgd(model,
                                   X,
                                   y,
                                   epsilon,
                                   pgd_alpha,
                                   50,
                                   10,
                                   'l_inf',
                                   early_stop=True)
                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                test_robust_loss += robust_loss.item() * y.size(0)
                test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                test_loss += loss.item() * y.size(0)
                test_acc += (output.max(1)[1] == y).sum().item()
                test_n += y.size(0)

            logger.info(
                'PGD50 \t time: %.1f,\t clean loss: %.4f,\t clean acc: %.4f,\t robust loss: %.4f,\t robust acc: %.4f',
                time.time() - start_test_time, test_loss / test_n,
                test_acc / test_n, test_robust_loss / test_n,
                test_robust_acc / test_n)
            return
Ejemplo n.º 8
0
def main():
    global args, best_prec1, exp_dir

    best_prec1 = 0
    args = parser.parse_args()
    print(args.lr_decay_at)
    assert args.normalization in ['GCN_ZCA',
                                  'GCN'], 'normalization {} unknown'.format(
                                      args.normalization)

    global zca
    if 'ZCA' in args.normalization:
        zca_params = torch.load('./data/cifar-10-batches-py/zca_params.pth')
        zca = ZCA(zca_params)
    else:
        zca = None

    exp_dir = os.path.join('experiments', args.name)
    if os.path.exists(exp_dir):
        print("same experiment exist...")
        #return
    else:
        os.makedirs(exp_dir)

    # DATA SETTINGS
    global dataset_train, dataset_test
    if args.dataset == 'cifar10':
        import cifar
        dataset_train = cifar.CIFAR10(args, train=True)
        dataset_test = cifar.CIFAR10(args, train=False)
    if args.UDA:
        # loader for UDA
        dataset_train_uda = cifar.CIFAR10(args, True, True)
        uda_loader = torch.utils.data.DataLoader(
            dataset_train_uda,
            batch_size=args.batch_size_unsup,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=True)
        iter_uda = iter(uda_loader)
    else:
        iter_uda = None

    train_loader, test_loader = initialize_loader()

    # MODEL SETTINGS
    if args.arch == 'WRN-28-2':
        model = WideResNet(28, [100, 10][int(args.dataset == 'cifar10')],
                           2,
                           dropRate=args.dropout_rate)
        model = torch.nn.DataParallel(model.cuda())
    else:
        raise NotImplementedError('arch {} is not implemented'.format(
            args.arch))
    if args.optimizer == 'Adam':
        print("use Adam optimizer")
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.l2_reg)
    elif args.optimizer == 'SGD':
        print("use SGD optimizer")
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.l2_reg,
                                    nesterov=args.nesterov)

    if args.lr_decay == 'cosine':
        print("use cosine lr scheduler")
        global scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_iter, eta_min=args.final_lr)

    global batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup
    batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter()
    t = time.time()
    model.train()
    iter_sup = iter(train_loader)
    for train_iter in range(args.max_iter):
        # TRAIN
        lr = adjust_learning_rate(optimizer, train_iter + 1)
        train(model,
              iter_sup,
              optimizer,
              train_iter,
              data_iterator_uda=iter_uda)

        # LOGGING
        if (train_iter + 1) % args.print_freq == 0:
            print('ITER: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'L1 Loss {l1_loss.val:.4f} ({l1_loss.avg:.4f})\t'
                  'Unsup Loss {unsup_loss.val:.4f} ({unsup_loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Learning rate {2} TSA th {3}'.format(
                      train_iter,
                      args.max_iter,
                      lr,
                      TSA_th(train_iter),
                      batch_time=batch_time,
                      loss=losses_sup,
                      l1_loss=losses_l1,
                      unsup_loss=losses_unsup,
                      top1=top1))

        if (train_iter +
                1) % args.eval_iter == 0 or train_iter + 1 == args.max_iter:
            # EVAL
            print("evaluation at iter {}".format(train_iter))
            prec1 = test(model, test_loader)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'iter': train_iter + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best)
            print("* Best accuracy: {}".format(best_prec1))
            eval_interval_time = time.time() - t
            t = time.time()
            print("total {} sec for {} iterations".format(
                eval_interval_time, args.eval_iter))
            seconds_remaining = eval_interval_time / float(
                args.eval_iter) * (args.max_iter - train_iter)
            print("{}:{}:{} remaining".format(
                int(seconds_remaining // 3600),
                int((seconds_remaining % 3600) // 60),
                int(seconds_remaining % 60)))
            model.train()
            batch_time, losses_sup, losses_unsup, top1, losses_l1, losses_unsup = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
            ), AverageMeter()
            iter_sup = iter(train_loader)
            if iter_uda is not None:
                iter_uda = iter(uda_loader)
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = 'noise_predictor/' + names
    else:
        args.fname = 'noise_predictor/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    shuffle = True

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    print("Attacks")
    attacks = args.list.split("_")
    print(attacks)

    for i in range(len(attacks)):
        _adv_dir = "adv_examples/{}/".format(attacks[i])
        train_path = _adv_dir + "train.pth"
        test_path = _adv_dir + "test.pth"

        adv_train_data = torch.load(train_path)
        adv_test_data = torch.load(test_path)

        if i == 0:
            train_adv_images = adv_train_data["adv"]
            test_adv_images = adv_test_data["adv"]
            train_adv_labels = [i] * len(adv_train_data["label"])
            test_adv_labels = [i] * len(adv_test_data["label"])
        else:
            train_adv_images = np.concatenate(
                (train_adv_images, adv_train_data["adv"]))
            test_adv_images = np.concatenate(
                (test_adv_images, adv_test_data["adv"]))
            train_adv_labels = np.concatenate(
                (train_adv_labels, [i] * len(adv_train_data["label"])))
            test_adv_labels = np.concatenate(
                (test_adv_labels, [i] * len(adv_test_data["label"])))

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    print("")
    print("Train Adv Attack Data: ", attacks)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(num_classes=args.num_classes)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(num_classes=args.num_classes)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(num_classes=3)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()

    model.train()

    logger.info('Epoch \t Train Robust Acc \t Test Robust Acc')

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, adv_batch in enumerate(train_adv_batches):
            if args.eval:
                break
            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_n += adv_y.size(0)

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t\t %.4f', epoch + 1,
                    train_robust_acc / train_n, test_adv_acc / test_adv_n)

        # Save results
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_adv_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_adv_n,
                    'test_adv_loss': test_adv_loss / test_adv_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_adv_n
def main():
    args = get_args()
    if args.fname == 'auto':
        names = get_auto_fname(args)
        args.fname = '../../trained_models/' + names
    else:
        args.fname = '../../trained_models/' + args.fname

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    # Set seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # setup data loader
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = torchvision.datasets.CIFAR10(root='../data',
                                             train=True,
                                             download=True,
                                             transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(root='../data',
                                            train=False,
                                            download=True,
                                            transform=transform_test)

    if args.attack == "all":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = np.tile(train_data, (11, 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (11))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))

    elif args.attack == "combine":
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        oversampled_train_data = train_data.copy()
        oversampled_train_labels = train_labels.copy()

        logger.info("Attacks")
        attacks = args.list.split("_")
        logger.info(attacks)

        oversampled_train_data = np.tile(train_data, (len(attacks), 1, 1, 1))
        oversampled_train_labels = np.tile(train_labels, (len(attacks)))

        train_set = list(
            zip(torch.from_numpy(oversampled_train_data),
                torch.from_numpy(oversampled_train_labels)))
    else:
        train_data = np.array(train_set.data) / 255.
        train_data = transpose(train_data).astype(np.float32)

        train_labels = np.array(train_set.targets)

        train_set = list(
            zip(torch.from_numpy(train_data), torch.from_numpy(train_labels)))

    test_data = np.array(test_set.data) / 255.
    test_data = transpose(test_data).astype(np.float32)
    test_labels = np.array(test_set.targets)

    test_set = list(
        zip(torch.from_numpy(test_data), torch.from_numpy(test_labels)))

    if args.sample != 100:
        n = len(train_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_set)
        train_set = train_set[:n_sample]

    print("")
    print("Train Original Data: ")
    print("Len: ", len(train_set))
    print("")

    shuffle = False

    train_batches = Batches(train_set, args.batch_size, shuffle=shuffle)
    test_batches = Batches(test_set, args.batch_size, shuffle=False)

    train_adv_images = None
    train_adv_labels = None
    test_adv_images = None
    test_adv_labels = None

    adv_dir = "adv_examples/{}/".format(args.attack)
    train_path = adv_dir + "train.pth"
    test_path = adv_dir + "test.pth"

    #     ATTACK_LIST = ["autoattack", "autopgd", "bim", "cw", "deepfool", "fgsm", "newtonfool", "pgd", "pixelattack", "spatialtransformation", "squareattack"]
    ATTACK_LIST = [
        "pixelattack", "spatialtransformation", "squareattack", "fgsm",
        "deepfool", "bim", "cw", "pgd", "autoattack", "autopgd", "newtonfool"
    ]

    if args.attack in TOOLBOX_ADV_ATTACK_LIST:
        adv_train_data = torch.load(train_path)
        train_adv_images = adv_train_data["adv"]
        train_adv_labels = adv_train_data["label"]
        adv_test_data = torch.load(test_path)
        test_adv_images = adv_test_data["adv"]
        test_adv_labels = adv_test_data["label"]
    elif args.attack in ["ffgsm", "mifgsm", "tpgd"]:
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(train_path)
        train_adv_images = adv_data["adv"].numpy()
        train_adv_labels = adv_data["label"].numpy()
        adv_data = {}
        adv_data["adv"], adv_data["label"] = torch.load(test_path)
        test_adv_images = adv_data["adv"].numpy()
        test_adv_labels = adv_data["label"].numpy()
    elif args.attack == "all":

        for i in range(len(ATTACK_LIST)):
            _adv_dir = "adv_examples/{}/".format(ATTACK_LIST[i])
            train_path = _adv_dir + "train.pth"
            test_path = _adv_dir + "test.pth"

            adv_train_data = torch.load(train_path)
            adv_test_data = torch.load(test_path)

            if i == 0:
                train_adv_images = adv_train_data["adv"]
                train_adv_labels = adv_train_data["label"]
                test_adv_images = adv_test_data["adv"]
                test_adv_labels = adv_test_data["label"]
            else:
                #                 print(train_adv_images.shape)
                #                 print(adv_train_data["adv"].shape)
                train_adv_images = np.concatenate(
                    (train_adv_images, adv_train_data["adv"]))
                train_adv_labels = np.concatenate(
                    (train_adv_labels, adv_train_data["label"]))
                test_adv_images = np.concatenate(
                    (test_adv_images, adv_test_data["adv"]))
                test_adv_labels = np.concatenate(
                    (test_adv_labels, adv_test_data["label"]))
    elif args.attack == "combine":

        print("Attacks")
        attacks = args.list.split("_")
        print(attacks)

        if args.balanced == None:
            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                if i == 0:
                    train_adv_images = adv_train_data["adv"]
                    train_adv_labels = adv_train_data["label"]
                    test_adv_images = adv_test_data["adv"]
                    test_adv_labels = adv_test_data["label"]
                else:
                    #                 print(train_adv_images.shape)
                    #                 print(adv_train_data["adv"].shape)
                    train_adv_images = np.concatenate(
                        (train_adv_images, adv_train_data["adv"]))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels, adv_train_data["label"]))
                    test_adv_images = np.concatenate(
                        (test_adv_images, adv_test_data["adv"]))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels, adv_test_data["label"]))
        else:
            proportion_str = args.balanced.split("_")
            proportion = [int(x) for x in proportion_str]
            sum_proportion = sum(proportion)
            proportion = [float(x) / float(sum_proportion) for x in proportion]
            sum_samples = 0

            for i in range(len(attacks)):
                _adv_dir = "adv_examples/{}/".format(attacks[i])
                train_path = _adv_dir + "train.pth"
                test_path = _adv_dir + "test.pth"

                adv_train_data = torch.load(train_path)
                adv_test_data = torch.load(test_path)

                random_state = 0
                num_samples = 0
                total = 50000
                if i != len(attacks) - 1:
                    n_samples = int(proportion[i] * total)
                    sum_samples += n_samples
                else:
                    n_samples = total - sum_samples
                print("Sample")
                print(n_samples)

                if i == 0:
                    train_adv_images = resample(adv_train_data["adv"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    train_adv_labels = resample(adv_train_data["label"],
                                                n_samples=n_samples,
                                                random_state=random_state)
                    test_adv_images = resample(adv_test_data["adv"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                    test_adv_labels = resample(adv_test_data["label"],
                                               n_samples=n_samples,
                                               random_state=random_state)
                else:
                    train_adv_images = np.concatenate(
                        (train_adv_images,
                         resample(adv_train_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    train_adv_labels = np.concatenate(
                        (train_adv_labels,
                         resample(adv_train_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_images = np.concatenate(
                        (test_adv_images,
                         resample(adv_test_data["adv"],
                                  n_samples=n_samples,
                                  random_state=random_state)))
                    test_adv_labels = np.concatenate(
                        (test_adv_labels,
                         resample(adv_test_data["label"],
                                  n_samples=n_samples,
                                  random_state=random_state)))

    else:
        raise ValueError("Unknown adversarial data")

    train_adv_set = list(zip(train_adv_images, train_adv_labels))

    if args.sample != 100:
        n = len(train_adv_set)
        n_sample = int(n * args.sample / 100)

        np.random.shuffle(train_adv_set)
        train_adv_set = train_adv_set[:n_sample]

    print("")
    print("Train Adv Attack Data: ", args.attack)
    print("Len: ", len(train_adv_set))
    print("")

    train_adv_batches = Batches(train_adv_set,
                                args.batch_size,
                                shuffle=shuffle,
                                set_random_choices=False,
                                num_workers=4)

    test_adv_set = list(zip(test_adv_images, test_adv_labels))

    test_adv_batches = Batches(test_adv_set,
                               args.batch_size,
                               shuffle=False,
                               num_workers=4)

    # Set perturbations
    epsilon = (args.epsilon / 255.)
    test_epsilon = (args.test_epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)
    test_pgd_alpha = (args.test_pgd_alpha / 255.)

    # Set models
    model = None
    if args.model == "resnet18":
        model = resnet18(pretrained=True)
    elif args.model == "resnet20":
        model = resnet20()
    elif args.model == "vgg16bn":
        model = vgg16_bn(pretrained=True)
    elif args.model == "densenet121":
        model = densenet121(pretrained=True)
    elif args.model == "googlenet":
        model = googlenet(pretrained=True)
    elif args.model == "inceptionv3":
        model = inception_v3(pretrained=True)
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=10,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    elif args.model == 'WideResNet_20':
        model = WideResNet(34,
                           10,
                           widen_factor=20,
                           dropRate=0.0,
                           normalize=args.use_FNandWN,
                           activation=args.activation,
                           softplus_beta=args.softplus_beta)
    else:
        raise ValueError("Unknown model")

    # Set training hyperparameters
    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()
    if args.lr_schedule == 'cyclic':
        opt = torch.optim.Adam(params,
                               lr=args.lr_max,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=args.weight_decay)
    else:
        if args.optimizer == 'momentum':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay)
        elif args.optimizer == 'Nesterov':
            opt = torch.optim.SGD(params,
                                  lr=args.lr_max,
                                  momentum=0.9,
                                  weight_decay=args.weight_decay,
                                  nesterov=True)
        elif args.optimizer == 'SGD_GC':
            opt = SGD_GC(params,
                         lr=args.lr_max,
                         momentum=0.9,
                         weight_decay=args.weight_decay)
        elif args.optimizer == 'SGD_GCC':
            opt = SGD_GCC(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
        elif args.optimizer == 'Adam':
            opt = torch.optim.Adam(params,
                                   lr=args.lr_max,
                                   betas=(0.9, 0.999),
                                   eps=1e-08,
                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'AdamW':
            opt = torch.optim.AdamW(params,
                                    lr=args.lr_max,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=args.weight_decay)

    # Cross-entropy (mean)
    if args.labelsmooth:
        criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue)
    else:
        criterion = nn.CrossEntropyLoss()

    # If we use freeAT or fastAT with previous init
    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    # Set lr schedule
    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t, warm_up_lr=args.warmup_lr):
            if t < 100:
                if warm_up_lr and t < args.warmup_lr_epoch:
                    return (t + 1.) / args.warmup_lr_epoch * args.lr_max
                else:
                    return args.lr_max
            if args.lrdecay == 'lineardecay':
                if t < 105:
                    return args.lr_max * 0.02 * (105 - t)
                else:
                    return 0.
            elif args.lrdecay == 'intenselr':
                if t < 102:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'looselr':
                if t < 150:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
            elif args.lrdecay == 'base':
                if t < 105:
                    return args.lr_max / 10.
                else:
                    return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':

        def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max):

            # Scaler: we can adapt this if we do not want the triangular CLR
            scaler = lambda x: 1.

            # Additional function to see where on the cycle we are
            cycle = math.floor(1 + t / (2 * stepsize))
            x = abs(t / stepsize - 2 * cycle + 1)
            relative = max(0, (1 - x)) * scaler(cycle)

            return min_lr + (max_lr - min_lr) * relative

    #### Set stronger adv attacks when decay the lr ####
    def eps_alpha_schedule(
            t,
            warm_up_eps=args.warmup_eps,
            if_use_stronger_adv=args.use_stronger_adv,
            stronger_index=args.stronger_index):  # Schedule number 0
        if stronger_index == 0:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha, pgd_alpha]
        elif stronger_index == 1:
            epsilon_s = [epsilon * 1.5, epsilon * 2]
            pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5]
        elif stronger_index == 2:
            epsilon_s = [epsilon * 2, epsilon * 2.5]
            pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2]
        else:
            print('Undefined stronger index')

        if if_use_stronger_adv:
            if t < 100:
                if t < args.warmup_eps_epoch and warm_up_eps:
                    return (
                        t + 1.
                    ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
                else:
                    return epsilon, pgd_alpha, args.restarts
            elif t < 105:
                return epsilon_s[0], pgd_alpha_s[0], args.restarts
            else:
                return epsilon_s[1], pgd_alpha_s[1], args.restarts
        else:
            if t < args.warmup_eps_epoch and warm_up_eps:
                return (
                    t + 1.
                ) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts
            else:
                return epsilon, pgd_alpha, args.restarts

    #### Set the counter for the early stop of PGD ####
    def early_stop_counter_schedule(t):
        if t < args.earlystopPGDepoch1:
            return 1
        elif t < args.earlystopPGDepoch2:
            return 2
        else:
            return 3

    best_test_adv_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(
            torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(
            torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        best_test_adv_acc = torch.load(
            os.path.join(args.fname, f'model_best.pth'))['test_adv_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    model.cuda()
    model.eval()

    # Evaluate on original test data
    test_acc = 0
    test_n = 0

    for i, batch in enumerate(test_batches):
        X, y = batch['input'], batch['target']

        clean_input = normalize(X)
        output = model(clean_input)

        test_acc += (output.max(1)[1] == y).sum().item()
        test_n += y.size(0)

    logger.info('Intial Accuracy on Original Test Data: %.4f (Test Acc)',
                test_acc / test_n)

    test_adv_acc = 0
    test_adv_n = 0

    for i, batch in enumerate(test_adv_batches):
        adv_input = normalize(batch['input'])
        y = batch['target']

        robust_output = model(adv_input)
        test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
        test_adv_n += y.size(0)

    logger.info(
        'Intial Accuracy on Adversarial Test Data: %.4f (Test Robust Acc)',
        test_adv_acc / test_adv_n)

    model.train()

    logger.info(
        'Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc'
    )

    # Records per epoch for savetxt
    train_loss_record = []
    train_acc_record = []
    train_robust_loss_record = []
    train_robust_acc_record = []
    train_grad_record = []

    test_loss_record = []
    test_acc_record = []
    test_adv_loss_record = []
    test_adv_acc_record = []
    test_grad_record = []

    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()

        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        train_grad = 0

        record_iter = torch.tensor([])

        for i, (batch,
                adv_batch) in enumerate(zip(train_batches, train_adv_batches)):
            if args.eval:
                break
            X, y = batch['input'], batch['target']

            adv_input = normalize(adv_batch['input'])
            adv_y = adv_batch['target']
            adv_input.requires_grad = True
            robust_output = model(adv_input)

            # Training losses
            if args.mixup:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)

            elif args.mixture:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                robust_loss = args.mixture_alpha * criterion(
                    robust_output,
                    adv_y) + (1 - args.mixture_alpha) * criterion(output, y)

            else:
                clean_input = normalize(X)
                clean_input.requires_grad = True
                output = model(clean_input)
                if args.focalloss:
                    criterion_nonreduct = nn.CrossEntropyLoss(reduction='none')
                    robust_confidence = F.softmax(robust_output,
                                                  dim=1)[:, adv_y].detach()
                    robust_loss = (criterion_nonreduct(robust_output, adv_y) *
                                   ((1. - robust_confidence)**
                                    args.focallosslambda)).mean()

                elif args.use_DLRloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * dlr_loss(
                            robust_output, adv_y)

                elif args.use_CWloss:
                    beta_ = 0.8 * epoch_now / args.epochs
                    robust_loss = (1. - beta_) * F.cross_entropy(
                        robust_output, adv_y) + beta_ * CW_loss(
                            robust_output, adv_y)

                elif args.use_FNandWN:
                    #print('use FN and WN with margin')
                    robust_loss = criterion(
                        args.s_FN * robust_output -
                        onehot_target_withmargin_HE, adv_y)

                else:
                    robust_loss = criterion(robust_output, adv_y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            clean_input = normalize(X)
            clean_input.requires_grad = True
            output = model(clean_input)
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            # Get the gradient norm values
            input_grads = torch.autograd.grad(loss,
                                              clean_input,
                                              create_graph=False)[0]

            # Record the statstic values
            train_robust_loss += robust_loss.item() * adv_y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == adv_y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            train_grad += input_grads.abs().sum()

        train_time = time.time()
        if args.earlystopPGD:
            print('Iter mean: ',
                  record_iter.mean().item(), ' Iter std:  ',
                  record_iter.std().item())

        # Evaluate on test data
        model.eval()
        test_loss = 0
        test_acc = 0
        test_n = 0

        test_adv_loss = 0
        test_adv_acc = 0
        test_adv_n = 0

        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            clean_input = normalize(X)
            output = model(clean_input)
            loss = criterion(output, y)

            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        for i, batch in enumerate(test_adv_batches):
            adv_input = normalize(batch['input'])
            y = batch['target']

            robust_output = model(adv_input)
            robust_loss = criterion(robust_output, y)

            test_adv_loss += robust_loss.item() * y.size(0)
            test_adv_acc += (robust_output.max(1)[1] == y).sum().item()
            test_adv_n += y.size(0)

        test_time = time.time()

        logger.info('%d \t %.4f \t %.4f \t\t %.4f \t %.4f', epoch + 1,
                    train_acc / train_n, train_robust_acc / train_n,
                    test_acc / test_n, test_adv_acc / test_adv_n)

        # Save results
        train_loss_record.append(train_loss / train_n)
        train_acc_record.append(train_acc / train_n)
        train_robust_loss_record.append(train_robust_loss / train_n)
        train_robust_acc_record.append(train_robust_acc / train_n)

        np.savetxt(args.fname + '/train_loss_record.txt',
                   np.array(train_loss_record))
        np.savetxt(args.fname + '/train_acc_record.txt',
                   np.array(train_acc_record))
        np.savetxt(args.fname + '/train_robust_loss_record.txt',
                   np.array(train_robust_loss_record))
        np.savetxt(args.fname + '/train_robust_acc_record.txt',
                   np.array(train_robust_acc_record))

        test_loss_record.append(test_loss / train_n)
        test_acc_record.append(test_acc / train_n)
        test_adv_loss_record.append(test_adv_loss / train_n)
        test_adv_acc_record.append(test_adv_acc / train_n)

        np.savetxt(args.fname + '/test_loss_record.txt',
                   np.array(test_loss_record))
        np.savetxt(args.fname + '/test_acc_record.txt',
                   np.array(test_acc_record))
        np.savetxt(args.fname + '/test_adv_loss_record.txt',
                   np.array(test_adv_loss_record))
        np.savetxt(args.fname + '/test_adv_acc_record.txt',
                   np.array(test_adv_acc_record))

        # save checkpoint
        if epoch > 99 or (epoch +
                          1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
            torch.save(model.state_dict(),
                       os.path.join(args.fname, f'model_{epoch}.pth'))
            torch.save(opt.state_dict(),
                       os.path.join(args.fname, f'opt_{epoch}.pth'))

        # save best
        if test_adv_acc / test_n > best_test_adv_acc:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'test_adv_acc': test_adv_acc / test_n,
                    'test_adv_loss': test_adv_loss / test_n,
                    'test_loss': test_loss / test_n,
                    'test_acc': test_acc / test_n,
                }, os.path.join(args.fname, f'model_best.pth'))
            best_test_adv_acc = test_adv_acc / test_n
Ejemplo n.º 11
0
                           num_workers=args.n_workers)
    testloader = DataLoader(dataset=testset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.n_workers)

    print('[init cache]')
    partition_size, _ = divmod(len(unlabeledloader.sampler), args.n_partitions)
    assert _ == 0
    cache = Cache(n_entries=partition_size,
                  entry_size=args.sparsity).to(args.output_device)

    print('[init model]')
    model = WideResNet(num_classes=args.n_classes).to(args.output_device)
    model_ema = WideResNet(num_classes=args.n_classes).to(args.output_device)
    model_ema.load_state_dict(model.state_dict())

    print('[init optimizer]')
    optimizer = optim.AdamW(model.parameters(),
                            lr=args.lr,
                            weight_decay=args.weight_decay)

    print('[init critera]')
    criterion_labeled = CrossEntropyLoss()
    criterion_unlabeled = MatchingLoss()
    criterion_val = nn.CrossEntropyLoss()

    print('[start training]')
    with SummaryWriter(log_dir=args.tensorboard_dir) as tblogger:
        run(labeledloader=labeledloader,
            unlabeledloader=unlabeledloader,
Ejemplo n.º 12
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard:
        configure("runs/%s" % (args.name))

    # Data loading code
    # normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
    #								 std=[x/255.0 for x in [63.0, 62.1, 66.7]])


# 	if args.augment:
# 		transform_train = transforms.Compose([
# 			transforms.Resize(256,256),
# 			transforms.ToTensor(),
# 			transforms.Lambda(lambda x: F.pad(
# 								Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
# 								(4,4,4,4),mode='reflect').data.squeeze()),
# 			transforms.ToPILImage(),
# 			transforms.RandomCrop(32),
# 			transforms.RandomHorizontalFlip(),
# 			transforms.ToTensor(),
# 			normalize,
# 			])
# 	else:
# 		transform_train = transforms.Compose([
# 			transforms.ToTensor(),
# 			normalize,
# 			])

# 	transform_test = transforms.Compose([
# 		transforms.ToTensor(),
# 		normalize
# 		])

#kwargs = {'num_workers': 1, 'pin_memory': True}
#assert(args.dataset == 'cifar10' or args.dataset == 'cifar100')

    train_data_path = "/home/mil/gupta/ifood18/data/training_set/train_set/"
    val_data_path = "/home/mil/gupta/ifood18/data/val_set/"

    train_label = "/home/mil/gupta/ifood18/data/labels/train_info.csv"
    val_label = "/home/mil/gupta/ifood18/data/labels/val_info.csv"

    #transformations = transforms.Compose([transforms.ToTensor()])

    # train_h5 = "/home/mil/gupta/ifood18/data/h5data/train_data.h5py"
    train_h5 = "/home/mil/gupta/ifood18/data/h5data/train_data_partial.h5py"
    train_dataset = H5Dataset(train_h5)
    #val_dataset =  H5Dataset(val_data_path, val_label, transform= None)

    #custom_mnist_from_csv = \
    #    CustomDatasetFromCSV('../data/mnist_in_csv.csv',
    #                         28, 28,
    #                         transformations)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=1)

    #val_loader = torch.utils.data.DataLoader(dataset=val_dataset,batch_size=256,num_workers=1,shuffle=True)

    # 	train_labels = pd.read_csv('./data/labels/train_info.csv')

    # 	train_loader = torch.utils.data.DataLoader(
    #         datasets.__dict__[args.dataset.upper()](train_data_path, train=True, download=True,
    #                          transform=transform_train),
    #         batch_size=args.batch_size, shuffle=True, **kwargs)
    #     val_loader = torch.utils.data.DataLoader(
    #         datasets.__dict__[args.dataset.upper()](val_data_path, train=False, transform=transform_test),
    #         batch_size=args.batch_size, shuffle=True, **kwargs)

    # create model
    model = WideResNet(args.layers,
                       211,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        #if epoch % args.validate_freq == 0:
        # evaluate on validation set
        prec3 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec3 > best_prec3
        best_prec1 = max(prec3, best_prec3)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec3,
            }, is_best)

    print('Best accuracy: ', best_prec1)
Ejemplo n.º 13
0
def main():
    args = get_args()
    if args.awp_gamma <= 0.0:
        args.awp_warmup = np.infty

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print("Couldn't find a dataset with a validation split, did you run "
                  "generate_validation.py?")
            return
        val_set = list(zip(transpose(dataset['val']['data']/255.), dataset['val']['labels']))
        val_batches = Batches(val_set, args.batch_size, shuffle=False, num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(zip(transpose(pad(dataset['train']['data'], 4)/255.),
        dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x, args.batch_size, shuffle=True, set_random_choices=True, num_workers=2)

    test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels']))
    test_batches = Batches(test_set, args.batch_size_test, shuffle=False, num_workers=2)

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
        proxy = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
        proxy = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = nn.DataParallel(model).cuda()
    proxy = nn.DataParallel(proxy).cuda()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)
    proxy_opt = torch.optim.SGD(proxy.parameters(), lr=0.01)
    awp_adversary = AdvWeightPerturb(model=model, proxy=proxy, proxy_optim=proxy_opt, gamma=args.awp_gamma)

    criterion = nn.CrossEntropyLoss()

    if args.attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.attack == 'fgsm' and args.fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.attack == 'free':
        epochs = int(math.ceil(args.epochs / args.attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs // 3, args.epochs * 2 // 3, args.epochs], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':
        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':
        def lr_schedule(t):
            return args.lr_max - (t//(args.epochs//10))*(args.lr_max/10)
    elif args.lr_schedule == 'cosine':
        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))
    elif args.lr_schedule == 'cyclic':
        lr_schedule = lambda t: np.interp([t], [0, 0.4 * args.epochs, args.epochs], [0, args.lr_max, 0])[0]

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')

        if os.path.exists(os.path.join(args.fname, f'model_best.pth')):
            best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')
    for epoch in range(start_epoch, epochs):
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()
            elif args.attack == 'fgsm':
                delta = attack_pgd(model, X, y, epsilon, args.fgsm_alpha*epsilon, 1, 1, args.norm)
            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)
            X_adv = normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))

            model.train()
            # calculate adversarial weight perturbation and perturb it
            if epoch >= args.awp_warmup:
                # not compatible to mixup currently.
                assert (not args.mixup)
                awp = awp_adversary.calc_awp(inputs_adv=X_adv,
                                             targets=y)
                awp_adversary.perturb(awp)

            robust_output = model(X_adv)
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            if epoch >= args.awp_warmup:
                awp_adversary.restore(awp)

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters_test, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.attack == 'none':
                    delta = torch.zeros_like(X)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters_test, args.restarts, args.norm, early_stop=args.eval)
                delta = delta.detach()

                robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                    val_loss/val_n, val_acc/val_n, val_robust_loss/val_n, val_robust_acc/val_n)

                if val_robust_acc/val_n > best_val_robust_acc:
                    torch.save({
                            'state_dict':model.state_dict(),
                            'test_robust_acc':test_robust_acc/test_n,
                            'test_robust_loss':test_robust_loss/test_n,
                            'test_loss':test_loss/test_n,
                            'test_acc':test_acc/test_n,
                            'val_robust_acc':val_robust_acc/val_n,
                            'val_robust_loss':val_robust_loss/val_n,
                            'val_loss':val_loss/val_n,
                            'val_acc':val_acc/val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc/val_n

            # save checkpoint
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc/test_n > best_test_robust_acc:
                torch.save({
                        'state_dict':model.state_dict(),
                        'test_robust_acc':test_robust_acc/test_n,
                        'test_robust_loss':test_robust_loss/test_n,
                        'test_loss':test_loss/test_n,
                        'test_acc':test_acc/test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc/test_n
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1, -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
Ejemplo n.º 14
0
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    num_workers = 2
    train_dataset = datasets.SVHN(
        args.data_dir, split='train', transform=train_transform, download=True)
    test_dataset = datasets.SVHN(
        args.data_dir, split='test', transform=test_transform, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=2,
    )

    epsilon = (args.epsilon / 255.)
    pgd_alpha = (args.pgd_alpha / 255.)

    # model = models_dict[args.architecture]().cuda()
    # model.apply(initialize_weights)
    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34, 10, widen_factor=args.width_factor, dropRate=0.0)
    else:
        raise ValueError("Unknown model")

    model = model.cuda()
    model.train()

    if args.l2:
        decay, no_decay = [], []
        for name,param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{'params':decay, 'weight_decay':args.l2},
                  {'params':no_decay, 'weight_decay': 0 }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0]
        # lr_schedule = lambda t: np.interp([t], [0, args.epochs], [0, args.lr_max])[0]
    elif args.lr_schedule == 'piecewise':
        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.

    if args.resume:
        start_epoch = args.resume
        model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        logger.info(f'Resuming at epoch {start_epoch}')
    else:
        start_epoch = 0


    if args.eval:
        if not args.resume:
            logger.info("No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info('Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc')
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            if args.eval:
                break
            X, y = X.cuda(), y.cuda()
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_loader))
            opt.param_groups[0].update(lr=lr)

            if args.attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, mixup=True, y_a=y_a, y_b=y_b, lam=lam)
                else:
                    delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm)
                delta = delta.detach()

            # Standard training
            elif args.attack == 'none':
                delta = torch.zeros_like(X)

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name,param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1*param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, (X, y) in enumerate(test_loader):
            X, y = X.cuda(), y.cuda()

            # Random initialization
            if args.attack == 'none':
                delta = torch.zeros_like(X)
            else:
                delta = attack_pgd(model, X, y, epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=args.eval)
            delta = delta.detach()

            robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()
        if not args.eval: 
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            
            if (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs:
                torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth'))
        else:
            logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1,
                -1, -1, -1, -1,
                test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n)
            return
Ejemplo n.º 15
0
    def train_fusionWRN6(
        self,
        epochs1=120,
        epochs2=3,
        device="cuda:0"
    ):  # https://github.com/xternalz/WideResNet-pytorch.git #120 80
        sys.path.append('/media/rene/code/WideResNet-pytorch')
        from wideresnet import WideResNet

        epochs1, epochs2 = int(epochs1), int(epochs2)
        num_workers = 4

        PATH = Path('/media/rene/data/')
        save_path = Path('/media/rene/code/WideResNet-pytorch/runs')
        model_name_list = [
            'WideResNet-28-10_0/model_best.pth.tar',
            'WideResNet-28-10_1/model_best.pth.tar',
            'WideResNet-28-10_2/model_best.pth.tar',
            'WideResNet-28-10_3/model_best.pth.tar',
            'WideResNet-28-10_4/model_best.pth.tar',
            'WideResNet-28-10_5/model_best.pth.tar'
        ]
        batch_size = 8

        dataloaders, dataset_sizes = make_batch_gen_cifar(str(PATH),
                                                          batch_size,
                                                          num_workers,
                                                          valid_name='valid')

        pretrained_model_list = []
        # First trained model was with DATA PARALLEL
        model = WideResNet(28, 10, 20)
        model = model.to(device)
        state_dict = torch.load(
            os.path.join(
                save_path,
                'WideResNet-28-10_0/model_best.pth.tar'))['state_dict']
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

        pretrained_model_list.append(model)

        # get all the models
        for i, model_name in enumerate(model_name_list[1:]):
            print('------------loading model: ', model_name)
            model = WideResNet(28, 10, 20)
            model = model.to(device)

            # original saved file with DataParallel
            state_dict = torch.load(os.path.join(save_path,
                                                 model_name))['state_dict']
            model.load_state_dict(state_dict)
            pretrained_model_list.append(model)

        model = Fusion6(pretrained_model_list, num_input=60, num_output=10)

        ######################  TRAIN LAST FEW LAYERS
        # print('training last few layers')

        model_name = 'Fusion6_WRN_1'
        for p in model.parameters():
            p.requires_grad = True

        for p in model.model1.parameters():
            p.requires_grad = False
        for p in model.model2.parameters():
            p.requires_grad = False
        for p in model.model3.parameters():
            p.requires_grad = False
        for p in model.model4.parameters():
            p.requires_grad = False
        for p in model.model5.parameters():
            p.requires_grad = False
        for p in model.model6.parameters():
            p.requires_grad = False

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=.005,
                              momentum=0.9,
                              weight_decay=5e-4)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=int(epochs1 / 3),
                                        gamma=0.1)

        best_acc, model = train_model(model,
                                      criterion,
                                      optimizer,
                                      scheduler,
                                      epochs1,
                                      dataloaders,
                                      dataset_sizes,
                                      device=device)
        torch.save(model.state_dict(), str(save_path / model_name))

        ########################   TRAIN ALL LAYERS

        # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        # model.load_state_dict(torch.load(save_path / 'Fusion2_WRN_1'))

        model_name = 'Fusion6_WRN_2'
        batch_size = 1
        print('---------', batch_size)
        dataloaders, dataset_sizes = make_batch_gen_cifar(str(PATH),
                                                          batch_size,
                                                          num_workers,
                                                          valid_name='valid')

        for p in model.parameters():
            p.requires_grad = True

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     model.parameters()),
                              lr=.0001,
                              momentum=0.9,
                              weight_decay=5e-4)
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=int(epochs2 / 3),
                                        gamma=0.1)

        best_acc, model = train_model(model,
                                      criterion,
                                      optimizer,
                                      scheduler,
                                      epochs2,
                                      dataloaders,
                                      dataset_sizes,
                                      device=device)

        torch.save(model.state_dict(), str(save_path / model_name))
Ejemplo n.º 16
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.tensorboard: configure(args.checkpoint_dir+"/%s"%(args.name))

    # Data loading code
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x/255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
        	transforms.ToTensor(),
        	transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
        						(4,4,4,4),mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
        ])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert(args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](args.dataset_path, train=True, download=True,
                         transform=transform_train),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](args.dataset_path, train=False, transform=transform_test),
        batch_size=args.batch_size, shuffle=True, **kwargs)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # create model
    model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                            args.widen_factor, dropRate=args.droprate,
                            semantic_loss=args.sloss, device=device)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum, nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # cosine learning rate
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs, eta_min=1e-6)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=.2)
    calc_logic, logic_net, examples, logic_optimizer, decoder_optimizer, logic_scheduler, decoder_scheduler = None,None,None,None,None,None,None

    if args.dataset == "cifar100":
        examples, logic_fn, group_precision = get_cifar100_experiment_params(train_loader.dataset)
        assert logic_fn(torch.arange(100), examples).all()
    else:
        examples, logic_fn, group_precision = get_cifar10_experiment_params(train_loader.dataset)
        assert logic_fn(torch.arange(10), examples).all()

    if args.sloss:

        examples = examples.to(device)

        logic_net = LogicNet(num_classes=len(train_loader.dataset.classes))
        logic_net.to(device)

        # logic_optimizer = torch.optim.Adam(logic_net.parameters(), 1e-1*args.lr)
        logic_optimizer = torch.optim.SGD(logic_net.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    nesterov=args.nesterov,
                                    weight_decay=args.weight_decay)
        # logic_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(logic_optimizer, len(train_loader) * args.epochs, eta_min=1e-8)
        logic_scheduler = torch.optim.lr_scheduler.StepLR(logic_optimizer, step_size=25, gamma=.2)

        # decoder_optimizer = torch.optim.Adam(model.global_paramters, args.lr)
        decoder_optimizer = torch.optim.SGD(model.global_paramters,
                                          args.lr,
                                          momentum=args.momentum,
                                          nesterov=args.nesterov,
                                          weight_decay=args.weight_decay)
        # decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(decoder_optimizer, len(train_loader) * args.epochs, eta_min=1e-8)
        decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_optimizer, step_size=25, gamma=.2)

        calc_logic = lambda predictions, targets: calc_logic_loss(predictions, targets, logic_net, logic_fn, num_classes=model.num_classes, device=device)

        # override the oprimizer from above
        optimizer = torch.optim.SGD(model.local_parameters, # TODO: still might be better for parameters()
                                    args.lr,
                                    momentum=args.momentum,
                                    nesterov=args.nesterov,
                                    weight_decay=args.weight_decay)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * args.epochs)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=.2)

    name = "_".join([str(getattr(args, source)) for source in ['lr', 'sloss', 'sloss_weight', 'dataset']])

    if args.resume:
        targets, preds, outs = validate(val_loader, model, criterion, 1, args, group_precision, device=device)
        from sklearn.metrics import confusion_matrix
        import pickle
        confusion_matrix(targets, preds)
        group_precision(torch.tensor(targets), torch.tensor(np.concatenate(outs, axis=0)))
        dict_ = {"targets": targets, "pred": np.concatenate(outs, axis=0)}
        f = open('../semantic_loss/notebooks/results.pickle', 'wb')
        pickle.dump(dict_, f)
        f.close()
        import pdb
        pdb.set_trace()

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, logic_net,
              criterion, examples,
              optimizer, logic_optimizer, decoder_optimizer,
              scheduler, logic_scheduler, decoder_scheduler,
              epoch, args, calc_logic, device=device)
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, args, group_precision, device=device)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best, filename=f"{name}.checkpoint.pt")

        if args.sloss:
            logic_scheduler.step()
            decoder_scheduler.step()
        scheduler.step()

    print('Best accuracy: ', best_prec1)
Ejemplo n.º 17
0
    def train_fusionWRN_last3(
        self,
        epochs1=40,
        epochs2=25,
        device="cuda:1"
    ):  # https://github.com/xternalz/WideResNet-pytorch.git #120 80
        with torch.cuda.device(1):
            sys.path.append('/media/rene/code/WideResNet-pytorch')
            from wideresnet import WideResNet

            epochs1, epochs2 = int(epochs1), int(epochs2)
            num_workers = 4

            PATH = Path('/media/rene/data/')
            save_path = Path('/media/rene/code/WideResNet-pytorch/runs')
            model_name_list = [
                'WideResNet-28-10_0/model_best.pth.tar',
                'WideResNet-28-10_1/model_best.pth.tar',
                'WideResNet-28-10_2/model_best.pth.tar',
                'WideResNet-28-10_3/model_best.pth.tar',
                'WideResNet-28-10_4/model_best.pth.tar',
                'WideResNet-28-10_5/model_best.pth.tar'
            ]
            batch_size = 300

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                    std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
            ])
            dataloaders, dataset_sizes = make_batch_gen_cifar(
                str(PATH),
                batch_size,
                num_workers,
                valid_name='valid',
                transformation=transform_test)

            pretrained_model_list = []
            # First trained model was with DATA PARALLEL
            model = WideResNet(28, 10, 20)
            model = model.to(device)

            state_dict = torch.load(
                os.path.join(
                    save_path,
                    'WideResNet-28-10_0/model_best.pth.tar'))['state_dict']

            # create new OrderedDict that does not contain `module.`
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)
            pretrained_model_list.append(model)

            # get all the models
            for i, model_name in enumerate(model_name_list[1:3]):
                print('------------loading model: ', model_name)
                model = WideResNet(28, 10, 20)
                model = model.to(device)

                # original saved file with DataParallel
                state_dict = torch.load(os.path.join(save_path,
                                                     model_name))['state_dict']
                model.load_state_dict(state_dict)
                pretrained_model_list.append(model)

            model = Fusion3(pretrained_model_list, num_input=30, num_output=10)

            ######################  TRAIN LAST FEW LAYERS
            print('training last few layers')

            model_name = 'fusionWRN_last3_1'
            for p in model.parameters():
                p.requires_grad = True
            for p in model.model1.parameters():
                p.requires_grad = False
            for p in model.model2.parameters():
                p.requires_grad = False
            for p in model.model3.parameters():
                p.requires_grad = False

            # criterion = nn.CrossEntropyLoss()
            # optimizer = optim.SGD(filter(lambda p: p.requires_grad,model.parameters()), lr=.005, momentum=0.9, weight_decay=5e-4)
            # scheduler = lr_scheduler.StepLR(optimizer, step_size=int(epochs1/3), gamma=0.3)

            # best_acc, model = train_model(model, criterion, optimizer, scheduler, epochs1,
            #                            dataloaders, dataset_sizes, device=device)
            # torch.save(model.state_dict(), str(save_path / model_name))

            ########################   TRAIN ALL LAYERS
            model.load_state_dict(torch.load(save_path / 'fusionWRN_last3_1'))
            model = model.to(device)
            model_name = 'fusionWRN_last3_2'

            batch_size = 88
            dataloaders, dataset_sizes = make_batch_gen_cifar(
                str(PATH),
                batch_size,
                num_workers,
                valid_name='valid',
                transformation=transform_test)

            ### ONLY THE LAST BLOCK:
            for i, child in enumerate(model.model1.children()):
                if (i >= 3):
                    for p in child.parameters():
                        p.requires_grad = True
            for i, child in enumerate(model.model2.children()):
                if (i >= 3):
                    for p in child.parameters():
                        p.requires_grad = True
            for i, child in enumerate(model.model3.children()):
                if (i >= 3):
                    for p in child.parameters():
                        p.requires_grad = True

            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=.0001,
                                  momentum=0.9,
                                  weight_decay=5e-4)
            scheduler = lr_scheduler.StepLR(optimizer,
                                            step_size=int(epochs2 / 2),
                                            gamma=0.1)
            best_acc, model = train_model(model,
                                          criterion,
                                          optimizer,
                                          scheduler,
                                          2,
                                          dataloaders,
                                          dataset_sizes,
                                          device=device,
                                          multi_gpu=False)

            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  lr=.001,
                                  momentum=0.9,
                                  weight_decay=5e-4)
            scheduler = lr_scheduler.StepLR(optimizer,
                                            step_size=int(epochs2 / 2),
                                            gamma=0.2)

            best_acc, model = train_model(model,
                                          criterion,
                                          optimizer,
                                          scheduler,
                                          epochs2,
                                          dataloaders,
                                          dataset_sizes,
                                          device=device,
                                          multi_gpu=False)

            torch.save(model.state_dict(), str(save_path / model_name))
Ejemplo n.º 18
0
def main():
    global args, best_prec1, data
    if args.tensorboard: configure("runs/%s"%(args.name))

    # Data loading code
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x/255.0 for x in [63.0, 62.1, 66.7]])

    if args.augment:
        transform_train = transforms.Compose([
        	transforms.ToTensor(),
        	transforms.Lambda(lambda x: F.pad(
        						Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
        						(4,4,4,4),mode='reflect').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            normalize,
            ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        normalize
        ])

    # create model
    print("Creating the model...")
    model = WideResNet(args.layers, args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor, dropRate=args.droprate)

    rank = MPI.COMM_WORLD.Get_rank()
    args.rank = rank
    args.seed += rank
    _set_seed(args.seed)

    print("Creating the DataLoader...")
    kwargs = {'num_workers': args.num_workers}#, 'pin_memory': args.use_cuda}
    assert(args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data', train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data', train=False, transform=transform_test),
        batch_size=args.batch_size, shuffle=False, **kwargs)

    # get the number of model parameters
    args.num_parameters = sum([p.data.nelement() for p in model.parameters()])
    print('Number of model parameters: {}'.format(args.num_parameters))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if args.use_cuda:
        print("Moving the model to the GPU")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
        model = model.cuda()
        #model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if args.use_cuda:
        criterion = criterion.cuda()
    #  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                              #  momentum=args.momentum, nesterov=args.nesterov,
                              #  weight_decay=args.weight_decay)
    #  optimizer = torch.optim.ASGD(model.parameters(), args.lr)
    #  from distributed_opt import MiniBatchSGD
    #  import torch.distributed as dist
    #  rank = np.random.choice('gloo')

    # TODO: pass kwargs to code
    print('initing MiniBatchSGD')
    print(list(model.parameters())[6].view(-1)[:3])
    if args.code == 'sgd':
        code = codings.svd.SVD(compress=False)
    elif args.code == 'svd':
        print("train.py, svd_rank =", args.svd_rank)
        code = codings.svd.SVD(random_sample=args.svd_rescale, rank=args.svd_rank,
                               compress=args.compress)
    elif args.code == 'qsgd':
        code = codings.qsgd.QSGD(scheme='qsgd')
    elif args.code == 'terngrad':
        code = codings.qsgd.QSGD(scheme='terngrad')
    elif args.code == 'qsvd':
        code = codings.qsvd.QSVD(scheme=args.scheme, rank=args.svd_rank)
    else:
        raise ValueError('args.code not recognized')

    names = [n for n, p in model.named_parameters()]
    assert len(names) == len(set(names))
    #  optimizer = MPI_PS(model.named_parameters(), model.parameters(), args.lr,
                       #  code=code, optim='adam',
                       #  use_mpi=args.use_mpi, cuda=args.use_cuda)
    optimizer = SGD(model.named_parameters(), model.parameters(), args.lr,
                    code=code, optim='sgd',
                    use_mpi=args.use_mpi, cuda=args.use_cuda, compress_level=args.compress_level)

    data = []
    train_data = []
    train_time = 0
    for epoch in range(args.start_epoch, args.epochs + 1):
        print(f"epoch {epoch}")
        adjust_learning_rate(optimizer, epoch+1)

        # train for one epoch
        start = time.time()
        if epoch >= 0:
            train_d = train(train_loader, model, criterion, optimizer, epoch)
        else:
            train_d = []
        train_time += time.time() - start
        train_data += [dict(datum, **vars(args)) for datum in train_d]

        # evaluate on validation set
        if epoch >= args.epochs:
            train_datum = validate(train_loader, model, criterion, epoch)
        else:
            train_datum = {'acc_test': np.inf, 'loss_test': np.inf}
        datum = validate(val_loader, model, criterion, epoch)
        #  train_datum = {'acc_train': 0.1, 'loss_train': 2.3}
        data += [{'train_time': train_time,
                  'whole_train_acc': train_datum['acc_test'],
                  'whole_train_loss': train_datum['loss_test'],
                  'epoch': epoch + 1, **vars(args), **datum}]
        if epoch > 0:
            if len(data) > 1:
                data[-1]['epoch_train_time'] = data[-1]['train_time'] - data[-2]['train_time']
            for key in train_data[-1]:
                values = [datum[key] for i, datum in enumerate(train_data)]
                if 'time' in key:
                    data[-1]["epoch_" + key] = np.sum(values)
                else:
                    data[-1]["epoch_" + key] = values[0]

        df = pd.DataFrame(data)
        train_df = pd.DataFrame(train_data)
        if True:
            time.sleep(1)
            # Yes loss_test IS on train data. (look at what validate returns)
            print('\n\nmin_train_loss', train_datum['loss_test'],
                  optimizer.steps, '\n\n')
            time.sleep(1)
        ids = [str(getattr(args, key)) for key in
               ['layers', 'lr', 'batch_size', 'compress', 'seed', 'num_workers',
                'svd_rank', 'svd_rescale', 'use_mpi', 'qsgd', 'world_size',
                'rank', 'code', 'scheme']]
        _write_csv(df, id=f'-'.join(ids))
        _write_csv(train_df, id=f'-'.join(ids) + '_train')
        pprint({k: v for k, v in data[-1].items()
                if k in ['svd_rank', 'svd_rescale', 'qsgd', 'compress']})
        pprint({k: v for k, v in data[-1].items()
                if k in ['train_time', 'num_workers', 'loss_test',
                         'acc_test', 'epoch', 'compress', 'svd_rank', 'qsgd'] or 'time' in k})
        prec1 = datum['acc_test']

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)
        print('Best accuracy: ', best_prec1)