Exemple #1
0
def test_grad_through_normalize():
    tensor = torch.rand((2, 1, 28, 28))
    tensor.requires_grad_()
    mean = torch.tensor((0., ))
    std = torch.tensor((1., ))
    normalize = NormalizeByChannelMeanStd(mean, std)

    loss = (normalize(tensor)**2).sum()
    loss.backward()

    assert torch_allclose(2 * tensor, tensor.grad)
Exemple #2
0
class Transform(object):
    classifier_training = transforms.Compose([
        transforms.Normalize((-1., -1., -1.), (2., 2., 2.)),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])
    classifier_testing = transforms.Compose([
        transforms.Normalize((-1., -1., -1.), (2., 2., 2.)),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])

    default = transforms.ToTensor()
    gan_deprocess_layer = []
    classifier_preprocess_layer = NormalizeByChannelMeanStd(MEAN, STD).cuda()
Exemple #3
0
class Transform(object):
    gan_training = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    classifier_training = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
    ])
    classifier_testing = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
    ])

    # transformations used during testing / generating attacks
    # we assume loaded images to be in [0, 1]
    # since g(z) is in [-1, 1], the output should be deprocessed into [0, 1]
    default = transforms.ToTensor()
    gan_deprocess_layer = NormalizeByChannelMeanStd([-1., -1., -1.],
                                                    [2., 2., 2.]).cuda()
    classifier_preprocess_layer = NormalizeByChannelMeanStd(
        CIFAR10_MEAN, CIFAR10_STD).cuda()
Exemple #4
0
    def __init__(self, block, num_blocks, num_classes=10):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.nomalize = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])

        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.linear = nn.Linear(512 * block.expansion, num_classes)
        self.linear_main = nn.Linear(512 * block.expansion, num_classes)
Exemple #5
0
def get_models(args,
               train=True,
               as_ensemble=False,
               model_file=None,
               leaky_relu=False,
               dataset=None):

    mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda()
    if (dataset == "STL-10"):
        std = torch.tensor([0.2023, 0.1994, 0.2010],
                           dtype=torch.float32).cuda()
        #std = torch.tensor([0.2471, 0.2435, 0.2616], dtype=torch.float32).cuda()
    elif (dataset == "CIFAR-10"):
        std = torch.tensor([0.2023, 0.1994, 0.2010],
                           dtype=torch.float32).cuda()

    normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

    if model_file:
        state_dict = torch.load(model_file)
        if train:
            print('Loading pre-trained models...')

    if args.arch.lower() == 'resnet':
        model = ResNet18()  #depth=args.depth, leaky_relu=leaky_relu)
    else:
        raise ValueError('[{:s}] architecture is not supported yet...')
        # we include input normalization as a part of the model
    model = ModelWrapper(model, normalizer)
    if model_file:
        model.load_state_dict(state_dict)
    if train:
        model.train()
    else:
        model.eval()

    model = model.cuda()

    return model
Exemple #6
0
def get_models(args, train=True, as_ensemble=False, model_file=None, leaky_relu=False):
    models = []
    
    mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda()
    std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda()
    normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

    if model_file:
        state_dict = torch.load(model_file)
        if train:
            print('Loading pre-trained models...')
    
    iter_m = state_dict.keys() if model_file else range(args.model_num)

    for i in iter_m:
        if args.arch.lower() == 'resnet':
            model = ResNet(depth=args.depth, leaky_relu=leaky_relu)
        else:
            raise ValueError('[{:s}] architecture is not supported yet...')
        # we include input normalization as a part of the model
        model = ModelWrapper(model, normalizer)
        if model_file:
            model.load_state_dict(state_dict[i])
        if train:
            model.train()
        else:
            model.eval()
        model = model.cuda()
        models.append(model)

    if as_ensemble:
        assert not train, 'Must be in eval mode when getting models to form an ensemble'
        ensemble = Ensemble(models)
        ensemble.eval()
        return ensemble
    else:
        return models
