def setup_store_with_metadata(args):
    '''
    Sets up a store for training according to the arguments object. See the
    argparse object above for options.
    '''
    # Create the store
    store = cox.store.Store(args.out_dir, args.exp_name)
    args_dict = args.__dict__
    schema = cox.store.schema_from_dict(args_dict)
    store.add_table('metadata', schema)
    store['metadata'].append_row(args_dict)
    return store
Ejemplo n.º 2
0
def evaluate(args, store):
    ds, (train_loader, val_loader) = get_dataset_and_loaders(args)

    if 'corruptions_eval' not in store.keys:
        store.add_table('corruptions_eval', {'type': str, 'severity': int, 'corr_acc': float})
    store.save_dir = store.save_dir + '_eval'
    if not os.path.exists(store.save_dir):
        os.makedirs(store.save_dir)
    
    model = get_boosted_model(args, ds) 
    # model.booster = None

    def iteration_hook(model, i, loop_type, inp, target):
        if model.module.booster is None:
            return
        if i % 50 == 0:
            inp, target = inp.cuda(), target.cuda()
            example_boosted = model.module.booster(inp, target)
            bs_path = Path(store.save_dir) / f'boosted_{i}.jpg'
            save_image(example_boosted[:4], bs_path)
            example_adversaried = model.module.boosted_model.apply(example_boosted)
            inp_path = Path(store.save_dir) / f'inp_{i}.jpg'
            adv_path = Path(store.save_dir) / f'adv_{i}.jpg'
            save_image(inp[:4], inp_path)
            save_image(example_adversaried[:4], adv_path)
            if i == 0:
                print(f'Saved in {store.save_dir}')

    args.iteration_hook = iteration_hook

    with ch.no_grad():
        # evaluate on corrupted boosted dataset
        for corr in FIXED_CONFIG[ds.ds_name]:
            for severity in range(1,6):
                print('---------------------------------------------------')
                print(f"Dataset: {args.dataset} | Model: {args.arch} | Corruption: {corr} | Severity: {severity}")
                print('---------------------------------------------------')
                model.boosted_model.augmentations = [corr]
                model.boosted_model.severity = severity
                result = train.eval_model(args, model, val_loader, store=None)
                store['corruptions_eval'].append_row({'type': corr[:14], 'severity': severity, 'corr_acc': result['nat_prec1']})

        # evlautate on the clean boosted dataset
        print('---------------------------------------------------')
        print(f"Dataset: {args.dataset} | Model: {args.arch} | Corruption: None | Severity: N/A")
        print('---------------------------------------------------')
        model.boosted_model.augmentations = []
        result = train.eval_model(args, model, val_loader, store=None)
        store['corruptions_eval'].append_row({'type': 'clean', 'severity': -1, 'corr_acc': result['nat_prec1']})
Ejemplo n.º 3
0
def setup_store_with_metadata(args):
    '''
    Sets up a store for training according to the arguments object. See the
    argparse object above for options.
    '''
    # Add git commit to args
    repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)),
                        search_parent_directories=True)
    git_commit = repo.head.object.hexsha
    args.git_commit = git_commit

    # Create the store
    store = cox.store.Store(args.out_dir, args.exp_name)
    args_dict = args.as_dict()
    schema = cox.store.schema_from_dict(args_dict)
    store.add_table('metadata', schema)
    store['metadata'].append_row(args_dict)

    return store
Ejemplo n.º 4
0
def setup_store_with_metadata(args):
    '''
    Sets up a store for training according to the arguments object. See the
    argparse object above for options.
    '''
    # Add git commit to args
    try:
        repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)),
                        search_parent_directories=True)
        version = repo.head.object.hexsha
    except git.exc.InvalidGitRepositoryError:
        version = __version__
    args.version = version

    # Create the store
    store = cox.store.Store(args.out_dir, args.exp_name)
    args_dict = args.__dict__
    schema = cox.store.schema_from_dict(args_dict)
    store.add_table('metadata', schema)
    store['metadata'].append_row(args_dict)

    return store
