Beispiel #1
0
def eval_model(args, atm, loader, *, checkpoint=None, store=None):
    """
    Evaluate a model for standard (and optionally adversarial) accuracy.

    Args:
        args (object) : A list of arguments---should be a python object
            implementing ``getattr()`` and ``setattr()``.
        model (AttackerModel) : model to evaluate
        loader (iterable) : a dataloader serving `(input, label)` batches from
            the validation set
        store (cox.Store) : store for saving results in (via tensorboardX)
    """
    if checkpoint:
        start_epoch = checkpoint['epoch']
        # best_prec1 = checkpoint[f"{'adv' if args.adv_train else 'nat'}_prec1"]

    check_required_args(args, eval_only=True)
    start_time = time.time()

    if store is not None:
        store.add_table(consts.LOGS_TABLE, consts.LOGS_SCHEMA_EXP2)
    writer = store.tensorboard if store else None

    # model = ch.nn.DataParallel(model)

    prec1, nat_loss = _model_loop(args, 'val', loader, atm, None, start_epoch,
                                  (False, ), writer)

    adv_prec1, adv_loss = float('nan'), float('nan')
    if args.adv_eval:
        args.eps = eval(str(args.eps)) if has_attr(args, 'eps') else None
        args.attack_lr = eval(str(args.attack_lr)) if has_attr(
            args, 'attack_lr') else None

        prec1, nat_loss = _model_loop(args, 'val', loader, atm, None,
                                      start_epoch, (True, ), writer)
    log_info = {
        'epoch': 0,
        'nat_prec1': prec1,
        'adv_prec1': adv_prec1,
        'nat_loss': nat_loss,
        'adv_loss': adv_loss,
        'train_prec1': float('nan'),
        'train_loss': float('nan'),
        'time': time.time() - start_time
    }

    # Log info into the logs table
    if store: store[consts.LOGS_TABLE].append_row(log_info)
def check_and_fill_args(args, arg_list, ds_class):
    """
    Fills in defaults based on an arguments list (e.g., TRAINING_ARGS) and a
    dataset class (e.g., datasets.CIFAR).

    Args:
        args (object) : Any object subclass exposing :samp:`setattr` and
            :samp:`getattr` (e.g. cox.utils.Parameters)
        arg_list (list) : A list of the same format as the lists above, i.e.
            containing entries of the form [NAME, TYPE/CHOICES, HELP, DEFAULT]
        ds_class (type) : A dataset class name (i.e. a
            :class:`robustness.datasets.DataSet` subclass name)

    Returns:
        args (object): The :samp:`args` object with all the defaults filled in according to :samp:`arg_list` defaults.
    """
    for arg_name, _, _, arg_default in arg_list:
        name = arg_name.replace("-", "_")
        if helpers.has_attr(args, name): continue
        if arg_default == REQ: raise ValueError(f"{arg_name} required")
        elif arg_default == BY_DATASET:
            setattr(args, name, TRAINING_DEFAULTS[ds_class][name])
        elif arg_default is not None:
            setattr(args, name, arg_default)
    return args
Beispiel #3
0
def check_required_args(args, eval_only=False):
    """
    Check that the required training arguments are present.

    Args:
        args (argparse object): the arguments to check
        eval_only (bool) : whether to check only the arguments for evaluation
    """
    required_args_eval = ["adv_eval"]
    required_args_train = [
        "epochs",
        "out_dir",  # "adv_train",
        "log_iters",
        "lr",
        "momentum",
        "weight_decay"
    ]
    adv_required_args = [
        "attack_steps", "eps", "constraint", "use_best", "eps_fadein_epochs",
        "attack_lr", "random_restarts"
    ]

    # Generic function for checking all arguments in a list
    def check_args(args_list):
        for arg in args_list:
            assert has_attr(args, arg), f"Missing argument {arg}"

    # Different required args based on training or eval:
    if not eval_only:
        check_args(required_args_train)
    else:
        check_args(required_args_eval)
    # More required args if we are robustly training or evaling
    is_adv = bool(args.adv_train) or bool(args.adv_eval)
    if is_adv:
        check_args(adv_required_args)
    # More required args if the user provides a custom training loss
    has_custom_train = has_attr(args, 'custom_train_loss')
    has_custom_adv = has_attr(args, 'custom_adv_loss')
    if has_custom_train and is_adv and not has_custom_adv:
        raise ValueError("Cannot use custom train loss \
            without a custom adversarial loss (see docs)")