def main():
    args = get_args()

    if not os.path.exists(args.fname):
        os.makedirs(args.fname)

    logger = logging.getLogger(__name__)
    logging.basicConfig(
        format='[%(asctime)s] - %(message)s',
        datefmt='%Y/%m/%d %H:%M:%S',
        level=logging.DEBUG,
        handlers=[
            logging.FileHandler(
                os.path.join(args.fname,
                             'eval.log' if args.eval else 'output.log')),
            logging.StreamHandler()
        ])

    logger.info(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    transforms = [Crop(32, 32), FlipLR()]
    # transforms = [Crop(32, 32)]
    if args.cutout:
        transforms.append(Cutout(args.cutout_len, args.cutout_len))
    if args.val:
        try:
            dataset = torch.load("cifar10_validation_split.pth")
        except:
            print(
                "Couldn't find a dataset with a validation split, did you run "
                "generate_validation.py?")
            return
        val_set = list(
            zip(transpose(dataset['val']['data'] / 255.),
                dataset['val']['labels']))
        val_batches = Batches(val_set,
                              args.batch_size,
                              shuffle=False,
                              num_workers=2)
    else:
        dataset = cifar10(args.data_dir)
    train_set = list(
        zip(transpose(pad(dataset['train']['data'], 4) / 255.),
            dataset['train']['labels']))
    train_set_x = Transform(train_set, transforms)
    train_batches = Batches(train_set_x,
                            args.batch_size,
                            shuffle=True,
                            set_random_choices=True,
                            num_workers=2)

    test_set = list(
        zip(transpose(dataset['test']['data'] / 255.),
            dataset['test']['labels']))
    test_batches = Batches(test_set,
                           args.batch_size,
                           shuffle=False,
                           num_workers=2)

    trn_epsilon = (args.trn_epsilon / 255.)
    trn_pgd_alpha = (args.trn_pgd_alpha / 255.)
    tst_epsilon = (args.tst_epsilon / 255.)
    tst_pgd_alpha = (args.tst_pgd_alpha / 255.)

    if args.model == 'PreActResNet18':
        model = PreActResNet18()
    elif args.model == 'WideResNet':
        model = WideResNet(34,
                           10,
                           widen_factor=args.width_factor,
                           dropRate=0.0)
    elif args.model == 'DenseNet121':
        model = DenseNet121()
    elif args.model == 'ResNet18':
        model = ResNet18()
    else:
        raise ValueError("Unknown model")

    ### temp testing ###
    model = model.cuda()
    # model = nn.DataParallel(model).cuda()
    model.train()

    ##################################
    # load pretrained model if needed
    if args.trn_adv_models != 'None':
        if args.trn_adv_arch == 'PreActResNet18':
            trn_adv_model = PreActResNet18()
        elif args.trn_adv_arch == 'WideResNet':
            trn_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.trn_adv_arch == 'DenseNet121':
            trn_adv_model = DenseNet121()
        elif args.trn_adv_arch == 'ResNet18':
            trn_adv_model = ResNet18()
        trn_adv_model = nn.DataParallel(trn_adv_model).cuda()
        trn_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.trn_adv_models,
                             'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.trn_adv_models}')
    else:
        trn_adv_model = None

    if args.tst_adv_models != 'None':
        if args.tst_adv_arch == 'PreActResNet18':
            tst_adv_model = PreActResNet18()
        elif args.tst_adv_arch == 'WideResNet':
            tst_adv_model = WideResNet(34,
                                       10,
                                       widen_factor=args.width_factor,
                                       dropRate=0.0)
        elif args.tst_adv_arch == 'DenseNet121':
            tst_adv_model = DenseNet121()
        elif args.tst_adv_arch == 'ResNet18':
            tst_adv_model = ResNet18()
        ### temp testing ###
        tst_adv_model = tst_adv_model.cuda()
        tst_adv_model.load_state_dict(
            torch.load(
                os.path.join('./adv_models', args.tst_adv_models,
                             'model_best.pth')))
        # tst_adv_model = nn.DataParallel(tst_adv_model).cuda()
        # tst_adv_model.load_state_dict(torch.load(os.path.join('./adv_models',args.tst_adv_models, 'model_best.pth'))['state_dict'])
        logger.info(f'loaded adv_model: {args.tst_adv_models}')
    else:
        tst_adv_model = None
    ##################################

    if args.l2:
        decay, no_decay = [], []
        for name, param in model.named_parameters():
            if 'bn' not in name and 'bias' not in name:
                decay.append(param)
            else:
                no_decay.append(param)
        params = [{
            'params': decay,
            'weight_decay': args.l2
        }, {
            'params': no_decay,
            'weight_decay': 0
        }]
    else:
        params = model.parameters()

    opt = torch.optim.SGD(params,
                          lr=args.lr_max,
                          momentum=0.9,
                          weight_decay=5e-4)

    criterion = nn.CrossEntropyLoss()

    if args.trn_attack == 'free':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True
    elif args.trn_attack == 'fgsm' and args.trn_fgsm_init == 'previous':
        delta = torch.zeros(args.batch_size, 3, 32, 32).cuda()
        delta.requires_grad = True

    if args.trn_attack == 'free':
        epochs = int(math.ceil(args.epochs / args.trn_attack_iters))
    else:
        epochs = args.epochs

    if args.lr_schedule == 'superconverge':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs * 2 // 5, args.epochs
        ], [0, args.lr_max, 0])[0]
    elif args.lr_schedule == 'piecewise':

        def lr_schedule(t):
            if t / args.epochs < 0.5:
                return args.lr_max
            elif t / args.epochs < 0.75:
                return args.lr_max / 10.
            else:
                return args.lr_max / 100.
    elif args.lr_schedule == 'linear':
        lr_schedule = lambda t: np.interp([t], [
            0, args.epochs // 3, args.epochs * 2 // 3, args.epochs
        ], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0]
    elif args.lr_schedule == 'onedrop':

        def lr_schedule(t):
            if t < args.lr_drop_epoch:
                return args.lr_max
            else:
                return args.lr_one_drop
    elif args.lr_schedule == 'multipledecay':

        def lr_schedule(t):
            return args.lr_max - (t //
                                  (args.epochs // 10)) * (args.lr_max / 10)
    elif args.lr_schedule == 'cosine':

        def lr_schedule(t):
            return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi))

    best_test_robust_acc = 0
    best_val_robust_acc = 0
    if args.resume:
        ### temp testing ###
        model.load_state_dict(
            torch.load(os.path.join(args.fname, 'model_best.pth')))
        start_epoch = args.resume
        # model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth')))
        # opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth')))
        # logger.info(f'Resuming at epoch {start_epoch}')

        # best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc']
        if args.val:
            best_val_robust_acc = torch.load(
                os.path.join(args.fname, f'model_val.pth'))['val_robust_acc']
    else:
        start_epoch = 0

    if args.eval:
        if not args.resume:
            logger.info(
                "No model loaded to evaluate, specify with --resume FNAME")
            return
        logger.info("[Evaluation mode]")

    logger.info(
        'Epoch \t Train Time \t Test Time \t LR \t \t Train Loss \t Train Acc \t Train Robust Loss \t Train Robust Acc \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc'
    )
    for epoch in range(start_epoch, epochs):
        model.train()
        start_time = time.time()
        train_loss = 0
        train_acc = 0
        train_robust_loss = 0
        train_robust_acc = 0
        train_n = 0
        for i, batch in enumerate(train_batches):
            if args.eval:
                break
            X, y = batch['input'], batch['target']
            if args.mixup:
                X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha)
                X, y_a, y_b = map(Variable, (X, y_a, y_b))
            lr = lr_schedule(epoch + (i + 1) / len(train_batches))
            opt.param_groups[0].update(lr=lr)

            if args.trn_attack == 'pgd':
                # Random initialization
                if args.mixup:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       mixup=True,
                                       y_a=y_a,
                                       y_b=y_b,
                                       lam=lam,
                                       adv_models=trn_adv_model)
                else:
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       trn_epsilon,
                                       trn_pgd_alpha,
                                       args.trn_attack_iters,
                                       args.trn_restarts,
                                       args.trn_norm,
                                       adv_models=trn_adv_model)
                delta = delta.detach()
            elif args.trn_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   trn_epsilon,
                                   args.trn_fgsm_alpha * trn_epsilon,
                                   1,
                                   1,
                                   args.trn_norm,
                                   adv_models=trn_adv_model,
                                   rand_init=args.trn_fgsm_init)
                delta = delta.detach()
            # Standard training
            elif args.trn_attack == 'none':
                delta = torch.zeros_like(X)
            # The Momentum Iterative Attack
            elif args.trn_attack == 'tmim':
                if trn_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    trn_adv_model = nn.Sequential(
                        NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD),
                        trn_adv_model)

                    adversary = MomentumIterativeAttack(
                        trn_adv_model,
                        nb_iter=args.trn_attack_iters,
                        eps=trn_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=trn_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
                delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            if args.mixup:
                robust_loss = mixup_criterion(criterion, robust_output, y_a,
                                              y_b, lam)
            else:
                robust_loss = criterion(robust_output, y)

            if args.l1:
                for name, param in model.named_parameters():
                    if 'bn' not in name and 'bias' not in name:
                        robust_loss += args.l1 * param.abs().sum()

            opt.zero_grad()
            robust_loss.backward()
            opt.step()

            output = model(normalize(X))
            if args.mixup:
                loss = mixup_criterion(criterion, output, y_a, y_b, lam)
            else:
                loss = criterion(output, y)

            train_robust_loss += robust_loss.item() * y.size(0)
            train_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

        train_time = time.time()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_robust_loss = 0
        test_robust_acc = 0
        test_n = 0
        for i, batch in enumerate(test_batches):
            X, y = batch['input'], batch['target']

            # Random initialization
            if args.tst_attack == 'none':
                delta = torch.zeros_like(X)
            elif args.tst_attack == 'pgd':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_pgd_alpha,
                                   args.tst_attack_iters,
                                   args.tst_restarts,
                                   args.tst_norm,
                                   adv_models=tst_adv_model,
                                   rand_init=args.tst_fgsm_init)
            elif args.tst_attack == 'fgsm':
                delta = attack_pgd(model,
                                   X,
                                   y,
                                   tst_epsilon,
                                   tst_epsilon,
                                   1,
                                   1,
                                   args.tst_norm,
                                   rand_init=args.tst_fgsm_init,
                                   adv_models=tst_adv_model)
            # The Momentum Iterative Attack
            elif args.tst_attack == 'tmim':
                if tst_adv_model is None:
                    adversary = MomentumIterativeAttack(
                        model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                else:
                    tmp_model = nn.Sequential(
                        NormalizeByChannelMeanStd(cifar10_mean, cifar10_std),
                        tst_adv_model).to(device)

                    adversary = MomentumIterativeAttack(
                        tmp_model,
                        nb_iter=args.tst_attack_iters,
                        eps=tst_epsilon,
                        loss_fn=nn.CrossEntropyLoss(reduction="sum"),
                        eps_iter=tst_pgd_alpha,
                        clip_min=0,
                        clip_max=1,
                        targeted=False)
                data_adv = adversary.perturb(X, y)
                delta = data_adv - X
            # elif args.tst_attack == 'pgd':
            #     if tst_adv_model is None:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     else:
            #         tmp_model = nn.Sequential(NormalizeByChannelMeanStd(cifar10_mean, cifar10_std), tst_adv_model).to(device)

            #         adversary = PGDAttack(tmp_model, nb_iter=args.tst_attack_iters,
            #                         eps = tst_epsilon,
            #                         loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            #                         eps_iter=tst_pgd_alpha, clip_min = 0, clip_max = 1, targeted=False)
            #     data_adv = adversary.perturb(X, y)
            #     delta = data_adv - X

            delta = delta.detach()

            robust_output = model(
                normalize(
                    torch.clamp(X + delta[:X.size(0)],
                                min=lower_limit,
                                max=upper_limit)))
            robust_loss = criterion(robust_output, y)

            output = model(normalize(X))
            loss = criterion(output, y)

            test_robust_loss += robust_loss.item() * y.size(0)
            test_robust_acc += (robust_output.max(1)[1] == y).sum().item()
            test_loss += loss.item() * y.size(0)
            test_acc += (output.max(1)[1] == y).sum().item()
            test_n += y.size(0)

        test_time = time.time()

        if args.val:
            val_loss = 0
            val_acc = 0
            val_robust_loss = 0
            val_robust_acc = 0
            val_n = 0
            for i, batch in enumerate(val_batches):
                X, y = batch['input'], batch['target']

                # Random initialization
                if args.tst_attack == 'none':
                    delta = torch.zeros_like(X)
                elif args.tst_attack == 'pgd':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_pgd_alpha,
                                       args.tst_attack_iters,
                                       args.tst_restarts,
                                       args.tst_norm,
                                       early_stop=args.eval)
                elif args.tst_attack == 'fgsm':
                    delta = attack_pgd(model,
                                       X,
                                       y,
                                       tst_epsilon,
                                       tst_epsilon,
                                       1,
                                       1,
                                       args.tst_norm,
                                       early_stop=args.eval,
                                       rand_init=args.tst_fgsm_init)

                delta = delta.detach()

                robust_output = model(
                    normalize(
                        torch.clamp(X + delta[:X.size(0)],
                                    min=lower_limit,
                                    max=upper_limit)))
                robust_loss = criterion(robust_output, y)

                output = model(normalize(X))
                loss = criterion(output, y)

                val_robust_loss += robust_loss.item() * y.size(0)
                val_robust_acc += (robust_output.max(1)[1] == y).sum().item()
                val_loss += loss.item() * y.size(0)
                val_acc += (output.max(1)[1] == y).sum().item()
                val_n += y.size(0)

        if not args.eval:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, lr,
                train_loss / train_n, train_acc / train_n,
                train_robust_loss / train_n, train_robust_acc / train_n,
                test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)

            if args.val:
                logger.info('validation %.4f \t %.4f \t %.4f \t %.4f',
                            val_loss / val_n, val_acc / val_n,
                            val_robust_loss / val_n, val_robust_acc / val_n)

                if val_robust_acc / val_n > best_val_robust_acc:
                    torch.save(
                        {
                            'state_dict': model.state_dict(),
                            'test_robust_acc': test_robust_acc / test_n,
                            'test_robust_loss': test_robust_loss / test_n,
                            'test_loss': test_loss / test_n,
                            'test_acc': test_acc / test_n,
                            'val_robust_acc': val_robust_acc / val_n,
                            'val_robust_loss': val_robust_loss / val_n,
                            'val_loss': val_loss / val_n,
                            'val_acc': val_acc / val_n,
                        }, os.path.join(args.fname, f'model_val.pth'))
                    best_val_robust_acc = val_robust_acc / val_n

            # save checkpoint
            if (epoch + 1) % args.chkpt_iters == 0 or epoch + 1 == epochs:
                torch.save(model.state_dict(),
                           os.path.join(args.fname, f'model_{epoch}.pth'))
                torch.save(opt.state_dict(),
                           os.path.join(args.fname, f'opt_{epoch}.pth'))

            # save best
            if test_robust_acc / test_n > best_test_robust_acc:
                torch.save(
                    {
                        'state_dict': model.state_dict(),
                        'test_robust_acc': test_robust_acc / test_n,
                        'test_robust_loss': test_robust_loss / test_n,
                        'test_loss': test_loss / test_n,
                        'test_acc': test_acc / test_n,
                    }, os.path.join(args.fname, f'model_best.pth'))
                best_test_robust_acc = test_robust_acc / test_n
        else:
            logger.info(
                '%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f',
                epoch, train_time - start_time, test_time - train_time, -1, -1,
                -1, -1, -1, test_loss / test_n, test_acc / test_n,
                test_robust_loss / test_n, test_robust_acc / test_n)
            return
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    torch.cuda.set_device(int(args.gpu))

    setup_seed(args.seed)

    model = jigsaw_model(args.class_number)
    normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
    model = nn.Sequential(normalize, model)
    model.cuda()

    cudnn.benchmark = True

    train_trans = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor()
    ])

    val_trans = transforms.Compose([transforms.ToTensor()])

    #dataset process
    train_dataset = datasets.CIFAR10(args.data,
                                     train=True,
                                     transform=train_trans,
                                     download=True)
    test_dataset = datasets.CIFAR10(args.data,
                                    train=False,
                                    transform=val_trans,
                                    download=True)

    valid_size = 0.1
    indices = list(range(len(train_dataset)))
    split = int(np.floor(valid_size * len(train_dataset)))
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = torch.utils.data.Subset(train_dataset, train_idx)
    valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx)

    train_loader = torch.utils.data.DataLoader(train_sampler,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(valid_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=2,
                                              pin_memory=True)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader),
            1,  # since lr_lambda computes multiplicative factor
            1e-6 / args.lr))

    print('std training')
    train_acc = []
    ta = []

    #gernerate order of jigsaw
    permutation = np.array(
        [np.random.permutation(16) for i in range(args.class_number - 1)])
    np.save('permutation.npy', permutation)

    if os.path.exists(args.save_dir) is not True:
        os.mkdir(args.save_dir)

    for epoch in range(args.epochs):

        print(optimizer.state_dict()['param_groups'][0]['lr'])
        acc, loss = train(train_loader, model, criterion, optimizer, epoch,
                          scheduler, permutation)

        # evaluate on validation set
        tacc, tloss = validate(val_loader, model, criterion, permutation)

        train_acc.append(acc)
        ta.append(tacc)

        # remember best prec@1 and save checkpoint
        is_best = tacc > best_prec1
        best_prec1 = max(tacc, best_prec1)

        if is_best:

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'best_model.pt'))

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.pt'))

        plt.plot(train_acc, label='train_acc')
        plt.plot(ta, label='TA')
        plt.legend()
        plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
        plt.close()

    model_path = os.path.join(args.save_dir, 'best_model.pt')
    model.load_state_dict(torch.load(model_path)['state_dict'])
    print('testing result of ta best model')
    tacc, tloss = validate(test_loader, model, criterion, permutation)
