def setup_store_with_metadata(args): ''' Sets up a store for training according to the arguments object. See the argparse object above for options. ''' # Create the store store = cox.store.Store(args.out_dir, args.exp_name) args_dict = args.__dict__ schema = cox.store.schema_from_dict(args_dict) store.add_table('metadata', schema) store['metadata'].append_row(args_dict) return store
def evaluate(args, store): ds, (train_loader, val_loader) = get_dataset_and_loaders(args) if 'corruptions_eval' not in store.keys: store.add_table('corruptions_eval', {'type': str, 'severity': int, 'corr_acc': float}) store.save_dir = store.save_dir + '_eval' if not os.path.exists(store.save_dir): os.makedirs(store.save_dir) model = get_boosted_model(args, ds) # model.booster = None def iteration_hook(model, i, loop_type, inp, target): if model.module.booster is None: return if i % 50 == 0: inp, target = inp.cuda(), target.cuda() example_boosted = model.module.booster(inp, target) bs_path = Path(store.save_dir) / f'boosted_{i}.jpg' save_image(example_boosted[:4], bs_path) example_adversaried = model.module.boosted_model.apply(example_boosted) inp_path = Path(store.save_dir) / f'inp_{i}.jpg' adv_path = Path(store.save_dir) / f'adv_{i}.jpg' save_image(inp[:4], inp_path) save_image(example_adversaried[:4], adv_path) if i == 0: print(f'Saved in {store.save_dir}') args.iteration_hook = iteration_hook with ch.no_grad(): # evaluate on corrupted boosted dataset for corr in FIXED_CONFIG[ds.ds_name]: for severity in range(1,6): print('---------------------------------------------------') print(f"Dataset: {args.dataset} | Model: {args.arch} | Corruption: {corr} | Severity: {severity}") print('---------------------------------------------------') model.boosted_model.augmentations = [corr] model.boosted_model.severity = severity result = train.eval_model(args, model, val_loader, store=None) store['corruptions_eval'].append_row({'type': corr[:14], 'severity': severity, 'corr_acc': result['nat_prec1']}) # evlautate on the clean boosted dataset print('---------------------------------------------------') print(f"Dataset: {args.dataset} | Model: {args.arch} | Corruption: None | Severity: N/A") print('---------------------------------------------------') model.boosted_model.augmentations = [] result = train.eval_model(args, model, val_loader, store=None) store['corruptions_eval'].append_row({'type': 'clean', 'severity': -1, 'corr_acc': result['nat_prec1']})
def setup_store_with_metadata(args): ''' Sets up a store for training according to the arguments object. See the argparse object above for options. ''' # Add git commit to args repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) git_commit = repo.head.object.hexsha args.git_commit = git_commit # Create the store store = cox.store.Store(args.out_dir, args.exp_name) args_dict = args.as_dict() schema = cox.store.schema_from_dict(args_dict) store.add_table('metadata', schema) store['metadata'].append_row(args_dict) return store
def setup_store_with_metadata(args): ''' Sets up a store for training according to the arguments object. See the argparse object above for options. ''' # Add git commit to args try: repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) version = repo.head.object.hexsha except git.exc.InvalidGitRepositoryError: version = __version__ args.version = version # Create the store store = cox.store.Store(args.out_dir, args.exp_name) args_dict = args.__dict__ schema = cox.store.schema_from_dict(args_dict) store.add_table('metadata', schema) store['metadata'].append_row(args_dict) return store
def main(): args = Bunch(config) print("Translating model file") path_hash = hashlib.md5(args.model_path.encode("utf-8")).hexdigest() translated_model_path = f"/tmp/checkpoint{path_hash}" g = ch.load(args.model_path) sd = {} for k, v in g["state_dict"].items(): kk = k[len("1.module."):] sd[f"module.attacker.model.{kk}"] = v sd[f"module.model.{kk}"] = v ch.save({"state_dict": sd, "epoch": g["epoch"]}, translated_model_path) args.__dict__["model_path"] = translated_model_path print("Done translating") # Create store and log the args store = StoreWrapper(os.path.join(output_dir, "cox")) if "metadata" not in store.keys: args_dict = args.__dict__ schema = cox.store.schema_from_dict(args_dict) store.add_table("metadata", schema) store["metadata"].append_row(args_dict) else: print("[Found existing metadata in store. Skipping this part.]") ds, train_loader, validation_loader = get_dataset_and_loaders(args) if args.per_class_accuracy: assert args.dataset in [ "pets", "caltech101", "caltech256", "flowers", "aircraft", ], f"Per-class accuracy not supported for the {args.dataset} dataset." # VERY IMPORTANT # We report the per-class accuracy using the validation # set distribution. So ignore the training accuracy (as you will see it go # beyond 100. Don't freak out, it doesn't really capture anything), # just look at the validation accuarcy args.custom_accuracy = get_per_class_accuracy(args, validation_loader) model, checkpoint = get_model(args, ds) if args.eval_only: return train.eval_model(args, model, validation_loader, store=store) update_params = freeze_model(model, freeze_level=args.freeze_level) log_info({"state.progress": 0.0}) print(f"Dataset: {args.dataset} | Model: {args.arch}") train.train_model( args, model, (train_loader, validation_loader), store=store, checkpoint=checkpoint, update_params=update_params, )
return args if __name__ == "__main__": args = parser.parse_args() args = args_preprocess(args) pytorch_models = { 'alexnet': models.alexnet, 'vgg16': models.vgg16, 'vgg16_bn': models.vgg16_bn, 'squeezenet': models.squeezenet1_0, 'densenet': models.densenet161, 'shufflenet': models.shufflenet_v2_x1_0, 'mobilenet': models.mobilenet_v2, 'resnext50_32x4d': models.resnext50_32x4d, 'mnasnet': models.mnasnet1_0, } # Create store and log the args store = cox.store.Store(args.out_dir, args.exp_name) if 'metadata' not in store.keys: args_dict = args.__dict__ schema = cox.store.schema_from_dict(args_dict) store.add_table('metadata', schema) store['metadata'].append_row(args_dict) else: print('[Found existing metadata in store. Skipping this part.]') main(args, store)
# Useful for evaluation QRCodes since they are not robust to occlustions at all if args.no_translation: constants.PATCH_TRANSFORMS['translate'] = (0., 0.) # Preprocess args args = defaults.check_and_fill_args( args, defaults.CONFIG_ARGS, datasets.DATASETS[args.dataset]) args = defaults.check_and_fill_args( args, defaults.MODEL_LOADER_ARGS, datasets.DATASETS[args.dataset]) store = cox.store.Store(args.out_dir, args.exp_name) if args.args_from_store: args_from_store = args.args_from_store.split(',') df = store['metadata'].df print(f'==>[Reading from existing store in {store.path}]') for a in args_from_store: if a not in df: raise Exception(f'Did not find {a} in the store {store.path}') setattr(args,a, df[a][0]) print(f'==>[Read {a} = ({getattr(args, a)}) ') if 'metadata_eval' not in store.keys: args_dict = args.__dict__ schema = cox.store.schema_from_dict(args_dict) store.add_table('metadata_eval', schema) store['metadata_eval'].append_row(args_dict) else: print('[Found existing metadata_eval in store. Skipping this part.]') evaluate(args, store)
def train_model(args, model, loaders, *, checkpoint=None, store=None, update_params=None): # Logging setup writer = store.tensorboard if store else None if store is not None: store.add_table(consts.LOGS_TABLE, consts.LOGS_SCHEMA) store.add_table(consts.CKPTS_TABLE, consts.CKPTS_SCHEMA) # Reformat and read arguments check_required_args(args) # Argument sanity check args.eps = eval(str(args.eps)) if has_attr(args, 'eps') else None args.attack_lr = eval(str(args.attack_lr)) if has_attr(args, 'attack_lr') else None # Initial setup train_loader, val_loader = loaders opt, schedule = make_optimizer_and_schedule(args, model, checkpoint, update_params) best_prec1, start_epoch = (0, 0) if checkpoint: start_epoch = checkpoint['epoch'] best_prec1 = checkpoint[f"{'adv' if args.adv_train else 'nat'}_prec1"] # Put the model into parallel mode assert not hasattr(model, "module"), "model is already in DataParallel." model = ch.nn.DataParallel(model).cuda() # Timestamp for training start time start_time = time.time() for epoch in range(start_epoch, args.epochs): # train for one epoch train_prec1, train_loss = _model_loop(args, 'train', train_loader, model, opt, epoch, args.adv_train, writer) last_epoch = (epoch == (args.epochs - 1)) # evaluate on validation set sd_info = { 'model':model.state_dict(), 'optimizer':opt.state_dict(), 'schedule':(schedule and schedule.state_dict()), 'epoch': epoch+1 } def save_checkpoint(filename): ckpt_save_path = os.path.join(args.out_dir if not store else \ store.path, filename) ch.save(sd_info, ckpt_save_path, pickle_module=dill) save_its = args.save_ckpt_iters should_save_ckpt = (epoch % save_its == 0) and (save_its > 0) should_log = (epoch % args.log_iters == 0) if should_log or last_epoch or should_save_ckpt: # log + get best with ch.no_grad(): prec1, nat_loss = _model_loop(args, 'val', val_loader, model, None, epoch, False, writer) # loader, model, epoch, input_adv_exs should_adv_eval = args.adv_eval or args.adv_train adv_val = should_adv_eval and _model_loop(args, 'val', val_loader, model, None, epoch, True, writer) adv_prec1, adv_loss = adv_val or (-1.0, -1.0) # remember best prec@1 and save checkpoint our_prec1 = adv_prec1 if args.adv_train else prec1 is_best = our_prec1 > best_prec1 best_prec1 = max(our_prec1, best_prec1) # log every checkpoint log_info = { 'epoch':epoch + 1, 'nat_prec1':prec1, 'adv_prec1':adv_prec1, 'nat_loss':nat_loss, 'adv_loss':adv_loss, 'train_prec1':train_prec1, 'train_loss':train_loss, 'time': time.time() - start_time } # Log info into the logs table if store: store[consts.LOGS_TABLE].append_row(log_info) # If we are at a saving epoch (or the last epoch), save a checkpoint if should_save_ckpt or last_epoch: save_checkpoint(ckpt_at_epoch(epoch)) # If store exists and this is the last epoch, save a checkpoint if last_epoch and store: store[consts.CKPTS_TABLE].append_row(sd_info) # Update the latest and best checkpoints (overrides old one) save_checkpoint(consts.CKPT_NAME_LATEST) if is_best: save_checkpoint(consts.CKPT_NAME_BEST) if schedule: schedule.step() if has_attr(args, 'epoch_hook'): args.epoch_hook(model, log_info) return model