Ejemplo n.º 5
0
def main():
    args = Bunch(config)

    print("Translating model file")
    path_hash = hashlib.md5(args.model_path.encode("utf-8")).hexdigest()
    translated_model_path = f"/tmp/checkpoint{path_hash}"
    g = ch.load(args.model_path)
    sd = {}
    for k, v in g["state_dict"].items():
        kk = k[len("1.module."):]
        sd[f"module.attacker.model.{kk}"] = v
        sd[f"module.model.{kk}"] = v
    ch.save({"state_dict": sd, "epoch": g["epoch"]}, translated_model_path)
    args.__dict__["model_path"] = translated_model_path
    print("Done translating")

    # Create store and log the args
    store = StoreWrapper(os.path.join(output_dir, "cox"))
    if "metadata" not in store.keys:
        args_dict = args.__dict__
        schema = cox.store.schema_from_dict(args_dict)
        store.add_table("metadata", schema)
        store["metadata"].append_row(args_dict)
    else:
        print("[Found existing metadata in store. Skipping this part.]")

    ds, train_loader, validation_loader = get_dataset_and_loaders(args)

    if args.per_class_accuracy:
        assert args.dataset in [
            "pets",
            "caltech101",
            "caltech256",
            "flowers",
            "aircraft",
        ], f"Per-class accuracy not supported for the {args.dataset} dataset."

        # VERY IMPORTANT
        # We report the per-class accuracy using the validation
        # set distribution. So ignore the training accuracy (as you will see it go
        # beyond 100. Don't freak out, it doesn't really capture anything),
        # just look at the validation accuarcy
        args.custom_accuracy = get_per_class_accuracy(args, validation_loader)

    model, checkpoint = get_model(args, ds)

    if args.eval_only:
        return train.eval_model(args, model, validation_loader, store=store)

    update_params = freeze_model(model, freeze_level=args.freeze_level)

    log_info({"state.progress": 0.0})
    print(f"Dataset: {args.dataset} | Model: {args.arch}")
    train.train_model(
        args,
        model,
        (train_loader, validation_loader),
        store=store,
        checkpoint=checkpoint,
        update_params=update_params,
    )
Ejemplo n.º 6
0
    return args


if __name__ == "__main__":

    args = parser.parse_args()
    args = args_preprocess(args)

    pytorch_models = {
        'alexnet': models.alexnet,
        'vgg16': models.vgg16,
        'vgg16_bn': models.vgg16_bn,
        'squeezenet': models.squeezenet1_0,
        'densenet': models.densenet161,
        'shufflenet': models.shufflenet_v2_x1_0,
        'mobilenet': models.mobilenet_v2,
        'resnext50_32x4d': models.resnext50_32x4d,
        'mnasnet': models.mnasnet1_0,
    }

    # Create store and log the args
    store = cox.store.Store(args.out_dir, args.exp_name)
    if 'metadata' not in store.keys:
        args_dict = args.__dict__
        schema = cox.store.schema_from_dict(args_dict)
        store.add_table('metadata', schema)
        store['metadata'].append_row(args_dict)
    else:
        print('[Found existing metadata in store. Skipping this part.]')
    main(args, store)
Ejemplo n.º 7
0
    # Useful for evaluation QRCodes since they are not robust to occlustions at all
    if args.no_translation: 
        constants.PATCH_TRANSFORMS['translate'] = (0., 0.)

    # Preprocess args
    args = defaults.check_and_fill_args(
        args, defaults.CONFIG_ARGS, datasets.DATASETS[args.dataset])
    args = defaults.check_and_fill_args(
        args, defaults.MODEL_LOADER_ARGS, datasets.DATASETS[args.dataset])

    store = cox.store.Store(args.out_dir, args.exp_name)
    if args.args_from_store:
        args_from_store = args.args_from_store.split(',')
        df = store['metadata'].df
        print(f'==>[Reading from existing store in {store.path}]')
        for a in args_from_store:
            if a not in df:
                raise Exception(f'Did not find {a} in the store {store.path}')
            setattr(args,a, df[a][0])
            print(f'==>[Read {a} = ({getattr(args, a)}) ')

    if 'metadata_eval' not in store.keys:
        args_dict = args.__dict__
        schema = cox.store.schema_from_dict(args_dict)
        store.add_table('metadata_eval', schema)
        store['metadata_eval'].append_row(args_dict)
    else:
        print('[Found existing metadata_eval in store. Skipping this part.]')

    evaluate(args, store)
