def cifar10_attack(args, N_img): img_size = args.cifar10_img_size model_file_name = model_settings.get_model_file_name("cifar10", args) from models import CifarDNN model = CifarDNN(model_type='res18', pretrained=False, gpu=args.use_gpu, img_size=img_size).eval() model.load_state_dict( torch.load('../class_models/%s_res18.model' % (model_file_name))) # mean = np.array([0.485, 0.456, 0.406]) # std = np.array([0.229, 0.224, 0.225]) mean, std = constants.get_mean_std('cifar10') fmodel = foolbox.models.PyTorchModel(model, bounds=(0, 1), num_classes=10, preprocessing=(mean.reshape( (3, 1, 1)), std.reshape( (3, 1, 1)))) import torchvision.datasets as datasets transform = model_settings.get_data_transformation_without_normalization( 'cifar10', args) dataset = datasets.CIFAR10(root=f"{constants.ROOT_DATA_PATH}", train=False, download=True, transform=transform) src_images, src_labels = [], [] tgt_images, tgt_labels = [], [] used_ids = set() while (len(src_images) < N_img): sid = np.random.randint(len(dataset)) tid = np.random.randint(len(dataset)) if (sid, tid) in used_ids: continue used_ids.add((sid, tid)) src_image, _ = dataset[sid] tgt_image, _ = dataset[tid] src_image, tgt_image = src_image.numpy(), tgt_image.numpy() src_label = np.argmax(fmodel.forward_one(src_image)) tgt_label = np.argmax(fmodel.forward_one(tgt_image)) if (src_label != tgt_label): src_images.append(src_image) tgt_images.append(tgt_image) src_labels.append(src_label) tgt_labels.append(tgt_label) mask = None return src_images, src_labels, tgt_images, tgt_labels, fmodel, mask
def mnist_attack(args, N_img, num_class, mounted=False): img_size = args.mnist_img_size model_file_name = model_settings.get_model_file_name("mnist", args) from models import MNISTDNN model = MNISTDNN(model_type='res18', gpu=args.use_gpu, n_class=num_class, img_size=img_size).eval() model.load_state_dict( torch.load('../class_models/%s_%s.model' % (model_file_name, 'res18'))) mean, std = constants.get_mean_std('mnist') fmodel = foolbox.models.PyTorchModel(model, bounds=(0, 1), num_classes=num_class, preprocessing=(mean, std)) import torchvision.transforms as transforms transform = model_settings.get_data_transformation_without_normalization( 'mnist', args) import torchvision dataset = torchvision.datasets.MNIST(root=f"{constants.ROOT_DATA_PATH}", train=False, download=True, transform=transform) src_images, src_labels = [], [] tgt_images, tgt_labels = [], [] used_ids = set() while (len(src_images) < N_img): sid = np.random.randint(len(dataset)) tid = np.random.randint(len(dataset)) if (sid, tid) in used_ids: continue used_ids.add((sid, tid)) src_image, src_y = dataset[sid] tgt_image, tgt_y = dataset[tid] src_image, tgt_image = src_image.numpy(), tgt_image.numpy() src_label = np.argmax(fmodel.forward_one(src_image)) tgt_label = np.argmax(fmodel.forward_one(tgt_image)) if (src_label != tgt_label) and (src_y == src_label) and ( tgt_y == tgt_label): # predictions should match gt src_images.append(src_image) tgt_images.append(tgt_image) src_labels.append(src_label) tgt_labels.append(tgt_label) mask = None print("MNIST attack, %d src imgs, %d tgt imgs" % (len(src_images), len(tgt_images))) return src_images, src_labels, tgt_images, tgt_labels, fmodel, mask
parser.add_argument('--N_Z', type=int) parser.add_argument('--mounted', action='store_true') parser.add_argument('--mnist_img_size', type=int, default=28) parser.add_argument('--mnist_padding_size', type=int, default=0) parser.add_argument('--mnist_padding_first', action='store_true') parser.add_argument('--cifar10_img_size', type=int, default=32) parser.add_argument('--cifar10_padding_size', type=int, default=0) parser.add_argument('--cifar10_padding_first', action='store_true') args = parser.parse_args() GPU = True TASK = args.TASK if TASK == 'mnist' or TASK == 'cifar10': TASK = model_settings.get_model_file_name(TASK, args) # TODO MNIST N_Z 3136 if TASK.startswith('mnist'): n_channels = 1 else: n_channels = 3 N_Z = args.N_Z model = AEGenerator(n_channels=n_channels, gpu=GPU, N_Z=N_Z) if N_Z == 128: ENC_SHAPE = (8, 4, 4) elif N_Z == 9408: ENC_SHAPE = (48, 14, 14) else:
parser.add_argument('--cifar10_padding_size', type=int, default=0) parser.add_argument('--cifar10_padding_first', action='store_true') parser.add_argument('--smooth', action='store_true') parser.add_argument('--smooth_suffix', type=str, default='') args = parser.parse_args() GPU = True # TASK = 'celeba' TASK = args.TASK # if TASK == 'mnist' or TASK == 'cifar10': # TASK, output_file_name = model_settings.get_model_file_name(TASK, args) # else: # output_file_name = TASK if TASK == 'mnist' or TASK == 'cifar10': TASK = model_settings.get_model_file_name(TASK, args) N_Z = args.N_Z print(TASK, N_Z) if TASK.startswith('mnist'): n_channels = 1 else: n_channels = 3 if TASK.startswith('mnist') and N_Z == 3136: model = MNIST224VAEGenerator(n_channels=n_channels, gpu=GPU) else: model = VAEGenerator(n_channels=n_channels, gpu=GPU)
def load_model_dataset(TASK, REF, root_dir): if TASK == 'imagenet': transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) trainset = ImageNetDataset(train=True, transform=transform) testset = ImageNetDataset(train=False, transform=transform) if REF == 'dense121': ref_model = models.densenet121(pretrained=True).eval() elif REF == 'res18': ref_model = models.resnet18(pretrained=True).eval() elif REF == 'res50': ref_model = models.resnet50(pretrained=True).eval() elif REF == 'vgg16': ref_model = models.vgg16(pretrained=True).eval() elif REF == 'googlenet': ref_model = models.googlenet(pretrained=True).eval() elif REF == 'wideresnet': ref_model = models.wide_resnet50_2(pretrained=True).eval() if GPU: ref_model.cuda() preprocess_std = (0.229, 0.224, 0.225) elif TASK == 'celeba': image_size = 224 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) attr_o_i = 'Mouth_Slightly_Open' list_attr_path = '%s/celebA/list_attr_celeba.txt' % (root_dir) img_data_path = '%s/celebA/img_align_celeba' % (root_dir) trainset = CelebAAttributeDataset(attr_o_i, list_attr_path, img_data_path, data_split='train', transform=transform) testset = CelebAAttributeDataset(attr_o_i, list_attr_path, img_data_path, data_split='test', transform=transform) ref_model = CelebADNN(model_type=REF, pretrained=False, gpu=GPU) ref_model.load_state_dict(torch.load('../class_models/celeba_%s_%s.model' % (attr_o_i, REF))) preprocess_std = (0.5, 0.5, 0.5) elif TASK == 'celebaid': image_size = 224 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) num_class = 10 trainset = CelebAIDDataset(root_dir=root_dir, is_train=True, transform=transform, preprocess=False, random_sample=False, n_id=num_class) testset = CelebAIDDataset(root_dir=root_dir, is_train=False, transform=transform, preprocess=False, random_sample=False, n_id=num_class) ref_model = CelebAIDDNN(model_type=REF, num_class=num_class, pretrained=False, gpu=GPU).eval() ref_model.load_state_dict(torch.load('../class_models/celeba_id_%s.model' % (REF))) preprocess_std = (0.5, 0.5, 0.5) elif TASK == 'dogcat2': BATCH_SIZE = 32 image_size = 224 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), ]) num_class = 2 trainset = DogCatDataset(root_dir=f'{constants.ROOT_DATA_PATH}/dogcat', mode='train', transform=transform) testset = DogCatDataset(root_dir=f'{constants.ROOT_DATA_PATH}/dogcat', mode='test', transform=transform) ref_model = NClassDNN(model_type=REF, pretrained=False, gpu=GPU, n_class=num_class) ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % ('dogcat2', REF))) preprocess_std = (1, 1, 1) elif TASK == 'cifar10': # mean = (0.485, 0.456, 0.406) # std = (0.229, 0.224, 0.225) mean, std = constants.plot_mean_std(TASK) transform = model_settings.get_data_transformation(TASK, args) model_file_name = model_settings.get_model_file_name(TASK, args) trainset = torchvision.datasets.CIFAR10(root='%s/'%(constants.ROOT_DATA_PATH), train=True, download=True, transform=transform) testset = torchvision.datasets.CIFAR10(root='%s/'%(constants.ROOT_DATA_PATH), train=False, download=True, transform=transform) ref_model = CifarDNN(model_type=REF, gpu=GPU, pretrained=False, img_size=cifar10_img_size) ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % (model_file_name, REF))) if GPU: ref_model.cuda() preprocess_std = std elif TASK == 'mnist': mean, std = constants.plot_mean_std(TASK) transform = model_settings.get_data_transformation(TASK, args) model_file_name = model_settings.get_model_file_name(TASK, args) num_class = 10 trainset = torchvision.datasets.MNIST(root='%s/'%(constants.ROOT_DATA_PATH), train=True, download=True, transform=transform) testset = torchvision.datasets.MNIST(root='%s/'%(constants.ROOT_DATA_PATH), train=False, download=True, transform=transform) ref_model = MNISTDNN(model_type=REF, gpu=GPU, n_class=num_class, img_size=mnist_img_size) ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % (model_file_name, REF))) preprocess_std = std elif TASK == 'celeba2': image_size = 224 num_class = 2 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) root_dir = f"{constants.ROOT_DATA_PATH}" trainset = CelebABinIDDataset(poi=args.celeba_poi, root_dir=root_dir, is_train=True, transform=transform, get_data=False, random_sample=False, n_id=10) testset = CelebABinIDDataset(poi=args.celeba_poi, root_dir=root_dir, is_train=False, transform=transform, get_data=False, random_sample=False, n_id=10) ref_model = NClassDNN(model_type=REF, pretrained=False, gpu=GPU, n_class=num_class) ref_model.load_state_dict(torch.load('../class_models/%s_%s.model' % ('celeba2', REF))) preprocess_std = (0.5, 0.5, 0.5) else: assert 0 return ref_model, trainset, testset, preprocess_std
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--cifar10_img_size', type=int, default=32) parser.add_argument('--cifar10_padding_size', type=int, default=0) parser.add_argument('--cifar10_padding_first', action='store_true') parser.add_argument('--do_train', action='store_true') args = parser.parse_args() img_size = args.cifar10_img_size mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) transform = model_settings.get_data_transformation("cifar10", args) # model_file_name, _ = model_settings.get_model_file_name("cifar10", args) model_file_name = model_settings.get_model_file_name("cifar10", args) print(model_file_name) trainset = torchvision.datasets.CIFAR10(root='../raw_data/', train=True, download=True, transform=transform) testset = torchvision.datasets.CIFAR10(root='../raw_data/', train=False, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) testloader = torch.utils.data.DataLoader(testset,
parser.add_argument('--mnist_img_size', type=int, default=28) parser.add_argument('--mnist_padding_size', type=int, default=0) parser.add_argument('--mnist_padding_first', action='store_true') parser.add_argument('--cifar10_img_size', type=int, default=32) parser.add_argument('--cifar10_padding_size', type=int, default=0) parser.add_argument('--cifar10_padding_first', action='store_true') parser.add_argument('--pretrained', action='store_true') args = parser.parse_args() GPU = True TASK = args.TASK if TASK == 'mnist' or TASK == 'cifar10': TASK = model_settings.get_model_file_name(TASK, args) #TODO mnist N_Z 3136 model_file_name = TASK if TASK.startswith('mnist'): n_channels = 1 else: n_channels = 3 N_Z = args.N_Z modelD = GANDiscriminator(n_channels=n_channels, gpu=GPU) modelG = GANGenerator(n_z=N_Z, n_channels=n_channels, gpu=GPU) if args.pretrained: modelD.load_state_dict(torch.load('./gen_models/normal_%s_gradient_gan_%d_discriminator.model' % (TASK, N_Z))) modelG.load_state_dict(torch.load('./gen_models/normal_%s_gradient_gan_%d_generator.model' % (TASK, N_Z)))
def calc_internal_dim(task, pgen_type, args): B = 5000 BATCH_SIZE = 32 p_gen = attack_setting.load_pgen(task=task, pgen_type=pgen_type, args=args) input_size = attack_setting.pgen_input_size(task=task, pgen_type=pgen_type, args=args) latent_data = torch.randn((B, *input_size)) if p_gen is not None: if args.use_gpu: latent_data = latent_data.cuda() # projected_data = p_gen.project(latent_data) projected_np = None for _i in range(int(B / BATCH_SIZE) + 1): _data = latent_data[_i * BATCH_SIZE:(_i + 1) * BATCH_SIZE] _B = _data.shape[0] if _B < 1: break _projected = p_gen.project(_data) _np = _projected.detach().cpu().numpy().reshape(_B, -1) if projected_np is None: projected_np = _np else: projected_np = np.concatenate((projected_np, _np), axis=0) else: projected_np = latent_data.numpy().reshape(B, -1) model_file_name = model_settings.get_model_file_name(TASK=task, args=args) if args.do_svd: print('Doing svd...') u, s, v = np.linalg.svd(projected_np, full_matrices=False) np.save( 'BAPP_result/%s_%s_internal_dim_u.npy' % (model_file_name, pgen_type), u) np.save( 'BAPP_result/%s_%s_internal_dim_s.npy' % (model_file_name, pgen_type), s) np.save( 'BAPP_result/%s_%s_internal_dim_v.npy' % (model_file_name, pgen_type), v) else: u = np.load('BAPP_result/%s_%s_internal_dim_u.npy' % (model_file_name, pgen_type)) s = np.load('BAPP_result/%s_%s_internal_dim_s.npy' % (model_file_name, pgen_type)) v = np.load('BAPP_result/%s_%s_internal_dim_v.npy' % (model_file_name, pgen_type)) projected_np = u.dot(np.diag(s)).dot(v) cos_sims = [] s_keep = np.zeros(s.shape) with tqdm(range(s.shape[0])) as pbar: for i in pbar: s_keep[i] = s[i] if i % inter_gap == 0: slice = u.dot(np.diag(s_keep)).dot(v) cos_sim = np.mean( utils.calc_cos_sim(x1=projected_np, x2=slice, dim=1)) cos_sims.append(cos_sim) pbar.set_description('Keep dim %d, cosine similarity %f' % (i, cos_sim)) np.save( 'BAPP_result/%s_%s_internal_dim.npy' % (model_file_name, pgen_type), cos_sims) if cos_sim > 0.9999: break print("%s, %s, Cosine similarity for internal dim done" % (model_file_name, pgen_type)) print(cos_sims)
def load_pgen(task, pgen_type, args): if task == 'imagenet' or task == 'celeba' or task == 'celebaid' or task == 'celeba2': if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize9408': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT2352': p_gen = DCTGenerator(factor=8.0) elif pgen_type == 'DCT4107': p_gen = DCTGenerator(factor=6.0) elif pgen_type == 'DCT9408': p_gen = DCTGenerator(factor=4.0) elif pgen_type == 'DCT16428': p_gen = DCTGenerator(factor=3.0) elif pgen_type == 'PCA9408': p_gen = PCAGenerator(N_b=9408, approx=True, basis_only=True) # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (task, 9408)) p_gen.load('./gen_models/pca_gen_%s_%d.npy' % ('imagenet', 9408)) elif pgen_type == 'AE128': p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=128) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_ae_%d_generator.model' % (task, 128))) elif pgen_type == 'AE9408': p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=9408) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_ae_%d_generator.model' % (task, 9408))) elif pgen_type == 'VAE9408': p_gen = VAEGenerator(n_channels=3, gpu=args.use_gpu) if args.TASK == 'celeba' and args.smooth: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_smooth%s_generator.model' % (task, 9408, args.smooth_suffix))) else: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_generator.model' % (task, 9408))) elif pgen_type == 'GAN128': p_gen = GANGenerator(n_z=128, n_channels=3, gpu=args.use_gpu) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_gan_%d_generator.model' % (task, 128))) elif pgen_type == 'GAN9408': p_gen = GANGenerator(n_z=9408, n_channels=3, gpu=args.use_gpu) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_gan_%d_generator.model' % (task, 9408))) elif pgen_type == 'oldAE9408': p_gen = OldAEGenerator(n_channels=3, gpu=args.use_gpu) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_old_ae_%d_generator.model' % (task, 9408))) elif pgen_type == 'expcos9408': p_gen = ExpCosGenerator(n_channels=3, gpu=args.use_gpu, N_Z=9408, lmbd=args.lmbd) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_expcos_%d_generator.model' % (task, 9408))) elif task == 'cifar10': model_file_name = model_settings.get_model_file_name("cifar10", args) n_channels = 3 if args.cifar10_img_size == 32: if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize192': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT192': p_gen = DCTGenerator(factor=4.0) elif pgen_type == 'AE192': p_gen = Cifar10AEGenerator(gpu=args.use_gpu) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_ae_%d_generator.model' % (task, 192))) elif args.cifar10_img_size == 224: if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize9408': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT9408': p_gen = DCTGenerator(factor=4.0) elif pgen_type == 'PCA9408': N_b = 9408 approx = True p_gen = PCAGenerator(N_b=N_b, approx=approx, basis_only=True) # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (model_file_name, 9408)) p_gen.load('./gen_models/pca_gen_%s_%d.npy' % ('imagenet', 9408)) elif pgen_type.startswith('AE'): N_Z = int(pgen_type[2:]) p_gen = AEGenerator(n_channels=3, preprocess=None, gpu=args.use_gpu, N_Z=N_Z) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_ae_%d_generator.model' % (model_file_name, N_Z))) elif pgen_type.startswith('GAN'): N_Z = int(pgen_type[3:]) p_gen = GANGenerator(n_z=N_Z, n_channels=n_channels, gpu=args.use_gpu) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_gan_%d_generator.model' % (model_file_name, N_Z))) elif pgen_type == 'DCGAN': p_gen = Cifar10DCGenerator(ngpu=1).eval() p_gen.load_state_dict( torch.load('./models/weights/cifar10_netG_epoch_199.pth'), strict=False) elif pgen_type.startswith('DCGAN_finetune'): n_epoch = int(pgen_type[-1]) p_gen = Cifar10DCGenerator(ngpu=1).eval() # p_gen.load_state_dict(torch.load('./models/weights/cifar10_224_netG_100_epoch%d.pth'%(n_epoch)), strict=False) p_gen.load_state_dict(torch.load( './gen_models/cifar10_224_netG_100_epoch%d.pth' % (n_epoch)), strict=False) elif pgen_type.startswith('VAE'): N_Z = int(pgen_type[3:]) assert N_Z == 9408 p_gen = VAEGenerator(n_channels=n_channels, gpu=args.use_gpu) if args.smooth: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_smooth%s_generator.model' % (model_file_name, int( pgen_type[3:]), args.smooth_suffix))) else: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_generator.model' % (model_file_name, int(pgen_type[3:])))) else: print('cifar10', args.cifar10_img_size, pgen_type, "Not implemented") assert 0 elif task == 'mnist': # model_file_name, output_file_name = model_settings.get_model_file_name("mnist", args) model_file_name = model_settings.get_model_file_name("mnist", args) n_channels = 1 if args.mnist_img_size == 28: if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT': p_gen = DCTGenerator(factor=4.0) elif pgen_type.startswith('AE'): p_gen = MNISTAEGenerator(gpu=args.use_gpu, N_Z=int(pgen_type[2:])) p_gen.load_state_dict( torch.load( './gen_models/%s_%d_gradient_ae_%d_generator.model' % (task, args.mnist_img_size, int(pgen_type[2:])))) elif pgen_type.startswith('GAN'): print("Not implemented yet") assert 0 elif pgen_type.startswith('VAE'): print("Not implemented yet") assert 0 elif args.mnist_img_size == 224: if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize9408': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT9408': p_gen = DCTGenerator(factor=4.0) elif pgen_type == 'PCA9408': N_b = 9408 approx = True p_gen = PCAGenerator(N_b=N_b, approx=approx, basis_only=True) # p_gen.load('./gen_models/pca_gen_%s_%d.npy' % (model_file_name, 9408)) p_gen.load('./gen_models/pca_gen_%s_%d.npy' % ('imagenet', 9408)) elif pgen_type.startswith('AE'): if pgen_type.endswith('9408'): p_gen = AEGenerator(n_channels=n_channels, preprocess=None, gpu=args.use_gpu, N_Z=9408) else: print("Not implemented yet") assert 0 p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_ae_%d_generator.model' % (model_file_name, int(pgen_type[2:])))) elif pgen_type.startswith('GAN'): assert int(pgen_type[3:]) == 9408 p_gen = GANGenerator(n_z=int(pgen_type[3:]), n_channels=n_channels, gpu=args.use_gpu) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_gan_%d_generator.model' % (model_file_name, int(pgen_type[3:])))) elif pgen_type.startswith('DCGAN'): p_gen = MNISTDCGenerator(ngpu=1).eval() p_gen.load_state_dict( torch.load('./models/weights/MNIST_netG_epoch_99.pth'), strict=False) elif pgen_type.startswith('VAE'): if int(pgen_type[3:]) == 9408: p_gen = VAEGenerator(n_channels=n_channels, gpu=args.use_gpu) if args.smooth: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_smooth%s_generator.model' % (model_file_name, 9408, args.smooth_suffix))) else: p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_generator.model' % (model_file_name, 9408))) elif int(pgen_type[3:]) == 3136: p_gen = MNIST224VAEGenerator(n_channels=n_channels, gpu=args.use_gpu) p_gen.load_state_dict( torch.load( './gen_models/%s_gradient_vae_%d_generator.model' % (model_file_name, 3136))) else: print('mnist', args.mnist_img_size, pgen_type, "Not implemented") assert 0 elif task == 'dogcat2': if pgen_type == 'naive': p_gen = None elif pgen_type == 'resize9408': p_gen = ResizeGenerator(factor=4.0) elif pgen_type == 'DCT9408': p_gen = DCTGenerator(factor=4.0) elif pgen_type == 'AE9408': p_gen = AEGenerator(n_channels=3, gpu=args.use_gpu, N_Z=9408) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_ae_%d_generator.model' % (task, 9408))) elif pgen_type == 'GAN9408': p_gen = GANGenerator(n_z=9408, n_channels=3, gpu=args.use_gpu) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_gan_%d_generator.model' % (task, 9408))) elif pgen_type == 'VAE9408': p_gen = VAEGenerator(n_channels=3, gpu=args.use_gpu) p_gen.load_state_dict( torch.load('./gen_models/%s_gradient_vae_%d_generator.model' % (task, 9408))) return p_gen
# args.padding_size # 0: no padding # for simplicity, if input size [A*A], output size [B*B], then padding_size = (B-A)/2 parser.add_argument('--mnist_padding_size', type=int, default=0) parser.add_argument('--mnist_padding_first', action='store_true') parser.add_argument('--mounted', action='store_true') parser.add_argument('--N_Z', type=int, default=9408) args = parser.parse_args() mnist_img_size = args.mnist_img_size n_class = 10 TASK = 'mnist' TASK = model_settings.get_model_file_name(TASK, args) transform = model_settings.get_data_transformation("mnist", args) model_file_name = model_settings.get_model_file_name("mnist", args) print(model_file_name) trainset = torchvision.datasets.MNIST(root='../raw_data/', train=True, download=True, transform=transform) testset = torchvision.datasets.MNIST(root='../raw_data/', train=False, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True) if TASK.startswith('mnist'): n_channels = 1 else: n_channels = 3