Пример #1
0
def cifar10_attack(args, N_img):
    img_size = args.cifar10_img_size
    model_file_name = model_settings.get_model_file_name("cifar10", args)

    from models import CifarDNN
    model = CifarDNN(model_type='res18',
                     pretrained=False,
                     gpu=args.use_gpu,
                     img_size=img_size).eval()
    model.load_state_dict(
        torch.load('../class_models/%s_res18.model' % (model_file_name)))

    # mean = np.array([0.485, 0.456, 0.406])
    # std = np.array([0.229, 0.224, 0.225])
    mean, std = constants.get_mean_std('cifar10')
    fmodel = foolbox.models.PyTorchModel(model,
                                         bounds=(0, 1),
                                         num_classes=10,
                                         preprocessing=(mean.reshape(
                                             (3, 1, 1)), std.reshape(
                                                 (3, 1, 1))))

    import torchvision.datasets as datasets

    transform = model_settings.get_data_transformation_without_normalization(
        'cifar10', args)

    dataset = datasets.CIFAR10(root=f"{constants.ROOT_DATA_PATH}",
                               train=False,
                               download=True,
                               transform=transform)
    src_images, src_labels = [], []
    tgt_images, tgt_labels = [], []
    used_ids = set()
    while (len(src_images) < N_img):
        sid = np.random.randint(len(dataset))
        tid = np.random.randint(len(dataset))
        if (sid, tid) in used_ids:
            continue
        used_ids.add((sid, tid))
        src_image, _ = dataset[sid]
        tgt_image, _ = dataset[tid]
        src_image, tgt_image = src_image.numpy(), tgt_image.numpy()
        src_label = np.argmax(fmodel.forward_one(src_image))
        tgt_label = np.argmax(fmodel.forward_one(tgt_image))
        if (src_label != tgt_label):
            src_images.append(src_image)
            tgt_images.append(tgt_image)
            src_labels.append(src_label)
            tgt_labels.append(tgt_label)
    mask = None

    return src_images, src_labels, tgt_images, tgt_labels, fmodel, mask
Пример #2
0
def mnist_attack(args, N_img, num_class, mounted=False):
    img_size = args.mnist_img_size
    model_file_name = model_settings.get_model_file_name("mnist", args)

    from models import MNISTDNN
    model = MNISTDNN(model_type='res18',
                     gpu=args.use_gpu,
                     n_class=num_class,
                     img_size=img_size).eval()
    model.load_state_dict(
        torch.load('../class_models/%s_%s.model' % (model_file_name, 'res18')))
    mean, std = constants.get_mean_std('mnist')
    fmodel = foolbox.models.PyTorchModel(model,
                                         bounds=(0, 1),
                                         num_classes=num_class,
                                         preprocessing=(mean, std))

    import torchvision.transforms as transforms
    transform = model_settings.get_data_transformation_without_normalization(
        'mnist', args)

    import torchvision
    dataset = torchvision.datasets.MNIST(root=f"{constants.ROOT_DATA_PATH}",
                                         train=False,
                                         download=True,
                                         transform=transform)
    src_images, src_labels = [], []
    tgt_images, tgt_labels = [], []
    used_ids = set()
    while (len(src_images) < N_img):
        sid = np.random.randint(len(dataset))
        tid = np.random.randint(len(dataset))
        if (sid, tid) in used_ids:
            continue
        used_ids.add((sid, tid))
        src_image, src_y = dataset[sid]
        tgt_image, tgt_y = dataset[tid]
        src_image, tgt_image = src_image.numpy(), tgt_image.numpy()
        src_label = np.argmax(fmodel.forward_one(src_image))
        tgt_label = np.argmax(fmodel.forward_one(tgt_image))
        if (src_label != tgt_label) and (src_y == src_label) and (
                tgt_y == tgt_label):  # predictions should match gt
            src_images.append(src_image)
            tgt_images.append(tgt_image)
            src_labels.append(src_label)
            tgt_labels.append(tgt_label)
    mask = None
    print("MNIST attack, %d src imgs, %d tgt imgs" %
          (len(src_images), len(tgt_images)))

    return src_images, src_labels, tgt_images, tgt_labels, fmodel, mask
