def create(model_id='l2-0'):
    model_url = robust_models[model_id]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if model_url is not None:
        weights = zoo.fetch_weights(weights_uri=model_url, unzip=False)
        ds = ImageNet('/tmp/')
        m = ds.get_model(arch='resnet50', pretrained=False)
        checkpt = torch.load(weights, pickle_module=dill, map_location=device)
        model_keys = ['model', 'state_dict']
        model_key = [k for k in model_keys if k in checkpt.keys()][0]
        layer_keys = filter(lambda x: x.startswith('module.model'),
                            checkpt[model_key].keys())
        checkpt = {
            k[len('module.model.'):]: checkpt[model_key][k]
            for k in layer_keys
        }
        m.load_state_dict(checkpt)
    else:
        m = torchvision.models.resnet50(pretrained=True)
    m.eval()

    preprocessing = dict(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225],
                         axis=-3)

    fmodel = fb.models.PyTorchModel(m,
                                    bounds=(0, 1),
                                    preprocessing=preprocessing)

    return fmodel
Beispiel #2
0
def get_custom_imagenet(restricted=False,
                        data_path='./data',
                        data_aug=False,
                        shuffle_val=False,
                        batch_size=20):
    """
        We use helpers from robustness library
    Returns:
        loader: dataset loader
        norm: normalization function for dataset
        label_map: label map (class numbers to names) for dataset
    """
    if restricted:
        ds = RestrictedImageNet(data_path)
        label_map = CLASS_DICT['RestrictedImageNet']
    else:
        ds = ImageNet(data_path)
        label_map = CLASS_DICT['ImageNet']
        label_map = {k: v.split(',')[0] for k, v in label_map.items()}

    normalization = helpers.InputNormalize(ds.mean.cuda(), ds.std.cuda())
    loaders = ds.make_loaders(1,
                              batch_size=batch_size,
                              data_aug=data_aug,
                              shuffle_val=shuffle_val)

    return loaders, normalization, label_map
Beispiel #3
0
def get_transfered_model(file_name: str, num_targets: int, save_dir_models: str):
    model = imagenet_models.resnet18()
    model.fc = nn.Linear(512, num_targets).cuda()
    model = AttackerModel(model, ImageNet(''))
    while hasattr(model, 'model'):
        model = model.model
    checkpoint = torch.load(f"{save_dir_models}/{file_name}.pt")
    model.load_state_dict(checkpoint)
    return model.cuda()
Beispiel #4
0
def load_imagenet_model_from_checkpoint(model, checkpoint):
    """
        Loads pretrained ImageNet models from the https://github.com/Microsoft/robust-models-transfer
    """
    # Makes us able to load models saved with legacy versions
    state_dict_key = 'model'
    if not ('model' in checkpoint):
        state_dict_key = 'state_dict'

    sd = checkpoint[state_dict_key]
    sd = {k[len('module.'):]: v for k, v in sd.items()}
    model = AttackerModel(model, ImageNet(''))
    model.load_state_dict(sd)

    while hasattr(model, 'model'):
        model = model.model

    return model.cuda()
Beispiel #5
0
            if args.rho > 0:
                if k == 'conv1.weight':
                    print('Modifying', k)
                    data = generateSmoothKernel(data.cpu().numpy(), args.rho)
                    data = torch.from_numpy(data)
            new_state_dict[k] = data
        net.load_state_dict(new_state_dict)
        # print(model)
        # input('continue')
        # net.'load_state_dict(torch.load('./cifar_linf_8.pt')['state_dict'])

    elif args.net == "imagenet-madry":
        print("Using ResNet-50 with Madry training")
        from robustness.model_utils import make_and_restore_model
        from robustness.datasets import ImageNet
        ds = ImageNet('../data')
        net, _ = make_and_restore_model(parallel=False,
                                        arch='resnet50',
                                        dataset=ds,
                                        resume_path='./imagenet_linf_8.pt')
        net = net.model

    elif args.net == "resnet-trades":
        print("Using WideResNet with TRADES training")
        from TRADES.models.wideresnet import WideResNet
        net = WideResNet()
        net.load_state_dict(
            torch.load('../../drive/My Drive/cifar/model_cifar_wrn.pt'))
        net.eval()
        for params in net.parameters():
            params.requires_grad = False
