コード例 #1
0
def generate_transformed_images(args):

    # Only runs one method at a time
    assert args.operation is not None, \
        "operation to run can't be None"
    assert OperationType.has_value(args.operation), \
        "\"{}\" operation not defined".format(args.operation)

    assert args.defenses is not None, "Defenses can't be None"
    assert not args.preprocessed_data, \
        "Trying to apply transformations on already transformed images"

    if args.operation == str(OperationType.TRANSFORM_ADVERSARIAL):
        for idx, defense_name in enumerate(args.defenses):
            defense = get_defense(defense_name, args)
            adv_params = constants.get_adv_params(args, idx)
            print("| adv_params: ", adv_params)
            dataset = _load_partial_dataset(args, 'valid', defense, adv_params)

            if args.data_batches is None:
                transformation_on_adv(args, dataset, defense_name, adv_params)
            else:
                for i in range(args.data_batches):
                    transformation_on_adv(args,
                                          dataset,
                                          defense_name,
                                          adv_params,
                                          data_batch_idx=i)

    elif args.operation == str(OperationType.CAT_DATA):
        for idx, defense_name in enumerate(args.defenses):
            adv_params = constants.get_adv_params(args, idx)
            print("| adv_params: ", adv_params)
            if args.data_batches is None:
                concatenate_data(args, defense_name, adv_params)
            else:
                for i in range(args.data_batches):
                    concatenate_data(args,
                                     defense_name,
                                     adv_params,
                                     data_batch_idx=i)

    elif args.operation == str(OperationType.TRANSFORM_RAW):
        start_class_idx = args.partition * args.partition_size
        end_class_idx = (args.partition + 1) * args.partition_size
        class_indices = range(start_class_idx, end_class_idx)
        for defense_name in args.defenses:
            defense = get_defense(defense_name, args)
            data_type = args.data_type if args.data_type == "train" else "valid"
            dataset = load_dataset(args,
                                   data_type,
                                   defense,
                                   class_indices=class_indices)
            transformation_on_raw(args, dataset, defense_name)
コード例 #2
0
def classify_images(args):

    # assertions
    assert args.ensemble is None or args.ensemble in ENSEMBLE_TYPE, \
        "{} not a supported type. Only supported ensembling are {}".format(
            args.ensemble, ENSEMBLE_TYPE)
    if not args.ensemble:
        assert args.ncrops is None or (len(args.ncrops) == 1
                                       and args.ncrops[0] == 1)
    if args.defenses is not None:
        for d in args.defenses:
            assert DefenseType.has_value(d), \
                "\"{}\" defense not defined".format(d)
        # crops expected for each defense
        assert (args.ncrops is None or len(args.ncrops) == len(
            args.defenses)), ("Number of crops for each defense is expected")
        assert (args.crop_type is None or len(args.crop_type) == len(
            args.defenses)), ("crop_type for each defense is expected")
        # assert (len(args.crop_frac) == len(args.defenses)), (
        #     "crop_frac for each defense is expected")
    elif args.ncrops is not None:
        # no crop ensembling when defense is None
        assert len(args.ncrops) == 1
        assert args.crop_frac is not None and len(args.crop_frac) == 1, \
            "Only one crop_frac is expected as there is no defense"
        assert args.crop_type is not None and len(args.crop_type) == 1, \
            "Only one crop_type is expected as there is no defense"

    if args.defenses is None or len(args.defenses) == 0:
        defenses = [None]
    else:
        defenses = args.defenses

    all_defense_probs = None
    for idx, defense_name in enumerate(defenses):
        # initialize dataset
        defense = get_defense(defense_name, args)
        # Read preset params for adversary based on args
        adv_params = constants.get_adv_params(args, idx)
        print("| adv_params: ", adv_params)
        # setup crop
        ncrops = 1
        crop_type = None
        crop_frac = 1.0
        if args.ncrops:
            crop_type = args.crop_type[idx]
            crop_frac = args.crop_frac[idx]
            if crop_type == 'sliding':
                ncrops = 9
            else:
                ncrops = args.ncrops[idx]
        # Init custom crop function
        crop = transforms.Crop(crop_type, crop_frac)
        # initialize dataset
        dataset = load_dataset(args, 'valid', defense, adv_params, crop)
        # load model
        model, _, _ = get_model(args,
                                load_checkpoint=True,
                                defense_name=defense_name)

        # get crop probabilities for crops for current defense
        probs, targets = _eval_crops(args, dataset, model, defense, crop,
                                     ncrops, crop_type)

        if all_defense_probs is None:
            all_defense_probs = torch.zeros(len(defenses), len(dataset),
                                            probs.size(2))
        # Ensemble crop probabilities
        if args.ensemble == 'max':
            probs = torch.max(probs, dim=0)[0]
        elif args.ensemble == 'avg':  # for average ensembling
            probs = torch.mean(probs, dim=0)
        else:  # for no ensembling
            assert all_defense_probs.size(0) == 1
            probs = probs[0]
        all_defense_probs[idx, :, :] = probs

        # free memory
        dataset = None
        model = None

    # Ensemble defense probabilities
    if args.ensemble == 'max':
        all_defense_probs = torch.max(all_defense_probs, dim=0)[0]
    elif args.ensemble == 'avg':  # for average ensembling
        all_defense_probs = torch.mean(all_defense_probs, dim=0)
    else:  # for no ensembling
        assert all_defense_probs.size(0) == 1
        all_defense_probs = all_defense_probs[0]
    # Calculate top1 and top5 accuracy
    prec1, prec5 = accuracy(all_defense_probs, targets, topk=(1, 5))
    print('=' * 50)
    print('Results for model={}, attack={}, ensemble_type={} '.format(
        args.model, args.adversary, args.ensemble))
    prec1 = prec1[0]
    prec5 = prec5[0]
    print('| classification accuracy @1: %2.5f' % (prec1))
    print('| classification accuracy @5: %2.5f' % (prec5))
    print('| classification error @1: %2.5f' % (100. - prec1))
    print('| classification error @5: %2.5f' % (100. - prec5))
    print('| done.')