Пример #3
0
    parser.add_argument('--N_Z', type=int)
    parser.add_argument('--mounted', action='store_true')

    parser.add_argument('--mnist_img_size', type=int, default=28)
    parser.add_argument('--mnist_padding_size', type=int, default=0)
    parser.add_argument('--mnist_padding_first', action='store_true')

    parser.add_argument('--cifar10_img_size', type=int, default=32)
    parser.add_argument('--cifar10_padding_size', type=int, default=0)
    parser.add_argument('--cifar10_padding_first', action='store_true')
    args = parser.parse_args()

    GPU = True
    TASK = args.TASK
    if TASK == 'mnist' or TASK == 'cifar10':
        TASK = model_settings.get_model_file_name(TASK, args) # TODO MNIST N_Z 3136

    if TASK.startswith('mnist'):
        n_channels = 1
    else:
        n_channels = 3

    N_Z = args.N_Z

    model = AEGenerator(n_channels=n_channels, gpu=GPU, N_Z=N_Z)

    if N_Z == 128:
        ENC_SHAPE = (8, 4, 4)
    elif N_Z == 9408:
        ENC_SHAPE = (48, 14, 14)
    else:
Пример #4
0
    parser.add_argument('--cifar10_padding_size', type=int, default=0)
    parser.add_argument('--cifar10_padding_first', action='store_true')

    parser.add_argument('--smooth', action='store_true')
    parser.add_argument('--smooth_suffix', type=str, default='')
    args = parser.parse_args()

    GPU = True
    # TASK = 'celeba'
    TASK = args.TASK
    # if TASK == 'mnist' or TASK == 'cifar10':
    #     TASK, output_file_name = model_settings.get_model_file_name(TASK, args)
    # else:
    #     output_file_name = TASK
    if TASK == 'mnist' or TASK == 'cifar10':
        TASK = model_settings.get_model_file_name(TASK, args)

    N_Z = args.N_Z
    print(TASK, N_Z)

    if TASK.startswith('mnist'):
        n_channels = 1
    else:
        n_channels = 3

    if TASK.startswith('mnist') and N_Z == 3136:
        model = MNIST224VAEGenerator(n_channels=n_channels, gpu=GPU)
    else:
        model = VAEGenerator(n_channels=n_channels, gpu=GPU)

