def main(): args = options() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") train_set = preload.datasets.MNISTDataset('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])) train_loader = preload.dataloader.DataLoader(train_set, batch_size=args.batch_size) test_loader = preload.dataloader.DataLoader( preload.datasets.MNISTDataset('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, )) ])), batch_size=args.test_batch_size) if args.model == 'cnn': model = CNN().to(device) elif args.model == 'cnn_leaky_relu': model = CNNLeakyReLU.to(device) else: print("model error") exit() start_point = copy.deepcopy(model.state_dict()) # print("\nNormal training:") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # normal_method = NormalTrain(model, device, train_loader, optimizer) # model_training(args, model, normal_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn.pt") # evaluation(args, model, device, test_loader) # print("\nNormal training with L2 regularization:") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn_l2_regular.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # l2_method = L2RegularTrain(model, # device, # train_loader, # optimizer, # weight_decay=args.weight_decay) # model_training(args, model, l2_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn_l2_regular.pt") # evaluation(args, model, device, test_loader) # print("\nTraining with adversarial gradient regularization:") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn_adv_grad_regular.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # adv_grad_reg_method = AdversarialGradientRegularTrain(model, # device, # train_loader, # optimizer, # gradient_decay=args.gradient_decay) # model_training(args, model, adv_grad_reg_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn_adv_grad_regular.pt") # evaluation(args, model, device, test_loader) # print("\nAdversarial training (FGSM):") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn_fgsm.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # fgsm = FastGradientSignMethod(lf=F.nll_loss, eps=args.eps) # adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=fgsm) # model_training(args, model, adv_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(fgsm.name)) # evaluation(args, model, device, test_loader) # print("\nAdversarial training (BIM):") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn_bim.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # bim = BasicIterativeMethod(lf=F.nll_loss, eps=args.eps, alpha=args.alpha, iter_max=args.iter_max) # adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=bim) # model_training(args, model, adv_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(bim.name)) # evaluation(args, model, device, test_loader) # print("\nAdversarial training (PGD):") # if args.load_model: # model.load_state_dict(torch.load("mnist_cnn_pgd.pt")) # else: # model.load_state_dict(start_point) # optimizer = optim.SGD(model.parameters(), lr=args.lr) # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) # pgd = ProjectedGradientDescent(lf=F.nll_loss, eps=args.eps, alpha=args.alpha, iter_max=args.iter_max) # adv_method = AdversarialTrain(model, device, train_loader, optimizer, attack=pgd) # model_training(args, model, adv_method, device, test_loader, scheduler) # if args.save_model: # torch.save(model.state_dict(), "mnist_cnn_{}.pt".format(pgd.name)) # evaluation(args, model, device, test_loader) print("\nAdversarial guided training (FGSM):") if args.load_model: model.load_state_dict(torch.load("mnist_cnn_adv_guided.pt")) else: model.load_state_dict(start_point) optimizer = optim.SGD(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) guide_sets = make_guide_set(train_set, size=1000) adv_guided_method = AdversarialGuidedTrain( model, device, train_loader, optimizer, guide_sets=guide_sets, epsilon=args.eps, beta=args.beta, weight_decay=args.weight_decay, gradient_decay=args.gradient_decay) model_training(args, model, adv_guided_method, device, test_loader, scheduler) if args.save_model: torch.save(model.state_dict(), "mnist_cnn_adv_guided.pt") evaluation(args, model, device, test_loader)
ap = argparse.ArgumentParser() ap.add_argument('--models', required=True, type=str, nargs="+") ap.add_argument('--dataset', required=True, type=str) args = vars(ap.parse_args()) args = Struct(**args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if 'cnn' in args.models: print("Initializing CNN...") model = CNN(cfg.input_sizes[args.dataset], cfg.output_sizes[args.dataset]).to(device) print('NA', model.output_size) init_path = '../ckpts/init/{}_cnn.init'.format(args.dataset) torch.save(model.state_dict(), init_path) print('Save init: {}'.format(init_path)) if 'fcn' in args.models: print("Initializing FCN...") model = FCN(cfg.input_sizes[args.dataset], cfg.output_sizes[args.dataset]).to(device) print(model.input_size, model.output_size) init_path = '../ckpts/init/{}_fcn.init'.format(args.dataset) torch.save(model.state_dict(), init_path) print('Save init: {}'.format(init_path)) if 'svm' in args.models: print("Initializing SVM...") model = SVM(cfg.input_sizes[args.dataset], cfg.output_sizes[args.dataset]).to(device)