Exemple #1
0
def main():
    args = options()

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_set = preload.datasets.MNISTDataset('../data',
                                              train=True,
                                              download=True,
                                              transform=transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize((0.5, ),
                                                                       (0.5, ))
                                              ]))

    train_loader = preload.dataloader.DataLoader(train_set,
                                                 batch_size=args.batch_size)

    test_loader = preload.dataloader.DataLoader(
        preload.datasets.MNISTDataset('../data',
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5, ),
                                                               (0.5, ))
                                      ])),
        batch_size=args.test_batch_size)

    if args.model == 'cnn':
        model = CNN().to(device)
    elif args.model == 'cnn_leaky_relu':
        model = CNNLeakyReLU.to(device)
    else:
        print("model error")
        exit()
    start_point = copy.deepcopy(model.state_dict())

    # print("\nNormal training:")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     normal_method = NormalTrain(model, device, train_loader, optimizer)
    #     model_training(args, model, normal_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn.pt")
    # evaluation(args, model, device, test_loader)

    # print("\nNormal training with L2 regularization:")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn_l2_regular.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     l2_method = L2RegularTrain(model,
    #                                 device,
    #                                 train_loader,
    #                                 optimizer,
    #                                 weight_decay=args.weight_decay)
    #     model_training(args, model, l2_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn_l2_regular.pt")
    # evaluation(args, model, device, test_loader)

    # print("\nTraining with adversarial gradient regularization:")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn_adv_grad_regular.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     adv_grad_reg_method = AdversarialGradientRegularTrain(model,
    #                                                         device,
    #                                                         train_loader,
    #                                                         optimizer,
    #                                                         gradient_decay=args.gradient_decay)
    #     model_training(args, model, adv_grad_reg_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn_adv_grad_regular.pt")
    # evaluation(args, model, device, test_loader)

    # print("\nAdversarial training (FGSM):")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn_fgsm.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     fgsm = FastGradientSignMethod(lf=F.nll_loss, eps=args.eps)
    #     adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=fgsm)
    #     model_training(args, model, adv_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(fgsm.name))
    # evaluation(args, model, device, test_loader)

    # print("\nAdversarial training (BIM):")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn_bim.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     bim = BasicIterativeMethod(lf=F.nll_loss, eps=args.eps, alpha=args.alpha, iter_max=args.iter_max)
    #     adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=bim)
    #     model_training(args, model, adv_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(bim.name))
    # evaluation(args, model, device, test_loader)

    # print("\nAdversarial training (PGD):")
    # if args.load_model:
    #     model.load_state_dict(torch.load("mnist_cnn_pgd.pt"))
    # else:
    #     model.load_state_dict(start_point)
    #     optimizer = optim.SGD(model.parameters(), lr=args.lr)
    #     scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    #     pgd = ProjectedGradientDescent(lf=F.nll_loss, eps=args.eps, alpha=args.alpha, iter_max=args.iter_max)
    #     adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=pgd)
    #     model_training(args, model, adv_method, device, test_loader, scheduler)
    #     if args.save_model:
    #         torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(pgd.name))
    # evaluation(args, model, device, test_loader)

    print("\nAdversarial guided training (FGSM):")
    if args.load_model:
        model.load_state_dict(torch.load("mnist_cnn_adv_guided.pt"))
    else:
        model.load_state_dict(start_point)
        optimizer = optim.SGD(model.parameters(), lr=args.lr)
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        guide_sets = make_guide_set(train_set, size=1000)
        adv_guided_method = AdversarialGuidedTrain(
            model,
            device,
            train_loader,
            optimizer,
            guide_sets=guide_sets,
            epsilon=args.eps,
            beta=args.beta,
            weight_decay=args.weight_decay,
            gradient_decay=args.gradient_decay)
        model_training(args, model, adv_guided_method, device, test_loader,
                       scheduler)
        if args.save_model:
            torch.save(model.state_dict(), "mnist_cnn_adv_guided.pt")
    evaluation(args, model, device, test_loader)
ap = argparse.ArgumentParser()
ap.add_argument('--models', required=True, type=str, nargs="+")
ap.add_argument('--dataset', required=True, type=str)
args = vars(ap.parse_args())
args = Struct(**args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if 'cnn' in args.models:
    print("Initializing CNN...")
    model = CNN(cfg.input_sizes[args.dataset],
                cfg.output_sizes[args.dataset]).to(device)
    print('NA', model.output_size)
    init_path = '../ckpts/init/{}_cnn.init'.format(args.dataset)
    torch.save(model.state_dict(), init_path)
    print('Save init: {}'.format(init_path))

if 'fcn' in args.models:
    print("Initializing FCN...")
    model = FCN(cfg.input_sizes[args.dataset],
                cfg.output_sizes[args.dataset]).to(device)
    print(model.input_size, model.output_size)
    init_path = '../ckpts/init/{}_fcn.init'.format(args.dataset)
    torch.save(model.state_dict(), init_path)
    print('Save init: {}'.format(init_path))

if 'svm' in args.models:
    print("Initializing SVM...")
    model = SVM(cfg.input_sizes[args.dataset],
                cfg.output_sizes[args.dataset]).to(device)