def main():
    global args, best_sa
    args = parser.parse_args()
    print(args)

    torch.cuda.set_device(int(args.gpu))
    os.makedirs(args.save_dir, exist_ok=True)
    if args.seed:
        setup_seed(args.seed)

    if args.task == 'rotation':
        print('train for rotation classification')
        class_number = 4
    else:
        print('train for supervised classification')
        if args.dataset == 'cifar10':
            class_number = 10
        elif args.dataset == 'fmnist':
            class_number = 10
        elif args.dataset == 'cifar100':
            class_number = 100
        else:
            print('error dataset')
            assert 0

    # prepare dataset 
    if args.dataset == 'cifar10':
        print('training on cifar10 dataset')

        model = resnet18(num_classes=class_number)
        model.normalize = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        train_loader, val_loader, test_loader = cifar10_dataloaders(batch_size= args.batch_size, data_dir =args.data)
    
    elif args.dataset == 'cifar_10_10':
        print('training on cifar10 subset')

        model = resnet18(num_classes=class_number)
        model.normalize = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
        train_loader, val_loader, test_loader = cifar10_subset_dataloaders(batch_size= args.batch_size, data_dir =args.data)


    elif args.dataset == 'cifar100':

        model = resnet18(num_classes=class_number)
        model.normalize = NormalizeByChannelMeanStd(
            mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
        train_loader, val_loader, test_loader = cifar100_dataloaders(batch_size= args.batch_size, data_dir =args.data)

    elif args.dataset == 'fmnist':

        model = resnet18(num_classes=class_number)
        model.normalize = NormalizeByChannelMeanStd(
            mean=[0.2860], std=[0.3530])
        model.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1, bias=False)
        train_loader, val_loader, test_loader = fashionmnist_dataloaders(batch_size= args.batch_size, data_dir =args.data)

    else:
        print('dataset not support')

    model.cuda()
    criterion = nn.CrossEntropyLoss()
    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))

    if args.prune_type == 'lt':
        print( 'report lottery tickets setting')
        initalization = deepcopy(model.state_dict())
    else:
        initalization = None 

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)


    if args.resume:
        print('resume from checkpoint')
        checkpoint = torch.load(args.resume, map_location = torch.device('cuda:'+str(args.gpu)))
        best_sa = checkpoint['best_sa']
        start_epoch = checkpoint['epoch']
        all_result = checkpoint['result']
        start_state = checkpoint['state']

        if start_state>0:
            current_mask = extract_mask(checkpoint['state_dict'])
            prune_model_custom(model, current_mask)
            check_sparsity(model)

        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        initalization = checkpoint['init_weight']
        print('loading state:', start_state)
        print('loading from epoch: ',start_epoch, 'best_sa=', best_sa)

    else:
        all_result = {}
        all_result['train'] = []
        all_result['test_ta'] = []
        all_result['ta'] = []

        start_epoch = 0
        start_state = 0

    print('######################################## Start Standard Training Iterative Pruning ########################################')
    print(model.normalize)  


    for state in range(start_state, args.pruning_times):

        print('******************************************')
        print('pruning state', state)
        print('******************************************')
        
        for epoch in range(start_epoch, args.epochs):

            print(optimizer.state_dict()['param_groups'][0]['lr'])
            check_sparsity(model)
            acc = train(train_loader, model, criterion, optimizer, epoch)

            # evaluate on validation set
            tacc = validate(val_loader, model, criterion)
            # evaluate on test set
            test_tacc = validate(test_loader, model, criterion)

            scheduler.step()

            all_result['train'].append(acc)
            all_result['ta'].append(tacc)
            all_result['test_ta'].append(test_tacc)

            # remember best prec@1 and save checkpoint
            is_best_sa = tacc  > best_sa
            best_sa = max(tacc, best_sa)

            save_checkpoint({
                'state': state,
                'result': all_result,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_sa': best_sa,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'init_weight': initalization
            }, is_SA_best=is_best_sa, pruning=state, save_path=args.save_dir)
        
            plt.plot(all_result['train'], label='train_acc')
            plt.plot(all_result['ta'], label='val_acc')
            plt.plot(all_result['test_ta'], label='test_acc')
            plt.legend()
            plt.savefig(os.path.join(args.save_dir, str(state)+'net_train.png'))
            plt.close()

        #report result
        check_sparsity(model, True)
        print('report best SA={}'.format(best_sa))
        
        all_result = {}
        all_result['train'] = []
        all_result['test_ta'] = []
        all_result['ta'] = []

        best_sa = 0
        start_epoch = 0

        if args.prune_type == 'pt':
            print('report loading pretrained weight')
            initalization = torch.load(os.path.join(args.save_dir, '0model_SA_best.pth.tar'), map_location = torch.device('cuda:'+str(args.gpu)))['state_dict']


        #pruning_model(model, args.rate)
        
        #current_mask = extract_mask(model.state_dict())
        
        #remove_prune(model)

        #rewind weight to init

        pruning_model(model, args.rate)
        check_sparsity(model)
        current_mask = torch.load(os.path.join(args.mask_path, '{}checkpoint.pth.tar'.format(state+1)))['state_dict']
        remove_prune(model)

        #rewind weight to init
        model.load_state_dict(initalization) 
        prune_model_custom(model, current_mask)
        check_sparsity(model)

        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