Пример #5
0
def load_model_dataset(TASK, REF, root_dir):
    if TASK == 'imagenet':
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        trainset = ImageNetDataset(train=True, transform=transform)
        testset = ImageNetDataset(train=False, transform=transform)
        if REF == 'dense121':
            ref_model = models.densenet121(pretrained=True).eval()
        elif REF == 'res18':
            ref_model = models.resnet18(pretrained=True).eval()
        elif REF == 'res50':
            ref_model = models.resnet50(pretrained=True).eval()
        elif REF == 'vgg16':
            ref_model = models.vgg16(pretrained=True).eval()
        elif REF == 'googlenet':
            ref_model = models.googlenet(pretrained=True).eval()
        elif REF == 'wideresnet':
            ref_model = models.wide_resnet50_2(pretrained=True).eval()
        if GPU:
            ref_model.cuda()
        preprocess_std = (0.229, 0.224, 0.225)

    elif TASK == 'celeba':
        image_size = 224
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        attr_o_i = 'Mouth_Slightly_Open'
        list_attr_path = '%s/celebA/list_attr_celeba.txt' % (root_dir)
        img_data_path = '%s/celebA/img_align_celeba' % (root_dir)
        trainset = CelebAAttributeDataset(attr_o_i, list_attr_path, img_data_path, data_split='train',
                                          transform=transform)
        testset = CelebAAttributeDataset(attr_o_i, list_attr_path, img_data_path, data_split='test',
                                         transform=transform)

        ref_model = CelebADNN(model_type=REF, pretrained=False, gpu=GPU)
        ref_model.load_state_dict(torch.load('../class_models/celeba_%s_%s.model' % (attr_o_i, REF)))
        preprocess_std = (0.5, 0.5, 0.5)

    elif TASK == 'celebaid':
        image_size = 224
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        num_class = 10
        trainset = CelebAIDDataset(root_dir=root_dir, is_train=True, transform=transform, preprocess=False,
                                   random_sample=False, n_id=num_class)
        testset = CelebAIDDataset(root_dir=root_dir, is_train=False, transform=transform, preprocess=False,
                                  random_sample=False, n_id=num_class)

        ref_model = CelebAIDDNN(model_type=REF, num_class=num_class, pretrained=False, gpu=GPU).eval()
        ref_model.load_state_dict(torch.load('../class_models/celeba_id_%s.model' % (REF)))

        preprocess_std = (0.5, 0.5, 0.5)

    elif TASK == 'dogcat2':
        BATCH_SIZE = 32
        image_size = 224
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
        ])
        num_class = 2
        trainset = DogCatDataset(root_dir=f'{constants.ROOT_DATA_PATH}/dogcat', mode='train', transform=transform)
        testset = DogCatDataset(root_dir=f'{constants.ROOT_DATA_PATH}/dogcat', mode='test', transform=transform)

        ref_model = NClassDNN(model_type=REF, pretrained=False, gpu=GPU, n_class=num_class)
        ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % ('dogcat2', REF)))

        preprocess_std = (1, 1, 1)

    elif TASK == 'cifar10':
        # mean = (0.485, 0.456, 0.406)
        # std = (0.229, 0.224, 0.225)
        mean, std = constants.plot_mean_std(TASK)
        transform = model_settings.get_data_transformation(TASK, args)
        model_file_name = model_settings.get_model_file_name(TASK, args)

        trainset = torchvision.datasets.CIFAR10(root='%s/'%(constants.ROOT_DATA_PATH), train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='%s/'%(constants.ROOT_DATA_PATH), train=False, download=True, transform=transform)

        ref_model = CifarDNN(model_type=REF, gpu=GPU, pretrained=False, img_size=cifar10_img_size)
        ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % (model_file_name, REF)))
        if GPU:
            ref_model.cuda()
        preprocess_std = std

    elif TASK == 'mnist':
        mean, std = constants.plot_mean_std(TASK)
        transform = model_settings.get_data_transformation(TASK, args)
        model_file_name = model_settings.get_model_file_name(TASK, args)
        num_class = 10
        trainset = torchvision.datasets.MNIST(root='%s/'%(constants.ROOT_DATA_PATH), train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='%s/'%(constants.ROOT_DATA_PATH), train=False, download=True, transform=transform)

        ref_model = MNISTDNN(model_type=REF, gpu=GPU, n_class=num_class, img_size=mnist_img_size)
        ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % (model_file_name, REF)))

        preprocess_std = std

    elif TASK == 'celeba2':
        image_size = 224
        num_class = 2
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        root_dir = f"{constants.ROOT_DATA_PATH}"
        trainset = CelebABinIDDataset(poi=args.celeba_poi, root_dir=root_dir, is_train=True, transform=transform,
                                      get_data=False, random_sample=False, n_id=10)
        testset = CelebABinIDDataset(poi=args.celeba_poi, root_dir=root_dir, is_train=False, transform=transform,
                                     get_data=False, random_sample=False, n_id=10)

        ref_model = NClassDNN(model_type=REF, pretrained=False, gpu=GPU, n_class=num_class)
        ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % ('celeba2', REF)))

        preprocess_std = (0.5, 0.5, 0.5)

    else:
        assert 0

    return ref_model, trainset, testset, preprocess_std