コード例 #3
0
def train_model(args):

    # At max 1 defense as no ensembling in training
    assert args.defenses is None or len(args.defenses) == 1
    defense_name = None if not args.defenses else args.defenses[0]
    defense = get_defense(defense_name, args)

    # Load model
    model, start_epoch, optimizer_ = get_model(args,
                                               load_checkpoint=args.resume,
                                               defense_name=defense_name,
                                               training=True)

    # set up optimizer:
    optimizer = _get_optimizer(model, args)

    # get from checkpoint if available
    if start_epoch and optimizer:
        args.start_epoch = start_epoch
        optimizer.load_state_dict(optimizer_)

    # set up criterion:
    criterion = nn.CrossEntropyLoss()

    if args.device == 'gpu':
        # Call .cuda() method on model
        criterion = criterion.cuda()
        model = model.cuda()

    loaders = {}

    # set up start-of-epoch hook:
    def start_epoch_hook(epoch, model, optimizer):
        print('| epoch %d, training:' % epoch)
        adjust_learning_rate(args.lr, epoch, optimizer, args.lr_decay,
                             args.lr_decay_stepsize)

    # set up the end-of-epoch hook:
    def end_epoch_hook(epoch, model, optimizer, prec1=None, prec5=None):

        # print training error:
        if prec1 is not None:
            print('| training error @1 (epoch %d): %2.5f' %
                  (epoch, 100. - prec1))
        if prec5 is not None:
            print('| training error @5 (epoch %d): %2.5f' %
                  (epoch, 100. - prec5))

        # save checkpoint:
        print('| epoch %d, testing:' % epoch)
        save_checkpoint(
            args.models_root, {
                'epoch': epoch + 1,
                'model_name': args.model,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            })

        # measure validation error:
        prec1, prec5 = test(model, loaders['valid'])
        print('| validation error @1 (epoch %d: %2.5f' % (epoch, 100. - prec1))
        print('| validation error @5 (epoch %d: %2.5f' % (epoch, 100. - prec5))

    def data_loader_hook(epoch):
        # Reload data loader for epoch
        if args.preprocessed_epoch_data:
            print('| epoch %d, Loading data:' % epoch)
            for key in {'train', 'valid'}:
                # Load validation data only once
                if key == 'valid' and 'valid' in loaders:
                    break
                loaders[key] = get_data_loader(
                    load_dataset(args, key, defense, epoch=epoch),
                    batchsize=args.batchsize,
                    device=args.device,
                    shuffle=True,
                )
        # if data needs to be loaded only once and is not yet loaded
        elif len(loaders) == 0:
            print('| epoch %d, Loading data:' % epoch)
            for key in {'train', 'valid'}:
                loaders[key] = get_data_loader(
                    load_dataset(args, key, defense),
                    batchsize=args.batchsize,
                    device=args.device,
                    shuffle=True,
                )

        return loaders['train']

    # train the model:
    print('| training model...')
    train(model,
          criterion,
          optimizer,
          start_epoch_hook=start_epoch_hook,
          end_epoch_hook=end_epoch_hook,
          data_loader_hook=data_loader_hook,
          start_epoch=args.start_epoch,
          end_epoch=args.end_epoch,
          learning_rate=args.lr)
    print('| done.')