def train_model(args, model, loaders, *, checkpoint=None,
                store=None, update_params=None):
    # Logging setup
    writer = store.tensorboard if store else None
    if store is not None:
        store.add_table(consts.LOGS_TABLE, consts.LOGS_SCHEMA)
        store.add_table(consts.CKPTS_TABLE, consts.CKPTS_SCHEMA)

    # Reformat and read arguments
    check_required_args(args) # Argument sanity check
    args.eps = eval(str(args.eps)) if has_attr(args, 'eps') else None
    args.attack_lr = eval(str(args.attack_lr)) if has_attr(args, 'attack_lr') else None

    # Initial setup
    train_loader, val_loader = loaders
    opt, schedule = make_optimizer_and_schedule(args, model, checkpoint, update_params)

    best_prec1, start_epoch = (0, 0)
    if checkpoint:
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint[f"{'adv' if args.adv_train else 'nat'}_prec1"]

    # Put the model into parallel mode
    assert not hasattr(model, "module"), "model is already in DataParallel."
    model = ch.nn.DataParallel(model).cuda()

    # Timestamp for training start time
    start_time = time.time()

    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        train_prec1, train_loss = _model_loop(args, 'train', train_loader,
                model, opt, epoch, args.adv_train, writer)
        last_epoch = (epoch == (args.epochs - 1))

        # evaluate on validation set
        sd_info = {
            'model':model.state_dict(),
            'optimizer':opt.state_dict(),
            'schedule':(schedule and schedule.state_dict()),
            'epoch': epoch+1
        }

        def save_checkpoint(filename):
            ckpt_save_path = os.path.join(args.out_dir if not store else \
                                          store.path, filename)
            ch.save(sd_info, ckpt_save_path, pickle_module=dill)

        save_its = args.save_ckpt_iters
        should_save_ckpt = (epoch % save_its == 0) and (save_its > 0)
        should_log = (epoch % args.log_iters == 0)

        if should_log or last_epoch or should_save_ckpt:
            # log + get best
            with ch.no_grad():
                prec1, nat_loss = _model_loop(args, 'val', val_loader, model,
                        None, epoch, False, writer)

            # loader, model, epoch, input_adv_exs
            should_adv_eval = args.adv_eval or args.adv_train
            adv_val = should_adv_eval and _model_loop(args, 'val', val_loader,
                    model, None, epoch, True, writer)
            adv_prec1, adv_loss = adv_val or (-1.0, -1.0)

            # remember best prec@1 and save checkpoint
            our_prec1 = adv_prec1 if args.adv_train else prec1
            is_best = our_prec1 > best_prec1
            best_prec1 = max(our_prec1, best_prec1)

            # log every checkpoint
            log_info = {
                'epoch':epoch + 1,
                'nat_prec1':prec1,
                'adv_prec1':adv_prec1,
                'nat_loss':nat_loss,
                'adv_loss':adv_loss,
                'train_prec1':train_prec1,
                'train_loss':train_loss,
                'time': time.time() - start_time
            }

            # Log info into the logs table
            if store: store[consts.LOGS_TABLE].append_row(log_info)
            # If we are at a saving epoch (or the last epoch), save a checkpoint
            if should_save_ckpt or last_epoch: save_checkpoint(ckpt_at_epoch(epoch))
            # If  store exists and this is the last epoch, save a checkpoint
            if last_epoch and store: store[consts.CKPTS_TABLE].append_row(sd_info)

            # Update the latest and best checkpoints (overrides old one)
            save_checkpoint(consts.CKPT_NAME_LATEST)
            if is_best: save_checkpoint(consts.CKPT_NAME_BEST)

        if schedule: schedule.step()
        if has_attr(args, 'epoch_hook'): args.epoch_hook(model, log_info)

    return model