Beispiel #4
0
 def check_args(args_list):
     for arg in args_list:
         assert has_attr(args, arg), f"Missing argument {arg}"
Beispiel #5
0
def train_model(args, atm, loaders, *, checkpoint=None, store=None):
    """
    Main function for training a model.

    Args:
        args (object) : A python object for arguments, implementing
            ``getattr()`` and ``setattr()`` and having the following
            attributes. See :attr:`robustness.defaults.TRAINING_ARGS` for a
            list of arguments, and you can use
            :meth:`robustness.defaults.check_and_fill_args` to make sure that
            all required arguments are filled and to fill missing args with
            reasonable defaults:

            adv_train (int or bool, *required*)
                if 1/True, adversarially train, otherwise if 0/False do
                standard training
            epochs (int, *required*)
                number of epochs to train for
            lr (float, *required*)
                learning rate for SGD optimizer
            weight_decay (float, *required*)
                weight decay for SGD optimizer
            momentum (float, *required*)
                momentum parameter for SGD optimizer
            step_lr (int)
                if given, drop learning rate by 10x every `step_lr` steps
            custom_schedule (str)
                If given, use a custom LR schedule (format: [(epoch, LR),...])
            adv_eval (int or bool)
                If True/1, then also do adversarial evaluation, otherwise skip
                (ignored if adv_train is True)
            log_iters (int, *required*)
                How frequently (in epochs) to save training logs
            save_ckpt_iters (int, *required*)
                How frequently (in epochs) to save checkpoints (if -1, then only
                save latest and best ckpts)
            attack_lr (float or str, *required if adv_train or adv_eval*)
                float (or float-parseable string) for the adv attack step size
            constraint (str, *required if adv_train or adv_eval*)
                the type of adversary constraint
                (:attr:`robustness.attacker.STEPS`)
            eps (float or str, *required if adv_train or adv_eval*)
                float (or float-parseable string) for the adv attack budget
            attack_steps (int, *required if adv_train or adv_eval*)
                number of steps to take in adv attack
            eps_fadein_epochs (int, *required if adv_train or adv_eval*)
                If greater than 0, fade in epsilon along this many epochs
            use_best (int or bool, *required if adv_train or adv_eval*) :
                If True/1, use the best (in terms of loss) PGD step as the
                attack, if False/0 use the last step
            random_restarts (int, *required if adv_train or adv_eval*)
                Number of random restarts to use for adversarial evaluation
            custom_train_loss (function, optional)
                If given, a custom loss instead of the default CrossEntropyLoss.
                Takes in `(logits, targets)` and returns a scalar.
            custom_adv_loss (function, *required if custom_train_loss*)
                If given, a custom loss function for the adversary. The custom
                loss function takes in `model, input, target` and should return
                a vector representing the loss for each element of the batch, as
                well as the classifier output.
            regularizer (function, optional)
                If given, this function of `model, input, target` returns a
                (scalar) that is added on to the training loss without being
                subject to adversarial attack
            iteration_hook (function, optional)
                If given, this function is called every training iteration by
                the training loop (useful for custom logging). The function is
                given arguments `model, iteration #, loop_type [train/eval],
                current_batch_ims, current_batch_labels`.
            epoch hook (function, optional)
                Similar to iteration_hook but called every epoch instead, and
                given arguments `model, log_info` where `log_info` is a
                dictionary with keys `epoch, nat_prec1, adv_prec1, nat_loss,
                adv_loss, train_prec1, train_loss`.

        model (AttackerModel) : the model to train.
        loaders (tuple[iterable]) : `tuple` of data loaders of the form
            `(train_loader, val_loader)`
        checkpoint (dict) : a loaded checkpoint previously saved by this library
            (if resuming from checkpoint)
        store (cox.Store) : a cox store for logging training progress
    """
    big_encoder = atm.attacker.model

    # Logging setup
    writer = store.tensorboard if store else None
    if store is not None:
        store.add_table(consts.LOGS_TABLE, consts.LOGS_SCHEMA_EXP2)
        store.add_table(consts.CKPTS_TABLE, consts.CKPTS_SCHEMA)

    # Reformat and read arguments
    check_required_args(args)  # Argument sanity check
    args.eps = eval(str(args.eps)) if has_attr(args, 'eps') else None
    args.attack_lr = eval(str(args.attack_lr)) if has_attr(
        args, 'attack_lr') else None

    # Initial setup
    train_loader, val_loader = loaders

    opts, schedules = [], []
    for i, submodel in enumerate(big_encoder.models):
        if submodel is None:
            opts.append(None)
            schedules.append(None)
        else:
            lr = args.lr
            opt, schedule = make_optimizer_and_schedule(
                args, submodel, checkpoint, lr, args.step_lr)
            opts.append(opt)
            schedules.append(schedule)

    best_prec1, best_loss, start_epoch = (0, float('inf'), 0)
    if checkpoint and not args.task == 'train-classifier':
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint[f"{'adv' if args.adv_train else 'nat'}_prec1"]

    # Put the model into parallel mode
    # assert not hasattr(model, "module"), "model is already in DataParallel."
    # model = ch.nn.DataParallel(model).cuda()

    # Timestamp for training start time
    start_time = time.time()

    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        # if args.exp2_neuronest_mode == -1:
        train_prec1, train_loss = \
            _model_loop(args, 'train', train_loader, atm,
                        opts, epoch, (None,), writer) # None

        last_epoch = (epoch == (args.epochs - 1))

        # evaluate on validation set
        sd_info = {
            'model': atm.state_dict(),
            'optimizer': -1,
            'schedule': -1,
            # 'optimizer': opt.state_dict(),
            # 'schedule': (schedule and schedule.state_dict()),
            'epoch': epoch + 1
        }

        def save_checkpoint(filename):
            if not store:
                return
            ckpt_save_path = os.path.join(args.out_dir if not store else \
                                              store.path, filename)
            ch.save(sd_info, ckpt_save_path, pickle_module=dill)

        save_its = args.save_ckpt_iters
        should_save_ckpt = (epoch % save_its == 0) and (save_its > 0)
        should_log = (epoch % args.log_iters == 0)

        if should_log or last_epoch or should_save_ckpt:
            # log + get best
            with ch.no_grad():
                prec1, nat_loss = _model_loop(args, 'val', val_loader, atm,
                                              None, epoch, (False, ), writer)

            # loader, model, epoch, input_adv_exs
            should_adv_eval = args.adv_eval or args.adv_train
            if should_adv_eval:
                adv_prec1, adv_loss = _model_loop(args, 'val', val_loader, atm,
                                                  None, epoch, (True, ),
                                                  writer)
            else:
                adv_prec1, adv_loss = -1.0, -1.0

            # remember best prec@1 and save checkpoint
            if args.task == 'train-model' or args.task == 'train-classifier':
                our_prec1 = adv_prec1 if args.adv_train else prec1
                is_best = our_prec1 > best_prec1
                best_prec1 = max(our_prec1, best_prec1)
            else:
                our_loss = adv_loss if args.adv_train else nat_loss
                is_best = our_loss < best_loss
                best_loss = min(our_loss, best_loss)

            # log every checkpoint
            log_info = {
                'epoch': epoch + 1,
                'nat_prec1': prec1,
                'adv_prec1': adv_prec1,
                'nat_loss': nat_loss,
                'adv_loss': adv_loss,
                'train_prec1': train_prec1,
                'train_loss': train_loss,
                'time': time.time() - start_time
            }

            # Log info into the logs table
            if store: store[consts.LOGS_TABLE].append_row(log_info)
            # If we are at a saving epoch (or the last epoch), save a checkpoint
            if should_save_ckpt or last_epoch:
                save_checkpoint(ckpt_at_epoch(epoch))
            # If  store exists and this is the last epoch, save a checkpoint
            if last_epoch and store:
                store[consts.CKPTS_TABLE].append_row(sd_info)

            # Update the latest and best checkpoints (overrides old one)
            # if not (args.exp2_neuronest_mode == 0):
            save_checkpoint(consts.CKPT_NAME_LATEST)
            if is_best: save_checkpoint(consts.CKPT_NAME_BEST)

        if schedule: schedule.step()
        if has_attr(args, 'epoch_hook'): args.epoch_hook(atm, log_info)

    return atm