def select_model(m): if m == 'large': model = pblm.mnist_model_large().cuda() _, test_loader = pblm.mnist_loaders(8) elif m == 'wide': print("Using wide model with model_factor={}".format(args.model_factor)) _, test_loader = pblm.mnist_loaders(64//args.model_factor) model = pblm.mnist_model_wide(args.model_factor).cuda() elif m == 'deep': print("Using deep model with model_factor={}".format(args.model_factor)) _, test_loader = pblm.mnist_loaders(64//(2**args.model_factor)) model = pblm.mnist_model_deep(args.model_factor).cuda() else: model = pblm.mnist_model().cuda() return model
args = parser.parse_args() args.prefix = args.prefix or 'fashion_mnist_conv_{:.4f}_{:.4f}_0'.format( args.epsilon, args.lr).replace(".", "_") #setproctitle.setproctitle(args.prefix) train_log = open(args.prefix + "_train.log", "w") test_log = open(args.prefix + "_test.log", "w") train_loader, _ = pblm.fashion_mnist_loaders(args.batch_size) _, test_loader = pblm.fashion_mnist_loaders(2) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) if args.large: model = pblm.mnist_model_large().cuda() elif args.vgg: model = pblm.mnist_model_vgg().cuda() else: model = pblm.mnist_model().cuda() opt = optim.Adam(model.parameters(), lr=args.lr) for t in range(args.epochs): if t <= args.epochs // 2 and args.starting_epsilon is not None: epsilon = args.starting_epsilon + (t / (args.epochs // 2)) * ( args.epsilon - args.starting_epsilon) else: epsilon = args.epsilon train_robust(train_loader, model,