def data_loader_hook(epoch):
        # Reload data loader for epoch
        if args.preprocessed_epoch_data:
            print('| epoch %d, Loading data:' % epoch)
            for key in {'train', 'valid'}:
                # Load validation data only once
                if key == 'valid' and 'valid' in loaders:
                    break
                loaders[key] = get_data_loader(
                    load_dataset(args, key, defense, epoch=epoch),
                    batchsize=args.batchsize,
                    device=args.device,
                    shuffle=True,
                )
        # if data needs to be loaded only once and is not yet loaded
        elif len(loaders) == 0:
            print('| epoch %d, Loading data:' % epoch)
            for key in {'train', 'valid'}:
                loaders[key] = get_data_loader(
                    load_dataset(args, key, defense),
                    batchsize=args.batchsize,
                    device=args.device,
                    shuffle=True,
                )

        return loaders['train']
def gather_patches(image_dataset, num_patches, patch_size, patch_transform=None):

    # assertions:
    assert isinstance(image_dataset, torch.utils.data.dataset.Dataset)
    assert type(num_patches) == int and num_patches > 0
    assert type(patch_size) == int and patch_size > 0
    if patch_transform is not None:
        assert callable(patch_transform)

    # gather patches (TODO: speed this up):
    patches, n = [], 0
    num_images = len(image_dataset)
    bar = progressbar.ProgressBar(num_images)
    bar.start()
    data_loader = get_data_loader(image_dataset)
    for (img, _) in data_loader:
        n += 1
        img = img.squeeze()
        for _ in range(0, max(1, int(num_patches / num_images))):
            y = random.randint(0, img.size(1) - patch_size)
            x = random.randint(0, img.size(2) - patch_size)
            patch = img[:, y:y + patch_size, x:x + patch_size]
            if patch_transform is not None:
                patch = patch_transform(patch)
            patches.append(patch)
        if n % 100 == 0:
            bar.update(n)

    # copy all patches into single tensor:
    patches = torch.stack(patches, dim=0)
    patches = patches.view(patches.size(0), int(patches.nelement() / patches.size(0)))
    return patches
예제 #3
0
def _eval_crops(args, dataset, model, defense, crop, ncrops, crop_type):

    # assertions
    assert dataset is not None, "dataset expected"
    assert model is not None, "model expected"
    assert crop_type is None or isinstance(crop_type, str)
    if crop is not None:
        assert callable(crop)
    assert type(ncrops) == int

    probs = None

    for crop_num in range(ncrops):

        # For sliding crop update crop function in dataset
        if crop_type == 'sliding':
            crop.update_sliding_position(crop_num)
            dataset = update_dataset_transformation(dataset, args, 'valid',
                                                    defense, crop)

        # set up dataloader:
        print('| set up data loader...')
        data_loader = get_data_loader(
            dataset,
            batchsize=args.batchsize,
            device=args.device,
            shuffle=False,
        )

        # test
        prob, targets = get_prob(model, data_loader)
        # collect prob for each run
        if probs is None:
            probs = torch.zeros(ncrops, len(dataset), prob.size(1))
        probs[crop_num, :, :] = prob

        # measure and print accuracy
        _, _prob = prob.topk(5, dim=1)
        _correct = _prob.eq(targets.view(-1, 1).expand_as(_prob))
        _top1 = _correct.select(1, 0).float().mean() * 100
        defense_name = "no defense" if defense is None else defense.get_name()
        print('| crop[%d]: top1 acc for %s = %f' %
              (crop_num, defense_name, _top1))

        data_loader = None

    return probs, targets
