コード例 #1
0
def args_preprocess(args):
    if args.adv_train and eval(args.eps) == 0:
        print('[Switching to standard training since eps = 0]')
        args.adv_train = 0

    if args.pytorch_pretrained:
        assert not args.model_path, 'You can either specify pytorch_pretrained or model_path, not together.'

    ## CIFAR10 to CIFAR10 assertions
    if args.cifar10_cifar10:
        assert args.dataset == 'cifar10'

    if args.data != '':
        cs.PETS_PATH = cs.CARS_PATH = cs.FGVC_PATH = cs.FLOWERS_PATH = cs.DTD_PATH = cs.SUN_PATH = cs.FOOD_PATH = cs.BIRDS_PATH = args.data

    ALL_DS = list(transfer_datasets.DS_TO_FUNC.keys()) + [
        'imagenet', 'breeds_living_9', 'stylized_imagenet'
    ]
    assert args.dataset in ALL_DS

    # Important for automatic job retries on the cluster in case of premptions. Avoid uuids.
    assert args.exp_name != None

    # Preprocess args
    args = defaults.check_and_fill_args(args, defaults.CONFIG_ARGS, None)
    if not args.eval_only:
        args = defaults.check_and_fill_args(args, defaults.TRAINING_ARGS, None)
    if args.adv_train or args.adv_eval:
        args = defaults.check_and_fill_args(args, defaults.PGD_ARGS, None)
    args = defaults.check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, None)

    return args
コード例 #2
0
def setup_args(args):
    '''
    Fill the args object with reasonable defaults from
    :mod:`robustness.defaults`, and also perform a sanity check to make sure no
    args are missing.
    '''
    # override non-None values with optional config_path
    args.adv_train = (args.classifier_loss == 'robust') or \
                     (args.estimator_loss == 'worst')
    if args.config_path:
        args = cox.utils.override_json(args, args.config_path)

    ds_class = DATASETS[args.dataset]
    args = check_and_fill_args(args, defaults.CONFIG_ARGS, ds_class)

    if not args.eval_only:
        args = check_and_fill_args(args, defaults.TRAINING_ARGS, ds_class)

    if args.adv_train or args.adv_eval:
        args = check_and_fill_args(args, defaults.PGD_ARGS, ds_class)

    args = check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, ds_class)
    if args.eval_only: assert args.resume is not None, \
            "Must provide a resume path if only evaluating"
    return args
コード例 #3
0
out_store = cox.store.Store(args.out_dir)

# Hard-coded base parameters
train_kwargs = {
    'out_dir': args.out_dir,
    'adv_train': 0,
    'constraint': GaussianNoise,
    'eps': args.eps,
    'attack_lr': args.eps,
    'lr': args.lr,
    'attack_steps': 1,
    'step_lr': 2000,
    'random_start': 1,
    'use_best': False,
    'epochs': args.epochs,
    'save_ckpt_iters': -1,  # best and last
    'eps_fadein_epochs': args.fade_in
}
train_args = Parameters(train_kwargs)

# Fill whatever parameters are missing from the defaults
# use Imagnet defaults
ds_class = datasets.DATASETS['imagenet']
train_args = defaults.check_and_fill_args(train_args,
                                          defaults.TRAINING_ARGS, ds_class)
train_args = defaults.check_and_fill_args(train_args,
                                          defaults.PGD_ARGS, ds_class)

# Train a model
train.train_model(train_args, model, (train_loader, val_loader), store=out_store)
コード例 #4
0
            for k in new_args), set(new_args.keys()) - set(vars(args).keys())
        for k in new_args:
            setattr(args, k, new_args[k])

    assert not args.adv_train, 'not supported yet slatta dog'
    assert args.training_mode is not None, "training_mode is required"

    # Important for automatic job retries on the cluster in case of premptions. Avoid uuids.
    if args.exp_name == 'random':
        args.exp_name = str(uuid4())
        print(f"Experiment name: {args.exp_name}")
    assert args.exp_name != None

    # Preprocess args
    default_ds = args.dataset if args.dataset in datasets.DATASETS else "imagenet"
    args = defaults.check_and_fill_args(args, defaults.CONFIG_ARGS,
                                        datasets.DATASETS[default_ds])
    if not args.eval_only:
        args = defaults.check_and_fill_args(args, defaults.TRAINING_ARGS,
                                            datasets.DATASETS[default_ds])
    if False and (args.adv_train or args.adv_eval):
        args = defaults.check_and_fill_args(args, defaults.PGD_ARGS,
                                            datasets.DATASETS[default_ds])
    args = defaults.check_and_fill_args(args, defaults.MODEL_LOADER_ARGS,
                                        datasets.DATASETS[default_ds])

    store = cox.store.Store(args.out_dir, args.exp_name)
    if 'metadata' not in store.keys:
        args_dict = args.__dict__
        schema = cox.store.schema_from_dict(args_dict)
        store.add_table('metadata', schema)
        store['metadata'].append_row(args_dict)