Beispiel #6
0
def main():
    # Parse image location from command argument
    parser = argparse.ArgumentParser()
    parser.add_argument('--image', type=str, required=True)
    args = parser.parse_args()

    # Check if the given location exists and it is a valid image file
    if os.path.exists(args.image) and args.image.endswith(
        ('png', 'jpg', 'jpeg')):
        # Open the image from the given location
        image = Image.open(args.image)

        # Transform the image to a PyTorch tensor
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        image = transform(image)
        image = image.unsqueeze(0).cuda()

        # Attack parameters
        kwargs = {
            'constraint': 'inf',
            'eps': 16.0 / 255.0,
            'step_size': 1.0 / 255.0,
            'iterations': 500,
            'do_tqdm': True,
        }

        # Set the dataset for the robustness model
        dataset = ImageNet('dataset/')

        # Initialize a pretrained model via the robustness library
        model, _ = make_and_restore_model(arch='resnet50',
                                          dataset=dataset,
                                          pytorch_pretrained=True)
        model = model.cuda()

        # For evaluation, the standard ResNet50 from torchvision is used
        eval_model = torchvision.models.resnet50(pretrained=True).cpu().eval()

        # Get the model prediction for the original image
        label = eval_model(image.cpu())
        label = torch.argmax(label[0])
        label = label.view(1).cuda()

        # Create an adversarial example of the original images
        _, adversarial_example = model(image, label, make_adv=True, **kwargs)

        # Get the prediction of the model for the adversarial image
        adversarial_prediction = eval_model(adversarial_example.cpu())
        adversarial_prediction = torch.argmax(adversarial_prediction[0])

        # Print the original and the adversarial predictions
        print('Original prediction: ' + str(label.item()))
        print('Adversarial prediction: ' + str(adversarial_prediction.item()))

        # Save the adversarial example in the same folder as the original image
        filename_and_extension = args.image.split('.')
        adversarial_location = filename_and_extension[
            0] + '_adversarial.' + filename_and_extension[-1]
        save_image(adversarial_example[0], adversarial_location)

    else:
        print('Incorrect image path!')
from PIL import Image
import numpy as np

from robustness.datasets import ImageNet
from robustness.model_utils import make_and_restore_model
import torch
import matplotlib.pyplot as plt

ds = ImageNet('/tmp')
model, _ = make_and_restore_model(
    arch='resnet50',
    dataset=ds,
    resume_path='/home/siddhant/Downloads/imagenet_l2_3_0.pt')
model.eval()

img = np.asarray(
    Image.open(
        '/home/siddhant/CMU/robustness_applications/sample_inputs/img_bear.jpg'
    ).resize((224, 224)))
img = img / 254.
img = np.transpose(img, (2, 0, 1))

_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_IMAGENET_STDDEV = [0.229, 0.224, 0.225]

img_var = torch.tensor(img, dtype=torch.float)[None, :]
img = img_var.clone().detach().cpu().numpy()
img = img[0]