Exemple #10
0
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

os.makedirs('logs', exist_ok=True)

# Model
MEAN = torch.Tensor([0.4914, 0.4822, 0.4465])
STD = torch.Tensor([0.2023, 0.1994, 0.2010])
norm = NormalizeByChannelMeanStd(mean=MEAN, std=STD)

print('==> Building model..')
net = torch.nn.DataParallel(nn.Sequential(norm, ResNet18_feat()).cuda())

cudnn.benchmark = True

checkpoint = torch.load('./checkpoint/batch_adv0_grad0_lambda_1.0.t7')
net.load_state_dict(checkpoint['net'])
#freeze last layer
# net.module[-1].linear.weight.requires_grad = False
# net.module[-1].linear.bias.requires_grad = False

if args.resume or args.test:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
Exemple #11
0
def main():
    global args, best_sa
    args = parser.parse_args()
    print(args)

    torch.cuda.set_device(int(args.gpu))
    os.makedirs(args.save_dir, exist_ok=True)
    if args.seed:
        setup_seed(args.seed)

    _, val_loader, _ = cifar10_dataloaders()
    class_number = 10
    model = resnet18(num_classes=class_number)

    # prepare dataset
    img_results = torch.load("results.pth", map_location="cpu")
    model.normalize = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465],
                                                std=[0.2470, 0.2435, 0.2616])
    model.cuda()
    criterion = nn.CrossEntropyLoss()

    initalization = deepcopy(model.state_dict())

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    all_result = {}
    all_result['train'] = []
    all_result['test_ta'] = []
    all_result['ta'] = []

    start_epoch = 0
    start_state = 0

    print(
        '######################################## Start Standard Training Iterative Pruning ########################################'
    )
    print(model.normalize)

    for state in range(start_state, args.pruning_times):

        print('******************************************')
        print('pruning state', state)
        print('******************************************')

        for epoch in range(start_epoch, args.epochs):

            print(optimizer.state_dict()['param_groups'][0]['lr'])
            check_sparsity(model)
            acc = train(img_results, model, criterion, optimizer, epoch)

            # evaluate on validation set
            tacc = validate(val_loader, model, criterion)
            # evaluate on test set

            all_result['train'].append(acc)
            all_result['ta'].append(tacc)

            # remember best prec@1 and save checkpoint
            is_best_sa = tacc > best_sa
            best_sa = max(tacc, best_sa)

            save_checkpoint(
                {
                    'state': state,
                    'result': all_result,
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_sa': best_sa,
                    'optimizer': optimizer.state_dict(),
                    'init_weight': initalization
                },
                is_SA_best=is_best_sa,
                pruning=state,
                save_path=args.save_dir)

            plt.plot(all_result['train'], label='train_acc')
            plt.plot(all_result['ta'], label='val_acc')
            plt.plot(all_result['test_ta'], label='test_acc')
            plt.legend()
            plt.savefig(
                os.path.join(args.save_dir,
                             str(state) + 'net_train.png'))
            plt.close()

        #report result
        check_sparsity(model, True)
        print('report best SA={}'.format(best_sa))

        all_result = {}
        all_result['train'] = []
        all_result['test_ta'] = []
        all_result['ta'] = []

        best_sa = 0
        start_epoch = 0

        if args.prune_type == 'pt':
            print('report loading pretrained weight')
            initalization = torch.load(
                os.path.join(args.save_dir, '0model_SA_best.pth.tar'),
                map_location=torch.device('cuda:' +
                                          str(args.gpu)))['state_dict']

        pruning_model(model, args.rate)
        check_sparsity(model)
        current_mask = extract_mask(model.state_dict())
        remove_prune(model)

        #rewind weight to init
        model.load_state_dict(initalization)

        #pruning using custom mask
        prune_model_custom(model, current_mask)
        check_sparsity(model)

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
Exemple #12
0
    def init_advertorch(self, model, device, attack_params, dataset_params):
        mean = dataset_params['mean']
        std = dataset_params['std']
        num_classes = dataset_params['num_classes']

        self.normalize = NormalizeByChannelMeanStd(mean=mean, std=std)
        basic_model = model

        if (attack_params['bpda'] == True):
            preprocess = attack_params['preprocess']
            preprocess_bpda_wrapper = BPDAWrapper(
                preprocess, forwardsub=preprocess.back_approx)
            attack_model = nn.Sequential(self.normalize,
                                         preprocess_bpda_wrapper,
                                         basic_model).to(device)
        else:
            attack_model = nn.Sequential(self.normalize,
                                         basic_model).to(device)

        attack_name = attack_params['attack'].lower()
        if (attack_name == 'pgd'):
            iterations = attack_params['iterations']
            stepsize = attack_params['stepsize']
            epsilon = attack_params['epsilon']
            attack = advertorch.attacks.LinfPGDAttack
            random = attack_params['random']
            # Return attack dictionary
            return {
                'attack': attack,
                'iterations': iterations,
                'epsilon': epsilon,
                'stepsize': stepsize,
                'model': attack_model,
                'random': random
            }

        elif (attack_name == 'cw'):
            iterations = attack_params['iterations']
            epsilon = attack_params['epsilon']
            attack = advertorch.attacks.CarliniWagnerL2Attack

            # Return attack dictionary
            return {
                'attack': attack,
                'iterations': iterations,
                'epsilon': epsilon,
                'model': attack_model,
                'num_classes': num_classes
            }
        elif (attack_name == 'fgsm'):
            epsilon = attack_params['epsilon']
            attack = advertorch.attacks.FGSM
            # Return attack dictionary
            return {
                'attack': attack,
                'iterations': 1,
                'epsilon': epsilon,
                'model': attack_model,
                'num_classes': num_classes
            }

        else:
            # Right way to handle exception in python see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python
            # Explains all the traps of using exception, does a good job!! I mean the link :)
            raise ValueError("Unsupported attack")