コード例 #5
0
            loss.backward()
            adv = self.clip(adv + step_size * torch.sign(adv.grad.data), inp,
                            eps)  # gradient ASCENT
        return adv.clone().detach()


ds = CIFAR('/scratch/raunakc/datasets/cifar10')
model, _ = make_and_restore_model(arch='resnet18', dataset=ds)
model.attacker = WhiteboxPGD(model.model, ds)

train_kwargs = {
    'dataset': 'cifar',
    'arch': 'resnet',
    'out_dir': "train_out",
    'adv_train': 1,
    'adv_eval': 1,
    'eps': 8 / 255,
    'attack_lr': 2 / 255,
    'attack_steps': 10,
    'constraint': 'inf'  # not required but arg checker requires it :(
}

args = utils.Parameters(train_kwargs)
args = check_and_fill_args(args, defaults.TRAINING_ARGS, ds.__class__)
if args.adv_train or args.adv_eval:
    args = check_and_fill_args(args, defaults.PGD_ARGS, ds.__class__)
args = check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, ds.__class__)

train_loader, val_loader = ds.make_loaders(batch_size=128, workers=8)
train.train_model(args, model, (train_loader, val_loader))
コード例 #6
0
    if args.adv_train and args.eps == 0:
        print('[Switching to standard training since eps = 0]')
        args.adv_train = 0

    assert not args.adv_train, 'not supported yet slatta dog'

    # Important for automatic job retries on the cluster in case of premptions. Avoid uuids.
    assert args.exp_name != None

    # Useful for evaluation QRCodes since they are not robust to occlustions at all
    if args.no_translation: 
        constants.PATCH_TRANSFORMS['translate'] = (0., 0.)

    # Preprocess args
    args = defaults.check_and_fill_args(
        args, defaults.CONFIG_ARGS, datasets.DATASETS[args.dataset])
    args = defaults.check_and_fill_args(
        args, defaults.MODEL_LOADER_ARGS, datasets.DATASETS[args.dataset])

    store = cox.store.Store(args.out_dir, args.exp_name)
    if args.args_from_store:
        args_from_store = args.args_from_store.split(',')
        df = store['metadata'].df
        print(f'==>[Reading from existing store in {store.path}]')
        for a in args_from_store:
            if a not in df:
                raise Exception(f'Did not find {a} in the store {store.path}')
            setattr(args,a, df[a][0])
            print(f'==>[Read {a} = ({getattr(args, a)}) ')

    if 'metadata_eval' not in store.keys:
コード例 #7
0
train_crit = torch.nn.BCELoss()


def custom_train_loss(logits, targ):
    if torch.cuda.is_available():
        targets = torch.from_numpy(order[targ.cpu().numpy()]).cuda()
    else:
        targets = torch.from_numpy(order[targ.numpy()])
    outputs = torch.sigmoid(logits.float())
    return train_crit(outputs.float(), targets.float())


train_args.custom_train_loss = custom_train_loss

# Fill whatever parameters are missing from the defaults
train_args = defaults.check_and_fill_args(train_args, defaults.TRAINING_ARGS,
                                          CustomCIFAR)
# train_args = defaults.check_and_fill_args(train_args,
# defaults.PGD_ARGS, CIFAR)

# Train a model
train.train_model(train_args, m, (train_loader, val_loader), store=out_store)
# from robustness.datasets import DATASETS
# from robustness.model_utils import make_and_restore_model
# from robustness.train import train_model
# from robustness.defaults import check_and_fill_args
# from robustness.tools import constants, helpers
# from robustness import defaults

# from cox import utils
# from cox import store
コード例 #8
0
        'eps_fadein_epochs': 0,
        'random_restarts': 0,
        'lr': parsed_args.start_lr,
        'use_adv_eval_criteria': 1,
        'regularizer': regularizer,
        'let_reg_handle_loss': True
    }

    ds_class = DATASETS[train_kwargs['dataset']]

    train_args = cox.utils.Parameters(train_kwargs)

    dx = utils.CIFAR10()
    dataset = dx.get_dataset()

    args = check_and_fill_args(train_args, defaults.TRAINING_ARGS, ds_class)
    args = check_and_fill_args(train_args, defaults.MODEL_LOADER_ARGS,
                               ds_class)

    model, _ = make_and_restore_model(arch='vgg19', dataset=dataset)

    # Make the data loaders
    train_loader, val_loader = dataset.make_loaders(args.workers,
                                                    args.batch_size,
                                                    data_aug=bool(
                                                        args.data_aug))

    # Prefetches data to improve performance
    train_loader = helpers.DataPrefetcher(train_loader)
    val_loader = helpers.DataPrefetcher(val_loader)