def main(): name = 'trades' # set up the experiment args, logger, device, \ (train_loader, test_loader), \ (model_factory, optimizer_factory, scheduler_factory) \ = init_experiment(args_factory=get_args, name=name) model = model_factory() opt = optimizer_factory(model) scheduler = scheduler_factory(opt) ### training adversary config assert args.geometry == 'linf', \ 'l2 adversary not supported yet' std = STD.to(device) upper_limit = UPPER_LIMIT.to(device) lower_limit = LOWER_LIMIT.to(device) epsilon = (args.epsilon / 255.) / std # parameters for PGD training pgd_train_kwargs = dict(reg_weight=args.reg_weight, geometry=args.geometry, epsilon=epsilon, alpha=(args.alpha / 255.) / std, lower_limit=lower_limit, upper_limit=upper_limit, attack_iters=args.attack_iters, criterion=nn.CrossEntropyLoss()) # parameters for PGD early stopping pgd_kwargs = dict(epsilon=epsilon, alpha=(2 / 255.) / std, lower_limit=lower_limit, upper_limit=upper_limit, attack_iters=5, restarts=1) # training model, best_state_dict = fit( step=functools.partial(train_step, **pgd_train_kwargs), epochs=args.epochs, model=model, optimizer=opt, scheduler=scheduler, data_loader=train_loader, model_path=os.path.join(args.out_dir, f'model_preact_resnet18_{name}.pt'), logger=logger, early_stop=args.early_stop, pgd_kwargs=pgd_kwargs) # eval if not args.no_eval: model_test = model_factory() model_test.load_state_dict(best_state_dict) evaluate(model=model_test, test_loader=test_loader, upper_limit=upper_limit, lower_limit=lower_limit, verbose=args.no_verbose)
def main(): name = 'distilled' # set up the experiment args, logger, device, \ (train_loader, test_loader), \ (model_factory, optimizer_factory, scheduler_factory) \ = init_experiment(args_factory=get_args, name=name) assert not args.early_stop, \ 'Model distillation is non-robust, do not early stop wrt adversarial accuracy.' if args.pretrained_model is None: logger.info('No pretrained model specified... pretraining model via `train_standard.py`.') process = subprocess.run([ 'python', 'train_standard.py', '--epochs', str(args.epochs), '--no-eval', '--temperature', str(args.softmax_temperature)]) pretrained_model_path = os.path.join(args.out_dir, 'model_preact_resnet18_standard.pt') else: logger.info(f'Pretrained model specified as `{args.pretrained_model}`... skipping pretraining.') pretrained_model_path = args.pretrained_model pretrained_model = load_model( path =pretrained_model_path, model_factory=model_factory, device =device) model = model_factory() opt = optimizer_factory(model) scheduler = scheduler_factory(opt) ### train model model, best_state_dict = fit( step =functools.partial(train_step, pretrained_model =pretrained_model, softmax_temperature=args.softmax_temperature), epochs =args.epochs, model =model, optimizer =opt, scheduler =scheduler, data_loader=train_loader, model_path =os.path.join(args.out_dir, f'model_preact_resnet18_{name}.pt'), logger =logger, early_stop =args.early_stop) ### evaluate model if not args.no_eval: model_test = model_factory() model_test.load_state_dict(best_state_dict) evaluate( model =model_test, test_loader=test_loader, upper_limit=UPPER_LIMIT, lower_limit=LOWER_LIMIT, verbose =args.no_verbose)
def main(): name = 'label_smoothing' # set up the experiment args, logger, device, \ (train_loader, test_loader), \ (model_factory, optimizer_factory, scheduler_factory) \ = init_experiment(args_factory=get_args, name=name) assert not args.early_stop, \ 'Model distillation is non-robust, do not early stop wrt adversarial accuracy.' model = model_factory() opt = optimizer_factory(model) scheduler = scheduler_factory(opt) model, best_state_dict = fit( step=functools.partial(train_step, smoothing=args.smoothing), epochs=args.epochs, model=model, optimizer=opt, scheduler=scheduler, data_loader=train_loader, model_path=os.path.join(args.out_dir, f'model_preact_resnet18_{name}.pt'), logger=logger, early_stop=args.early_stop) if not args.no_eval: model_test = model_factory() model_test.load_state_dict(best_state_dict) evaluate(model=model_test, test_loader=test_loader, upper_limit=UPPER_LIMIT, lower_limit=LOWER_LIMIT, verbose=args.no_verbose)