Exemple #13
0
def main():
    global args, best_prec1, best_ata
    args = parser.parse_args()
    print(args)

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    torch.cuda.set_device(int(args.gpu))
    setup_seed(args.seed)

    model = ResNet50()
    normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
    model = nn.Sequential(normalize, model)
    model.cuda()
    cudnn.benchmark = True

    if args.pretrained_model:
        model_dict_pretrain = torch.load(
            args.pretrained_model,
            map_location=torch.device('cuda:' + str(args.gpu)))
        model.load_state_dict(model_dict_pretrain, strict=False)
        print('model loaded:', args.pretrained_model)

    #dataset process
    train_trans = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, 4),
        transforms.ToTensor()
    ])

    val_trans = transforms.Compose([transforms.ToTensor()])

    #dataset process
    train_dataset = datasets.CIFAR10(args.data,
                                     train=True,
                                     transform=train_trans,
                                     download=True)
    test_dataset = datasets.CIFAR10(args.data,
                                    train=False,
                                    transform=val_trans,
                                    download=True)

    valid_size = 0.1
    indices = list(range(len(train_dataset)))
    split = int(np.floor(valid_size * len(train_dataset)))
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = torch.utils.data.Subset(train_dataset, train_idx)
    valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx)

    train_loader = torch.utils.data.DataLoader(train_sampler,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=2,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(valid_sampler,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=2,
                                              pin_memory=True)

    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=decreasing_lr,
                                                     gamma=0.1)

    print('adv training')

    train_acc = []
    ta = []
    ata = []

    if os.path.exists(args.save_dir) is not True:
        os.mkdir(args.save_dir)

    for epoch in range(args.epochs):

        print(optimizer.state_dict()['param_groups'][0]['lr'])
        acc, loss = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        tacc, tloss = validate(val_loader, model, criterion)

        atacc, atloss = validate_adv(val_loader, model, criterion)

        scheduler.step()

        train_acc.append(acc)
        ta.append(tacc)
        ata.append(atacc)

        # remember best prec@1 and save checkpoint
        is_best = tacc > best_prec1
        best_prec1 = max(tacc, best_prec1)

        ata_is_best = atacc > best_ata
        best_ata = max(atacc, best_ata)

        if is_best:

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'best_model.pt'))

        if ata_is_best:

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'ata_best_model.pt'))

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            },
            is_best,
            filename=os.path.join(args.save_dir, 'model.pt'))

        plt.plot(train_acc, label='train_acc')
        plt.plot(ta, label='TA')
        plt.plot(ata, label='ATA')
        plt.legend()
        plt.savefig(os.path.join(args.save_dir, 'net_train.png'))
        plt.close()

    best_model_path = os.path.join(args.save_dir, 'ata_best_model.pt')
    print('start testing ATA best model')
    model.load_state_dict(torch.load(best_model_path)['state_dict'])
    tacc, tloss = validate(test_loader, model, criterion)
    atacc, atloss = validate_adv(test_loader, model, criterion)

    best_model_path = os.path.join(args.save_dir, 'best_model.pt')
    print('start testing TA best model')
    model.load_state_dict(torch.load(best_model_path)['state_dict'])
    tacc, tloss = validate(test_loader, model, criterion)
    atacc, atloss = validate_adv(test_loader, model, criterion)
