Exemplo n.º 1
0
def main():
    name = 'trades'

    # set up the experiment
    args, logger, device, \
    (train_loader, test_loader), \
    (model_factory, optimizer_factory, scheduler_factory) \
        = init_experiment(args_factory=get_args, name=name)

    model = model_factory()
    opt = optimizer_factory(model)
    scheduler = scheduler_factory(opt)

    ### training adversary config
    assert args.geometry == 'linf', \
        'l2 adversary not supported yet'
    std = STD.to(device)
    upper_limit = UPPER_LIMIT.to(device)
    lower_limit = LOWER_LIMIT.to(device)
    epsilon = (args.epsilon / 255.) / std
    # parameters for PGD training
    pgd_train_kwargs = dict(reg_weight=args.reg_weight,
                            geometry=args.geometry,
                            epsilon=epsilon,
                            alpha=(args.alpha / 255.) / std,
                            lower_limit=lower_limit,
                            upper_limit=upper_limit,
                            attack_iters=args.attack_iters,
                            criterion=nn.CrossEntropyLoss())
    # parameters for PGD early stopping
    pgd_kwargs = dict(epsilon=epsilon,
                      alpha=(2 / 255.) / std,
                      lower_limit=lower_limit,
                      upper_limit=upper_limit,
                      attack_iters=5,
                      restarts=1)

    # training
    model, best_state_dict = fit(
        step=functools.partial(train_step, **pgd_train_kwargs),
        epochs=args.epochs,
        model=model,
        optimizer=opt,
        scheduler=scheduler,
        data_loader=train_loader,
        model_path=os.path.join(args.out_dir,
                                f'model_preact_resnet18_{name}.pt'),
        logger=logger,
        early_stop=args.early_stop,
        pgd_kwargs=pgd_kwargs)

    # eval
    if not args.no_eval:
        model_test = model_factory()
        model_test.load_state_dict(best_state_dict)
        evaluate(model=model_test,
                 test_loader=test_loader,
                 upper_limit=upper_limit,
                 lower_limit=lower_limit,
                 verbose=args.no_verbose)
Exemplo n.º 2
0
def main():
    name = 'distilled'

    # set up the experiment
    args, logger, device, \
    (train_loader, test_loader), \
    (model_factory, optimizer_factory, scheduler_factory) \
        = init_experiment(args_factory=get_args, name=name)

    assert not args.early_stop, \
        'Model distillation is non-robust, do not early stop wrt adversarial accuracy.'

    if args.pretrained_model is None:
        logger.info('No pretrained model specified... pretraining model via `train_standard.py`.')
        process = subprocess.run([
            'python', 'train_standard.py', 
            '--epochs', str(args.epochs),
            '--no-eval', 
            '--temperature', str(args.softmax_temperature)])
        pretrained_model_path = os.path.join(args.out_dir, 'model_preact_resnet18_standard.pt')
    else:
        logger.info(f'Pretrained model specified as `{args.pretrained_model}`... skipping pretraining.')
        pretrained_model_path = args.pretrained_model

    pretrained_model = load_model(
        path         =pretrained_model_path, 
        model_factory=model_factory, 
        device       =device)

    model     = model_factory()
    opt       = optimizer_factory(model)
    scheduler = scheduler_factory(opt)

    ### train model
    model, best_state_dict = fit(
        step       =functools.partial(train_step, 
            pretrained_model   =pretrained_model, 
            softmax_temperature=args.softmax_temperature),
        epochs     =args.epochs,
        model      =model,
        optimizer  =opt,
        scheduler  =scheduler,
        data_loader=train_loader,
        model_path =os.path.join(args.out_dir, f'model_preact_resnet18_{name}.pt'),
        logger     =logger,
        early_stop =args.early_stop)

    ### evaluate model
    if not args.no_eval:
        model_test = model_factory()
        model_test.load_state_dict(best_state_dict)
        evaluate(
            model      =model_test,
            test_loader=test_loader,
            upper_limit=UPPER_LIMIT,
            lower_limit=LOWER_LIMIT,
            verbose    =args.no_verbose)
Exemplo n.º 3
0
def performance(model_local, data_local):
    preds = model_local.predict(data_generator=data_local)
    labels = []
    for data_batched in data_local.generate(batch_size=512,
                                            random_shuffle=False):
        labels.append(data_batched["label"])
    labels = np.concatenate(labels, axis=0)
    # one-hot to index #
    trues = np.argmax(labels, axis=-1)

    perf = evaluate(preds=preds, trues=trues)
    return perf
Exemplo n.º 4
0
def main():
    name = 'label_smoothing'

    # set up the experiment
    args, logger, device, \
    (train_loader, test_loader), \
    (model_factory, optimizer_factory, scheduler_factory) \
        = init_experiment(args_factory=get_args, name=name)

    assert not args.early_stop, \
        'Model distillation is non-robust, do not early stop wrt adversarial accuracy.'

    model = model_factory()
    opt = optimizer_factory(model)
    scheduler = scheduler_factory(opt)

    model, best_state_dict = fit(
        step=functools.partial(train_step, smoothing=args.smoothing),
        epochs=args.epochs,
        model=model,
        optimizer=opt,
        scheduler=scheduler,
        data_loader=train_loader,
        model_path=os.path.join(args.out_dir,
                                f'model_preact_resnet18_{name}.pt'),
        logger=logger,
        early_stop=args.early_stop)

    if not args.no_eval:
        model_test = model_factory()
        model_test.load_state_dict(best_state_dict)
        evaluate(model=model_test,
                 test_loader=test_loader,
                 upper_limit=UPPER_LIMIT,
                 lower_limit=LOWER_LIMIT,
                 verbose=args.no_verbose)
Exemplo n.º 5
0
def test_evaluate_model():
    """ Tests evaluating a model """
    model = get_model(Args)
    assert evaluate(model, Args)