def generate_adversarial_images(args):
    # assertions
    assert args.adversary_to_generate is not None, \
        "adversary_to_generate can't be None"
    assert AdversaryType.has_value(args.adversary_to_generate), \
        "\"{}\" adversary_to_generate not defined".format(args.adversary_to_generate)

    defense_name = None if not args.defenses else args.defenses[0]
    # defense = get_defense(defense_name, args)
    data_indices = _get_data_indices(args)
    data_type = args.data_type if args.data_type == "train" else "valid"
    dataset = load_dataset(args, data_type, None, data_indices=data_indices)
    data_loader = get_data_loader(
        dataset,
        batchsize=args.batchsize,
        device=args.device,
        shuffle=False)

    model, _, _ = get_model(args, load_checkpoint=True, defense_name=defense_name)

    adv_params = constants.get_adv_params(args)
    print('| adv_params:', adv_params)
    status = None
    all_inputs = None
    all_outputs = None
    all_targets = None
    bar = progressbar.ProgressBar(len(data_loader))
    bar.start()
    for batch_num, (imgs, targets) in enumerate(data_loader):
        if args.adversary_to_generate == str(AdversaryType.DEEPFOOL):
            assert adv_params['learning_rate'] is not None
            s, r = adversary.deepfool(
                model, imgs, targets, args.data_params['NUM_CLASSES'],
                train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter,
                step_size=adv_params['learning_rate'], batch_size=args.batchsize,
                labels=dataset.get_classes())
        elif args.adversary_to_generate == str(AdversaryType.FGS):
            s, r = adversary.fgs(
                model, imgs, targets, train_mode=(args.data_type == 'train'),
                mode=args.fgs_mode)
        elif args.adversary_to_generate == str(AdversaryType.IFGS):
            assert adv_params['learning_rate'] is not None
            s, r = adversary.ifgs(
                model, imgs, targets,
                train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter,
                step_size=adv_params['learning_rate'], mode=args.fgs_mode)
        elif args.adversary_to_generate == str(AdversaryType.CWL2):
            assert args.adv_strength is not None and len(args.adv_strength) == 1
            if len(args.crop_frac) == 1:
                crop_frac = args.crop_frac[0]
            else:
                crop_frac = 1.0
            s, r = adversary.cw(
                model, imgs, targets, args.adv_strength[0], 'l2',
                tv_weight=args.tvm_weight,
                train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter,
                drop_rate=args.pixel_drop_rate, crop_frac=crop_frac,
                kappa=args.margin)
        elif args.adversary_to_generate == str(AdversaryType.CWLINF):
            assert args.adv_strength is not None and len(args.adv_strength) == 1
            s, r = adversary.cw(
                model, imgs, targets, args.adv_strength[0], 'linf',
                bound=args.adv_bound,
                tv_weight=args.tvm_weight,
                train_mode=(args.data_type == 'train'), max_iter=args.max_adv_iter,
                drop_rate=args.pixel_drop_rate, crop_frac=args.crop_frac,
                kappa=args.margin)

        if status is None:
            status = s.clone()
            all_inputs = imgs.clone()
            all_outputs = imgs + r
            all_targets = targets.clone()
        else:
            status = torch.cat((status, s), 0)
            all_inputs = torch.cat((all_inputs, imgs), 0)
            all_outputs = torch.cat((all_outputs, imgs + r), 0)
            all_targets = torch.cat((all_targets, targets), 0)
        bar.update(batch_num)

    print("| computing adversarial stats...")
    if args.compute_stats:
        rb, ssim, sc = adversary.compute_stats(all_inputs, all_outputs, status)
        print('| average robustness = ' + str(rb))
        print('| average SSIM = ' + str(ssim))
        print('| success rate = ' + str(sc))

    # Unnormalize before saving
    unnormalize = Unnormalize(args.data_params['MEAN_STD']['MEAN'],
                                args.data_params['MEAN_STD']['STD'])
    all_inputs = unnormalize(all_inputs)
    all_outputs = unnormalize(all_outputs)
    # save output
    output_file = get_adversarial_file_path(
        args, args.adversarial_root, defense_name, adv_params,
        data_indices['end_idx'], start_idx=data_indices['start_idx'],
        with_defense=False)
    print("| Saving adversarial data at " + output_file)
    if not os.path.isdir(args.adversarial_root):
        os.makedirs(args.adversarial_root)
    torch.save({'status': status, 'all_inputs': all_inputs,
                'all_outputs': all_outputs, 'all_targets': all_targets},
                output_file)