img = img.transpose((1, 2, 0))
img *= 255
Beispiel #8
0
def load_model(model_type):
    if model_type == "simclr":
        # load checkpoint for simclr
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50-1x.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['state_dict'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet

    if model_type == "simclr_v2_0":
        # load checkpoint for simclr
        checkpoint = torch.load('/content/gdrive/MyDrive/r50_1x_sk0.pth')
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(checkpoint['resnet'])
        # preprocess images for simclr
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(256),
            transforms.ToTensor()
        ])
        return resnet
    if model_type == "moco":
        # load checkpoints of moco
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/moco_v1_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for moco
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "mocov2":
        # load checkpoints of mocov2
        state_dict = torch.load(
            '/content/gdrive/MyDrive/moco/moco_v2_200ep_pretrain.pth.tar',
            map_location=torch.device('cpu'))['state_dict']
        resnet = models.resnet50(pretrained=False)
        for k in list(state_dict.keys()):
            if k.startswith('module.encoder_q'
                            ) and not k.startswith('module.encoder_q.fc'):
                state_dict[k[len("module.encoder_q."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for mocov2
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "InsDis":
        # load checkpoints for instance recoginition resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/lemniscate_resnet50_update.pth',
            map_location=torch.device('cpu'))['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('module') and not k.startswith('module.fc'):
                state_dict[k[len("module."):]] = state_dict[k]
            del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for instance recoginition resnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "place365_rn50":
        # load checkpoints for place365 resnet
        resnet = models.resnet50(pretrained=False)
        state_dict = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/resnet50_places365.pth.tar',
            map_location=torch.device('cuda'))['state_dict']
        #     for k in list(state_dict.keys()):
        #         if k.startswith('module') and not k.startswith('module.fc'):
        #             state_dict[k[len("module."):]] = state_dict[k]
        #         del state_dict[k]
        msg = resnet.load_state_dict(state_dict, strict=False)
        #     assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        #preprocess for place365-resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnext101":
        #load ResNeXt 101_32x8 imagenet trained model
        resnet = models.resnext101_32x8d(pretrained=True)
        #preprocess for resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "wsl_resnext101":
        # load wsl resnext101
        resnet = models.resnext101_32x8d(pretrained=False)
        checkpoint = torch.load(
            "/content/gdrive/MyDrive/model_checkpoints/ig_resnext101_32x8-c38310e5.pth"
        )
        resnet.load_state_dict(checkpoint)
        #preprocess for wsl resnext101
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_resnet":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet
    if model_type == "wide_resnet50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=True)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet50":
        # load checkpoint for st resnet
        resnet = models.resnet50(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_resnet101":
        # load checkpoint for st resnet
        resnet = models.resnet101(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn50":
        # load checkpoint for st resnet
        resnet = models.wide_resnet50_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "untrained_wrn101":
        # load checkpoint for st resnet
        resnet = models.wide_resnet101_2(pretrained=False)
        #preprocess for st_resnet50
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return resnet

    if model_type == "st_alexnet":
        # load checkpoint for st alexnet
        alexnet = models.alexnet(pretrained=True)
        #preprocess for alexnet
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        return alexnet

    if model_type == "clip":
        import clip
        resnet, preprocess = clip.load("RN50")
        return resnet

    if model_type == 'linf_8':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8_model.pt') # https://drive.google.com/file/d/1DRkIcM_671KQNhz1BIXMK6PQmHmrYy_-/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_8.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'linf_4':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/robust_resnet.pt')#https://drive.google.com/file/d/1_tOhMBqaBpfOojcueSnYQRw_QgXdPVS6/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_linf_4.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            #         if k.startswith('module.attacker.model.') and not k.startswith('module.attacker.normalize') :
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'l2_3':
        #     resnet = torch.load('/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0_model.pt') # https://drive.google.com/file/d/1SM9wnNr_WnkEIo8se3qd3Di50SUT9apn/view?usp=sharing
        #     preprocess = transforms.Compose([
        #     transforms.Resize(256),
        #     transforms.CenterCrop(224),
        #     transforms.ToTensor(),
        #     transforms.Normalize(
        #     mean=[0.485, 0.456, 0.406],
        #     std=[0.229, 0.224, 0.225])
        #     ])
        #     return resnet
        resnet = models.resnet50(pretrained=False)
        checkpoint = torch.load(
            '/content/gdrive/MyDrive/model_checkpoints/imagenet_l2_3_0.pt',
            map_location=torch.device('cuda'))
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):

                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet

    if model_type == 'resnet50_l2_eps1' or model_type == 'resnet50_l2_eps0.01' or model_type == 'resnet50_l2_eps0.03' or model_type == 'resnet50_l2_eps0.5' or model_type == 'resnet50_l2_eps0.25' or model_type == 'resnet50_l2_eps3' or model_type == 'resnet50_l2_eps5':
        resnet = models.resnet50(pretrained=False)
        ds = ImageNet('/tmp')
        total_resnet, checkpoint = make_and_restore_model(
            arch='resnet50',
            dataset=ds,
            resume_path=
            f'/content/gdrive/MyDrive/model_checkpoints/{model_type}.ckpt')
        # resnet=total_resnet.attacker
        state_dict = checkpoint['model']
        for k in list(state_dict.keys()):
            if k.startswith('module.attacker.model.'):
                state_dict[k[len('module.attacker.model.'):]] = state_dict[k]
            del state_dict[k]
        resnet.load_state_dict(state_dict)
        return resnet
    parser.add_argument('--out-path', default='/tmp/path_to_imagenet_features')
    parser.add_argument('--batch-size', default=256, type=int)
    parser.add_argument('--chunk-threshold', default=20000, type=int)
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--num-workers', default=8, type=int)
    parser.add_argument('--max-chunks', default=-1, type=int)

    args = parser.parse_args()

    # Personal preference here to default to grad not enabled; 
    # explicitly enable grad when necessary for memory reasons
    ch.manual_seed(0)
    ch.set_grad_enabled(False)

    print("Initializing dataset and loader...")
    ds_imagenet = ImageNet(args.dataset_path)
    train_loader, test_loader = ds_imagenet.make_loaders(args.num_workers, args.batch_size, 
                                                        data_aug=False, shuffle_train=False, shuffle_val=False)

    print("Loading model...")
    model, _ = make_and_restore_model( 
        arch=args.arch, 
        dataset=ds_imagenet,
        resume_path=args.model_root
    )
    model.eval()
    model = ch.nn.DataParallel(model.to(args.device))

    out_dir = args.out_path
    if not os.path.exists(out_dir):
        print(f"Making directory {out_dir}")
Beispiel #10
0
def convert_to_robustness(model, state_dict):
    dataset = ImageNet('dataset/imagenet-airplanes')
    model, _ = make_and_restore_model(arch=model, dataset=dataset)
    state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
    return model, state_dict
def main():
    images_location = 'results/images/'
    current_time = str(datetime.datetime.now().strftime('%d-%m-%Y_%H:%M:%S'))
    save_location = images_location + current_time
    os.mkdir(save_location)

    print('Enter mode type: rgb or grayscale')
    mode = input()

    if mode == 'grayscale':
        imageset = (
            torch.load('./dataset/imagenet-airplanes-images-grayscale.pt'))
        image_loader = torch.utils.data.DataLoader(imageset,
                                                   batch_size=4,
                                                   num_workers=2)
        labels = (torch.load('./dataset/imagenet-airplanes-labels.pt'))
    else:
        imageset = (torch.load('./dataset/imagenet-airplanes-images.pt'))
        image_loader = torch.utils.data.DataLoader(imageset,
                                                   batch_size=4,
                                                   num_workers=2)
        labels = (torch.load('./dataset/imagenet-airplanes-labels.pt'))

    kwargs = {
        'constraint': 'inf',
        'eps': 64 / 255.0,
        'step_size': 1 / 255.0,
        'iterations': 500,
        'do_tqdm': True,
    }

    dataset = ImageNet('dataset/imagenet-airplanes')

    model, _ = make_and_restore_model(arch='resnet50',
                                      dataset=dataset,
                                      pytorch_pretrained=True)
    model = model.cuda()

    for batch_index, (images_batch,
                      labels_batch) in enumerate(zip(image_loader, labels)):
        images_batch = images_batch[:2]
        labels_batch = labels_batch[:2]

        label = torch.LongTensor(2)
        label[0] = 101
        label[1] = 101

        print(label)

        print(images_batch.shape)
        print(labels_batch.shape)

        _, images_adversarial = model(images_batch.cuda(),
                                      labels_batch.cuda(),
                                      make_adv=True,
                                      **kwargs)
        predictions, _ = model(images_adversarial)
        save_images(images_batch, images_adversarial, batch_index,
                    save_location)

    print('Finished!')