Exemple #14
0
def main():
    global args, best_prec1, ata_best_prec1

    args = parser.parse_args()
    print(args)
    
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    torch.cuda.set_device(int(args.gpu))

    setup_seed(args.seed)

    # Data Preprocess
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(2),
            transforms.RandomCrop(32),
            # transforms.RandomHorizontalFlip(),
            # transforms.RandomRotation(5),
            transforms.ToTensor(),
            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'val': transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Pad(2),
            #transforms.RandomCrop(32),
            transforms.ToTensor(),
            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    } 
    if args.dataset == 'cifar':
        train_dataset = CifarDataset(args.data, True, data_transforms['train'], args.percent)
        test_dataset = CifarDataset(args.data, False, data_transforms['val'], 1)
    elif args.dataset == 'imagenet':
        train_dataset = ImageNetDataset(args.data, True, data_transforms['train'], args.percent)
        test_dataset = ImageNetDataset(args.data, False, data_transforms['val'], 1)

    elif args.dataset == 'imagenet224':
        data_transforms = {
            'train': transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]),
            'val': transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])
        } 
        train_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['train'])
        test_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['val'])

    valid_size = 0.1
    indices = list(range(len(train_dataset)))
    split = int(np.floor(valid_size*len(train_dataset)))
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = torch.utils.data.Subset(train_dataset, train_idx)
    valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_sampler,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        valid_sampler,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)


    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # define model 
    n_split = 4
    selfie_model = get_selfie_model(n_split)
    selfie_model = selfie_model.cuda()

    P=get_P_model()
    normalize = NormalizeByChannelMeanStd(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    P = nn.Sequential(normalize, P)
    P = P.cuda()

    #define optimizer and scheduler 
    params_list = [{'params': selfie_model.parameters(), 'lr': args.lr,
                        'weight_decay': args.weight_decay},]
    params_list.append({'params': P.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay})
    optimizer = torch.optim.SGD(params_list, lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay, nesterov = True)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader),
            1,  # since lr_lambda computes multiplicative factor
            1e-7 / args.lr))

    print("Training model.")
    step = 0
    if os.path.exists(args.modeldir) is not True:
        os.mkdir(args.modeldir)
    stats_ = stats(args.modeldir, args.start_epoch)

    if args.epochs > 0:

        #order of patches 
        all_seq=[np.random.permutation(16) for ind in range(400)]
        pickle.dump(all_seq, open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'wb'))
        # all_seq=pickle.load(open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'rb'))
      
        print("Begin selfie training...")
        for epoch in range(args.start_epoch, args.epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
            trainObj, top1 = train_selfie_adv(train_loader, selfie_model, P, criterion, optimizer, epoch, scheduler)

            valObj, prec1 = val_selfie(val_loader, selfie_model, P, criterion, all_seq)

            adv_valObj, adv_prec1 = val_pgd_selfie(val_loader, selfie_model, P, criterion, all_seq)

            stats_._update(trainObj, top1, valObj, prec1,adv_valObj, adv_prec1)

            
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            ata_is_best = adv_prec1 > ata_best_prec1
            ata_best_prec1 = max(adv_prec1, ata_best_prec1)
            if is_best:
                torch.save(
                    {
                    'epoch': epoch,
                    'P_state': P.state_dict(),
                    'selfie_state': selfie_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                    }, os.path.join(args.modeldir, 'adv_selfie_TA_model_best.pth.tar'))

            if ata_is_best:
                torch.save(
                    {
                    'epoch': epoch,
                    'P_state': P.state_dict(),
                    'selfie_state': selfie_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                    }, os.path.join(args.modeldir, 'adv_selfie_ATA_model_best.pth.tar'))

            torch.save(
                    {
                    'epoch': epoch,
                    'P_state': P.state_dict(),
                    'selfie_state': selfie_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                    }, os.path.join(args.modeldir, 'adv_selfie_checkpoint.pth.tar'))

            plot_curve(stats_, args.modeldir, True)
            data = stats_
            sio.savemat(os.path.join(args.modeldir,'stats.mat'), {'data':data})
   

        print("testing ATA best selfie model from checkpoint...")
        model_path = os.path.join(args.modeldir, 'adv_selfie_ATA_model_best.pth.tar')
        model_loaded = torch.load(model_path)

        P.load_state_dict(model_loaded['P_state'])
        selfie_model.load_state_dict(model_loaded['selfie_state'])
        print("Best ATA selfie model loaded! ")

        valObj, prec1 = val_selfie(test_loader, selfie_model, P, criterion, all_seq)
        adv_valObj, adv_prec1 = val_pgd_selfie(test_loader, selfie_model, P, criterion, all_seq)


        print("testing TA best selfie model from checkpoint...")
        model_path = os.path.join(args.modeldir, 'adv_selfie_TA_model_best.pth.tar')
        model_loaded = torch.load(model_path)

        P.load_state_dict(model_loaded['P_state'])
        selfie_model.load_state_dict(model_loaded['selfie_state'])
        print("Best TA selfie model loaded! ")
        
        valObj, prec1 = val_selfie(test_loader, selfie_model, P, criterion, all_seq)
        adv_valObj, adv_prec1 = val_pgd_selfie(test_loader, selfie_model, P, criterion, all_seq)
Exemple #15
0
    def __init__(
        self,
        clip_values,
        step=200,
        means=None,
        stds=None,
        gan_ckpt='styleGAN.pt',
        encoder_ckpt='encoder.pt',
        optimize_noise=True,
        use_noise_regularize=False,
        use_lpips=False,
        apply_fit=False,
        apply_predict=True,
        mse=500,
        lr_rampup=0.05,
        lr_rampdown=0.05,
        noise=0.05,
        noise_ramp=0.75,
        noise_regularize=1e5,
        lr=0.1
    ):
        
        super(InvGAN, self).__init__()
        #print("invgan")
        #pdb.set_trace()
        self._apply_fit = apply_fit
        self._apply_predict = apply_predict
        # setup normalization parameters
        if means is None:
            means = (0.0, 0.0, 0.0)  # identity operation
        if len(means) != 3:
            raise ValueError("means must have 3 values, one per channel")
        self.means = means

        if stds is None:
            stds = (1.0, 1.0, 1.0)  # identity operation
        if len(stds) != 3:
            raise ValueError("stds must have 3 values, one per channel")
        self.stds = stds
        self.clip_values = clip_values

        # setup optimization parameters
        self.optimize_noise = optimize_noise
        self.use_noise_regularize = use_noise_regularize
        self.use_lpips = use_lpips
        self.step = step
        self.mse = mse
        self.lr = lr
        self.lr_rampup = lr_rampup
        self.lr_rampdown = lr_rampdown
        self.noise = noise
        self.noise_ramp = noise_ramp
        self.noise_regularize = noise_regularize

        # setup generator
        self.generator = Generator(256, 512, 8)
        #self.generator.load_state_dict(torch.load(gan_ckpt)['g_ema'])
        self.generator.load_state_dict(torch.load(maybe_download_weights_from_s3(gan_ckpt))['g_ema'])
        self.generator.eval()
        self.generator.cuda()
        self.deprocess_layer = NormalizeByChannelMeanStd([-1., -1., -1.], [2., 2., 2.]).cuda()

        # setup encoder
        self.encoder = Encoder()
        #self.encoder.load_state_dict(torch.load(encoder_ckpt)['netE'])
        self.encoder.load_state_dict(torch.load(maybe_download_weights_from_s3(encoder_ckpt))['netE'])
        self.encoder.eval()
        self.encoder.cuda()

        # setup loss
        if use_lpips:
            self.lpips = PerceptualLoss().cuda()

        # estimate latent code statistics
        n_mean_latent = 10000
        with torch.no_grad():
            noise_sample = torch.randn(n_mean_latent, 512, device='cuda')
            latent_out = self.generator.style(noise_sample)

            latent_mean = latent_out.mean(0)
            self.latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
Exemple #16
0
    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 zero_init_residual=False,
                 groups=1,
                 width_per_group=64,
                 replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(
                                 replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        self.normalize = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])

        self.conv1 = nn.Conv2d(3,
                               self.inplanes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.Identity()

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)