Пример #6
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cifar10_img_size', type=int, default=32)
    parser.add_argument('--cifar10_padding_size', type=int, default=0)
    parser.add_argument('--cifar10_padding_first', action='store_true')
    parser.add_argument('--do_train', action='store_true')
    args = parser.parse_args()

    img_size = args.cifar10_img_size

    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    transform = model_settings.get_data_transformation("cifar10", args)
    # model_file_name, _ = model_settings.get_model_file_name("cifar10", args)
    model_file_name = model_settings.get_model_file_name("cifar10", args)
    print(model_file_name)

    trainset = torchvision.datasets.CIFAR10(root='../raw_data/',
                                            train=True,
                                            download=True,
                                            transform=transform)
    testset = torchvision.datasets.CIFAR10(root='../raw_data/',
                                           train=False,
                                           download=True,
                                           transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
    testloader = torch.utils.data.DataLoader(testset,
Пример #7
0
    parser.add_argument('--mnist_img_size', type=int, default=28)
    parser.add_argument('--mnist_padding_size', type=int, default=0)
    parser.add_argument('--mnist_padding_first', action='store_true')

    parser.add_argument('--cifar10_img_size', type=int, default=32)
    parser.add_argument('--cifar10_padding_size', type=int, default=0)
    parser.add_argument('--cifar10_padding_first', action='store_true')

    parser.add_argument('--pretrained', action='store_true')
    args = parser.parse_args()

    GPU = True
    TASK = args.TASK
    if TASK == 'mnist' or TASK == 'cifar10':
        TASK = model_settings.get_model_file_name(TASK, args) #TODO mnist N_Z 3136
    model_file_name = TASK

    if TASK.startswith('mnist'):
        n_channels = 1
    else:
        n_channels = 3

    N_Z = args.N_Z

    modelD = GANDiscriminator(n_channels=n_channels, gpu=GPU)
    modelG = GANGenerator(n_z=N_Z, n_channels=n_channels, gpu=GPU)

    if args.pretrained:
        modelD.load_state_dict(torch.load('./gen_models/normal_%s_gradient_gan_%d_discriminator.model' % (TASK, N_Z)))
        modelG.load_state_dict(torch.load('./gen_models/normal_%s_gradient_gan_%d_generator.model' % (TASK, N_Z)))
Пример #8
0
def calc_internal_dim(task, pgen_type, args):
    B = 5000
    BATCH_SIZE = 32
    p_gen = attack_setting.load_pgen(task=task, pgen_type=pgen_type, args=args)
    input_size = attack_setting.pgen_input_size(task=task,
                                                pgen_type=pgen_type,
                                                args=args)
    latent_data = torch.randn((B, *input_size))
    if p_gen is not None:
        if args.use_gpu:
            latent_data = latent_data.cuda()
        # projected_data = p_gen.project(latent_data)
        projected_np = None
        for _i in range(int(B / BATCH_SIZE) + 1):
            _data = latent_data[_i * BATCH_SIZE:(_i + 1) * BATCH_SIZE]
            _B = _data.shape[0]
            if _B < 1:
                break
            _projected = p_gen.project(_data)
            _np = _projected.detach().cpu().numpy().reshape(_B, -1)
            if projected_np is None:
                projected_np = _np
            else:
                projected_np = np.concatenate((projected_np, _np), axis=0)
    else:
        projected_np = latent_data.numpy().reshape(B, -1)

    model_file_name = model_settings.get_model_file_name(TASK=task, args=args)

    if args.do_svd:
        print('Doing svd...')
        u, s, v = np.linalg.svd(projected_np, full_matrices=False)
        np.save(
            'BAPP_result/%s_%s_internal_dim_u.npy' %
            (model_file_name, pgen_type), u)
        np.save(
            'BAPP_result/%s_%s_internal_dim_s.npy' %
            (model_file_name, pgen_type), s)
        np.save(
            'BAPP_result/%s_%s_internal_dim_v.npy' %
            (model_file_name, pgen_type), v)
    else:
        u = np.load('BAPP_result/%s_%s_internal_dim_u.npy' %
                    (model_file_name, pgen_type))
        s = np.load('BAPP_result/%s_%s_internal_dim_s.npy' %
                    (model_file_name, pgen_type))
        v = np.load('BAPP_result/%s_%s_internal_dim_v.npy' %
                    (model_file_name, pgen_type))
        projected_np = u.dot(np.diag(s)).dot(v)
    cos_sims = []
    s_keep = np.zeros(s.shape)
    with tqdm(range(s.shape[0])) as pbar:
        for i in pbar:
            s_keep[i] = s[i]
            if i % inter_gap == 0:
                slice = u.dot(np.diag(s_keep)).dot(v)
                cos_sim = np.mean(
                    utils.calc_cos_sim(x1=projected_np, x2=slice, dim=1))
                cos_sims.append(cos_sim)
                pbar.set_description('Keep dim %d, cosine similarity %f' %
                                     (i, cos_sim))
                np.save(
                    'BAPP_result/%s_%s_internal_dim.npy' %
                    (model_file_name, pgen_type), cos_sims)
                if cos_sim > 0.9999:
                    break
    print("%s, %s, Cosine similarity for internal dim done" %
          (model_file_name, pgen_type))
    print(cos_sims)
Пример #9
0
def load_pgen(task, pgen_type, args):
    if task == 'imagenet' or task == 'celeba' or task == 'celebaid' or task == 'celeba2':
        if pgen_type == 'naive':
            p_gen = None

        elif pgen_type == 'resize9408':
            p_gen = ResizeGenerator(factor=4.0)

        elif pgen_type == 'DCT2352':
            p_gen = DCTGenerator(factor=8.0)
        elif pgen_type == 'DCT4107':
            p_gen = DCTGenerator(factor=6.0)
        elif pgen_type == 'DCT9408':
            p_gen = DCTGenerator(factor=4.0)
        elif pgen_type == 'DCT16428':
            p_gen = DCTGenerator(factor=3.0)

        elif pgen_type == 'PCA9408':
            p_gen = PCAGenerator(N_b=9408, approx=True, basis_only=True)
            # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (task, 9408))
            p_gen.load('./gen_models/pca_gen_%s_%d.npy' % ('imagenet', 9408))

        elif pgen_type == 'AE128':
            p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=128)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_ae_%d_generator.model' %
                           (task, 128)))
        elif pgen_type == 'AE9408':
            p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=9408)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_ae_%d_generator.model' %
                           (task, 9408)))

        elif pgen_type == 'VAE9408':
            p_gen = VAEGenerator(n_channels=3, gpu=args.use_gpu)
            if args.TASK == 'celeba' and args.smooth:
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_vae_%d_smooth%s_generator.model'
                        % (task, 9408, args.smooth_suffix)))
            else:
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_vae_%d_generator.model' %
                        (task, 9408)))

        elif pgen_type == 'GAN128':
            p_gen = GANGenerator(n_z=128, n_channels=3, gpu=args.use_gpu)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_gan_%d_generator.model' %
                           (task, 128)))
        elif pgen_type == 'GAN9408':
            p_gen = GANGenerator(n_z=9408, n_channels=3, gpu=args.use_gpu)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_gan_%d_generator.model' %
                           (task, 9408)))

        elif pgen_type == 'oldAE9408':
            p_gen = OldAEGenerator(n_channels=3, gpu=args.use_gpu)
            p_gen.load_state_dict(
                torch.load(
                    './gen_models/%s_gradient_old_ae_%d_generator.model' %
                    (task, 9408)))

        elif pgen_type == 'expcos9408':
            p_gen = ExpCosGenerator(n_channels=3,
                                    gpu=args.use_gpu,
                                    N_Z=9408,
                                    lmbd=args.lmbd)
            p_gen.load_state_dict(
                torch.load(
                    './gen_models/%s_gradient_expcos_%d_generator.model' %
                    (task, 9408)))

    elif task == 'cifar10':
        model_file_name = model_settings.get_model_file_name("cifar10", args)
        n_channels = 3
        if args.cifar10_img_size == 32:
            if pgen_type == 'naive':
                p_gen = None

            elif pgen_type == 'resize192':
                p_gen = ResizeGenerator(factor=4.0)

            elif pgen_type == 'DCT192':
                p_gen = DCTGenerator(factor=4.0)

            elif pgen_type == 'AE192':
                p_gen = Cifar10AEGenerator(gpu=args.use_gpu)
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_ae_%d_generator.model' %
                        (task, 192)))

        elif args.cifar10_img_size == 224:
            if pgen_type == 'naive':
                p_gen = None

            elif pgen_type == 'resize9408':
                p_gen = ResizeGenerator(factor=4.0)

            elif pgen_type == 'DCT9408':
                p_gen = DCTGenerator(factor=4.0)

            elif pgen_type == 'PCA9408':
                N_b = 9408
                approx = True
                p_gen = PCAGenerator(N_b=N_b, approx=approx, basis_only=True)
                # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (model_file_name, 9408))
                p_gen.load('./gen_models/pca_gen_%s_%d.npy' %
                           ('imagenet', 9408))

            elif pgen_type.startswith('AE'):
                N_Z = int(pgen_type[2:])
                p_gen = AEGenerator(n_channels=3,
                                    preprocess=None,
                                    gpu=args.use_gpu,
                                    N_Z=N_Z)
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_ae_%d_generator.model' %
                        (model_file_name, N_Z)))

            elif pgen_type.startswith('GAN'):
                N_Z = int(pgen_type[3:])
                p_gen = GANGenerator(n_z=N_Z,
                                     n_channels=n_channels,
                                     gpu=args.use_gpu)
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_gan_%d_generator.model' %
                        (model_file_name, N_Z)))

            elif pgen_type == 'DCGAN':
                p_gen = Cifar10DCGenerator(ngpu=1).eval()
                p_gen.load_state_dict(
                    torch.load('./models/weights/cifar10_netG_epoch_199.pth'),
                    strict=False)

            elif pgen_type.startswith('DCGAN_finetune'):
                n_epoch = int(pgen_type[-1])
                p_gen = Cifar10DCGenerator(ngpu=1).eval()
                # p_gen.load_state_dict(torch.load('./models/weights/cifar10_224_netG_100_epoch%d.pth'%(n_epoch)), strict=False)
                p_gen.load_state_dict(torch.load(
                    './gen_models/cifar10_224_netG_100_epoch%d.pth' %
                    (n_epoch)),
                                      strict=False)

            elif pgen_type.startswith('VAE'):
                N_Z = int(pgen_type[3:])
                assert N_Z == 9408
                p_gen = VAEGenerator(n_channels=n_channels, gpu=args.use_gpu)
                if args.smooth:
                    p_gen.load_state_dict(
                        torch.load(
                            './gen_models/%s_gradient_vae_%d_smooth%s_generator.model'
                            % (model_file_name, int(
                                pgen_type[3:]), args.smooth_suffix)))
                else:
                    p_gen.load_state_dict(
                        torch.load(
                            './gen_models/%s_gradient_vae_%d_generator.model' %
                            (model_file_name, int(pgen_type[3:]))))
        else:
            print('cifar10', args.cifar10_img_size, pgen_type,
                  "Not implemented")
            assert 0

    elif task == 'mnist':
        # model_file_name, output_file_name = model_settings.get_model_file_name("mnist", args)
        model_file_name = model_settings.get_model_file_name("mnist", args)
        n_channels = 1
        if args.mnist_img_size == 28:
            if pgen_type == 'naive':
                p_gen = None

            elif pgen_type == 'resize':
                p_gen = ResizeGenerator(factor=4.0)

            elif pgen_type == 'DCT':
                p_gen = DCTGenerator(factor=4.0)

            elif pgen_type.startswith('AE'):
                p_gen = MNISTAEGenerator(gpu=args.use_gpu,
                                         N_Z=int(pgen_type[2:]))
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_%d_gradient_ae_%d_generator.model' %
                        (task, args.mnist_img_size, int(pgen_type[2:]))))

            elif pgen_type.startswith('GAN'):
                print("Not implemented yet")
                assert 0

            elif pgen_type.startswith('VAE'):
                print("Not implemented yet")
                assert 0

        elif args.mnist_img_size == 224:
            if pgen_type == 'naive':
                p_gen = None

            elif pgen_type == 'resize9408':
                p_gen = ResizeGenerator(factor=4.0)

            elif pgen_type == 'DCT9408':
                p_gen = DCTGenerator(factor=4.0)

            elif pgen_type == 'PCA9408':
                N_b = 9408
                approx = True
                p_gen = PCAGenerator(N_b=N_b, approx=approx, basis_only=True)
                # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (model_file_name, 9408))
                p_gen.load('./gen_models/pca_gen_%s_%d.npy' %
                           ('imagenet', 9408))

            elif pgen_type.startswith('AE'):
                if pgen_type.endswith('9408'):
                    p_gen = AEGenerator(n_channels=n_channels,
                                        preprocess=None,
                                        gpu=args.use_gpu,
                                        N_Z=9408)
                else:
                    print("Not implemented yet")
                    assert 0
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_ae_%d_generator.model' %
                        (model_file_name, int(pgen_type[2:]))))

            elif pgen_type.startswith('GAN'):
                assert int(pgen_type[3:]) == 9408
                p_gen = GANGenerator(n_z=int(pgen_type[3:]),
                                     n_channels=n_channels,
                                     gpu=args.use_gpu)
                p_gen.load_state_dict(
                    torch.load(
                        './gen_models/%s_gradient_gan_%d_generator.model' %
                        (model_file_name, int(pgen_type[3:]))))

            elif pgen_type.startswith('DCGAN'):
                p_gen = MNISTDCGenerator(ngpu=1).eval()
                p_gen.load_state_dict(
                    torch.load('./models/weights/MNIST_netG_epoch_99.pth'),
                    strict=False)

            elif pgen_type.startswith('VAE'):
                if int(pgen_type[3:]) == 9408:
                    p_gen = VAEGenerator(n_channels=n_channels,
                                         gpu=args.use_gpu)
                    if args.smooth:
                        p_gen.load_state_dict(
                            torch.load(
                                './gen_models/%s_gradient_vae_%d_smooth%s_generator.model'
                                % (model_file_name, 9408, args.smooth_suffix)))
                    else:
                        p_gen.load_state_dict(
                            torch.load(
                                './gen_models/%s_gradient_vae_%d_generator.model'
                                % (model_file_name, 9408)))
                elif int(pgen_type[3:]) == 3136:
                    p_gen = MNIST224VAEGenerator(n_channels=n_channels,
                                                 gpu=args.use_gpu)
                    p_gen.load_state_dict(
                        torch.load(
                            './gen_models/%s_gradient_vae_%d_generator.model' %
                            (model_file_name, 3136)))

        else:
            print('mnist', args.mnist_img_size, pgen_type, "Not implemented")
            assert 0

    elif task == 'dogcat2':
        if pgen_type == 'naive':
            p_gen = None

        elif pgen_type == 'resize9408':
            p_gen = ResizeGenerator(factor=4.0)

        elif pgen_type == 'DCT9408':
            p_gen = DCTGenerator(factor=4.0)

        elif pgen_type == 'AE9408':
            p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=9408)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_ae_%d_generator.model' %
                           (task, 9408)))

        elif pgen_type == 'GAN9408':
            p_gen = GANGenerator(n_z=9408, n_channels=3, gpu=args.use_gpu)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_gan_%d_generator.model' %
                           (task, 9408)))

        elif pgen_type == 'VAE9408':
            p_gen = VAEGenerator(n_channels=3, gpu=args.use_gpu)
            p_gen.load_state_dict(
                torch.load('./gen_models/%s_gradient_vae_%d_generator.model' %
                           (task, 9408)))

    return p_gen
Пример #10
0
    # args.padding_size
    # 0: no padding
    # for simplicity, if input size [A*A], output size [B*B], then padding_size = (B-A)/2
    parser.add_argument('--mnist_padding_size', type=int, default=0)
    parser.add_argument('--mnist_padding_first', action='store_true')

    parser.add_argument('--mounted', action='store_true')
    parser.add_argument('--N_Z', type=int, default=9408)
    args = parser.parse_args()

    mnist_img_size = args.mnist_img_size
    n_class = 10

    TASK = 'mnist'

    TASK = model_settings.get_model_file_name(TASK, args)

    transform = model_settings.get_data_transformation("mnist", args)
    model_file_name = model_settings.get_model_file_name("mnist", args)
    print(model_file_name)

    trainset = torchvision.datasets.MNIST(root='../raw_data/', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='../raw_data/', train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True)

    if TASK.startswith('mnist'):
        n_channels = 1
    else:
        n_channels = 3