def get_BigGAN(version="biggan-deep-256"): cache_path = "/scratch/binxu/torch/" cfg = BigGANConfig.from_json_file( join(cache_path, "%s-config.json" % version)) BGAN = BigGAN(cfg) BGAN.load_state_dict( torch.load(join(cache_path, "%s-pytorch_model.bin" % version))) return BGAN
def generate_model_file(model_type, device): model: nn.Module = BigGAN.from_pretrained(model_type) model.to(device) model.eval() truncation = torch.tensor([1.0]).to(device) for p in model.parameters(): p.require_grads = False # Remove the spectral norm from all layers, we can do this because we only do inference for module in model.modules(): try: torch.nn.utils.remove_spectral_norm(module) except (AttributeError, ValueError): pass # Do a JIT precompute for additional speedup model = torch.jit.trace( model, ( create_noise_vector(device=device), create_class_vector(643, device=device), truncation, ), ) torch.jit.save( model, (get_asset_folder() / "networks" / (model_type + "-" + device)).__str__(), )
def __init__(self): super().__init__() self.biggan = BigGAN.from_pretrained('biggan-deep-512') self.image_size = 512 # Monkey patch forward methods for n_cuts BigGAN.forward = patch_biggan_forward Generator.forward = patch_generator_forward # NOTE - because each resblock reduces channels and # then increases, we cannot skip into the middle. # If we did, we would have no way to add channels # to the skip connection ("x0") self.input_shapes = [ ((128, ), (128, )), # Raw input shape ((2048, 4, 4), (256, )), # Linear ((2048, 4, 4), (256, )), # Block ((2048, 8, 8), (256, )), # Block Up ((2048, 8, 8), (256, )), # Block ((1024, 16, 16), (256, )), # Block Up ((1024, 16, 16), (256, )), # Block ((1024, 32, 32), (256, )), # Block Up ((1024, 32, 32), (256, )), # Block ((512, 64, 64), (256, )), # Block Up ((512, 64, 64), (256, )), # Self-Attention block ((512, 64, 64), (256, )), # Block ((256, 128, 128), (256, )), # Block Up ((256, 128, 128), (256, )), # Block ((128, 256, 256), (256, )), # Block Up ((128, 256, 256), (256, )), # Block ((128, 512, 512), (256, )), # Block Up ((3, 512, 512), ()), # Final Conv ]
def load_model(resolution: int = 512): print('[info] loading pre-trained model...') model_name = f'biggan-deep-{resolution}' pretrained_path = Path.home().joinpath('.pytorch_pretrained_biggan') model = BigGAN.from_pretrained(pretrained_path) print('[info] loading complete.\n') return model
def __init__(self, image_sentiment_model=join(config['basedir'], "models/image_sentiment.model"), noise=0.3, in_batch=4): # Load GAN self.gan_model = BigGAN.from_pretrained('biggan-deep-512') self.gan_model.to(config['device']) # Load image sentiment analysis model self.sentiment_model = get_pretrained_mobile_net() self.sentiment_model.to(config['device']) self.sentiment_model.load_state_dict(torch.load(image_sentiment_model)) self.in_batch = in_batch self.noise = noise self.n_iters = int(1000 / in_batch) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.transform_image = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), normalize, ])
def main(): acc_epoch = [] loss_epoch = [] ''' if args.cloud: isExists = os.path.exists(basic_path) if not isExists: os.makedirs(basic_path) ''' img_dir = args.train_url + args.img_dir raw_img = Image.open(img_dir) input = trans(raw_img) input = input.view(1, input.size(0), input.size(1), input.size(2)) print('input:', input.size()) print('=========> Load models from train_url') checkpoint_dir = os.path.join(args.train_url, 'resnet50-19c8e357.pth') print('checkpoint_dir:', checkpoint_dir) checkpoint = torch.load(checkpoint_dir) model = networks_imagenet.resnet.resnet50() model.load_state_dict(checkpoint) netG = BigGAN.from_pretrained('biggan-deep-512', model_dir=args.train_url) print('=========> Load models ended') # _, feature = model.forward(input, isda=True) # print(feature.size()) train(input, target, netG, model)
def build_models(self): print('Loading BigGAN...') self.biggan = BigGAN.from_pretrained(self.args.model).eval().to(self.args.device) print('Loading vgg...') self.vgg = torchvision.models.vgg.vgg16(pretrained=True).eval().to(self.args.device) print('Building other models...') bgc = self.biggan.config self.z_approximator = ZApproximator(self.biggan, vgg=self.vgg)
def model_resolution(resolution): """ set model's resolution, default 128 128, 256, or 512 lower = faster generation, lower quality. """ model_name = 'biggan-deep-' + resolution model = BigGAN.from_pretrained(model_name) return model
def set_model(self, resolution): model_name = 'biggan-deep-' + str(resolution) self.big_gan = BigGAN.from_pretrained(model_name) args = self.args self.device, self.multi_gpu = get_device(args) if self.multi_gpu: self.big_gan = nn.DataParallel( self.big_gan, device_ids=list(range(args.gpu_1st, args.gpu_1st + args.ngpu)), output_device=args.gpu_1st) self.big_gan.to(self.device)
def gan_network(self, resolution): """Sets up the BigGAN model from the default file used by this application. Args: resolution (int): The resolution of BIGGAN to load Returns: BigGAN: A BigGAN model """ logger.info(f'Loading BigGAN with resolution {resolution}.') model = BigGAN.from_pretrained(f'biggan-deep-{resolution}') return model
def loadBigGAN(version="biggan-deep-256"): from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, BigGANConfig if platform == "linux": cache_path = "/scratch/binxu/torch/" cfg = BigGANConfig.from_json_file( join(cache_path, "%s-config.json" % version)) BGAN = BigGAN(cfg) BGAN.load_state_dict( torch.load(join(cache_path, "%s-pytorch_model.bin" % version))) else: BGAN = BigGAN.from_pretrained(version) for param in BGAN.parameters(): param.requires_grad_(False) # embed_mat = BGAN.embeddings.parameters().__next__().data BGAN.cuda() return BGAN
def generate_image(dense_class_vector=None, name=None, noise_seed_vector=None, truncation=0.4, gan_model=None, pretrained_gan_model_name='biggan-deep-128'): """ Utility function to generate an image (numpy uint8 array) from either: - a name (string): converted in an associated ImageNet class and then a dense class embedding using BigGAN's internal ImageNet class embeddings. - a dense_class_vector (torch.Tensor with 128 elements): used as a replacement of BigGAN internal ImageNet class embeddings. Other args: - noise_seed_vector: a vector used to control the seed (seed set to the sum of the vector elements) - truncation: a float between 0 and 1 to control image quality/diversity tradeoff (see BigGAN paper) - gan_model: a BigGAN model from pytorch_pretrained_biggan library. If None a model is instanciated from a pretrained model name given by `pretrained_gan_model_name` List of possible names: https://github.com/huggingface/pytorch-pretrained-BigGAN#models - pretrained_gan_model_name: shortcut name of the GAN model to instantiate if no gan_model is provided. Default to 'biggan-deep-128' """ seed = int(noise_seed_vector.sum().item() ) if noise_seed_vector is not None else None noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) noise_vector = torch.from_numpy(noise_vector) if gan_model is None: gan_model = BigGAN.from_pretrained(pretrained_gan_model_name) if name is not None: class_vector = one_hot_from_names([name], batch_size=1) class_vector = torch.from_numpy(class_vector) dense_class_vector = gan_model.embeddings(class_vector) # input_vector = torch.cat([noise_vector, gan_class_vect.unsqueeze(0)], dim=1) # dense_class_vector = torch.matmul(class_vector, gan.embeddings.weight.t()) else: dense_class_vector = dense_class_vector.view(1, 128) input_vector = torch.cat([noise_vector, dense_class_vector], dim=1) # Generate an image with torch.no_grad(): output = gan_model.generator(input_vector, truncation) output = output.cpu().numpy() output = output.transpose((0, 2, 3, 1)) output = ((output + 1.0) / 2.0) * 256 output.clip(0, 255, out=output) output = np.asarray(np.uint8(output[0]), dtype=np.uint8) return output
def main(): acc_epoch = [] loss_epoch = [] ''' if args.cloud: isExists = os.path.exists(basic_path) if not isExists: os.makedirs(basic_path) ''' img_dir = args.train_url + args.img_dir raw_img = Image.open(img_dir) input = trans(raw_img) # input = input.view(1, input.size(0), input.size(1), input.size(2)) input = torch.unsqueeze(input, 0) print('input:', input.size()) print('=========> Load models from train_url') checkpoint_dir = os.path.join(args.train_url, 'resnet50-19c8e357.pth') print('checkpoint_dir:', checkpoint_dir) checkpoint = torch.load(checkpoint_dir) model = networks_imagenet.resnet.resnet50() model.load_state_dict(checkpoint) netG = BigGAN.from_pretrained('biggan-deep-512', model_dir=args.train_url) def make_layers(cfg, batch_norm=False): layers = [] in_channels = 3 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) vgg_config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] vgg16 = VGG(make_layers(vgg_config, batch_norm=False)) vgg_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth', progress=True) vgg16.load_state_dict(vgg_state_dict) print('=========> Load models ended') train(input, target, netG, model, vgg16)
def build_models(self): print('Loading BigGAN...') self.biggan = BigGAN.from_pretrained(self.args.model).eval().to(self.args.device) print('Loading w2v...') self.w2v = gensim.models.KeyedVectors.load_word2vec_format( self.args.w2v_model_path, binary=True, limit=self.args.vocab_limit ) print('Loading vgg...') self.vgg = torchvision.models.vgg.vgg16(pretrained=True).eval().to(self.args.device) print('Building other models...') bgc = self.biggan.config self.z_approximator = ZApproximator(self.biggan, vgg=self.vgg)#.to(self.args.device) self.semantic_z_encoder = SemanticZEncoder( self.args.semantic_dims, bgc.z_dim, bgc.num_classes ).to(self.args.device) print(self.z_approximator) print(self.semantic_z_encoder)
def get_model(args): model = None if 'BigGAN' in args.model: num_params = int(''.join(filter(str.isdigit, args.model))) model = BigGAN.from_pretrained(f'biggan-deep-{num_params}') elif 'WGAN-GP' in args.model: generator_path = "/deep/group/gen-eval/model-training/src/GAN_models/improved-wgan-pytorch/experiments/exp4_wgan_gp/generator.pt" model = torch.load(generator_path) elif 'BEGAN' in args.model: generator_path = "/deep/group/gen-eval/model-training/src/GAN_models/BEGAN-pytorch/trained_models/64/models/gen_97000.pth" model = models.BEGANGenerator() model.load_state_dict(torch.load(generator_path)) return model
def main(): acc_epoch = [] loss_epoch = [] img_dir = args.train_url + args.img_dir raw_img = Image.open(img_dir) input = trans(raw_img) input = input.view(1, input.size(0), input.size(1), input.size(2)) # input = torch.unsqueeze(input, 0) # print('input:', input.size()) print('=========> Load models from train_url') checkpoint_dir = os.path.join(args.train_url, 'resnet50-19c8e357.pth') print('checkpoint_dir:', checkpoint_dir) checkpoint = torch.load(checkpoint_dir) model = networks_imagenet.resnet.resnet50() model.load_state_dict(checkpoint) netG = BigGAN.from_pretrained('biggan-deep-512', model_dir=args.train_url) train(input, target, netG, model)
def evaluate_inversion(args, inverted_net_path): # Load saved inverted net device = 'cuda:{}'.format( args.gpu_ids[0]) if len(args.gpu_ids) > 0 else 'cpu' ckpt_dict = torch.load(inverted_net_path, map_location=device) # Build model, load parameters model_args = ckpt_dict['model_args'] inverted_net = models.ResNet18(**model_args) inverted_net = nn.DataParallel(inverted_net, args.gpu_ids) inverted_net.load_state_dict(ckpt_dict['model_state']) import pdb pdb.set_trace() # Get test images (CelebA) initial_generated_image_dir = '/deep/group/sharonz/generator/z_test_images/' initial_generated_image_name = '058004_crop.jpg' initial_generated_image = util.get_image(initial_generated_image_dir, initial_generated_image_name) initial_generated_image = initial_generated_image / 255. intiial_generated_image = initial_generated_image.cuda() inverted_noise = inverted_net(initial_generated_image) if 'BigGAN' in args.model: class_vector = one_hot_from_int(207, batch_size=batch_size) class_vector = torch.from_numpy(class_vector) num_params = int(''.join(filter(str.isdigit, args.model))) generator = BigGAN.from_pretrained(f'biggan-deep-{num_params}') generator = generator.to(args.device) generated_image = generator.forward(inverted_noise, class_vector, args.truncation) # Get difference btw initial and subsequent generated image # Save both return
def generate_data(random_state, batch_size, num_images_per_classes, device, output_path): generator_model = BigGAN.from_pretrained("biggan-deep-128", cache_dir=os.path.join( "./data/checkpoint", "cached_model")) generator_model = generator_model.to(device) # prepare a input truncation = 0.4 op_paths.build_dirs(f"{output_path}") for class_idx in range(1000): _id = 0 num_batches = int(num_images_per_classes / batch_size) op_paths.build_dirs(f"{output_path}/{class_idx}") for _ in range(num_batches): class_vector = one_hot_from_int(class_idx, batch_size=batch_size) noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size) noise_vector = torch.from_numpy(noise_vector).to(device) class_vector = torch.from_numpy(class_vector).to(device) # generate images with torch.no_grad(): generated_images = generator_model(noise_vector, class_vector, truncation).clamp(min=-1, max=1) for image in generated_images: torchvision.utils.save_image( image, fp=f"{output_path}/{class_idx}/{_id}", format="JPEG", scale_each=True, normalize=True, ) _id += 1 print(f"finished {class_idx + 1}/1000.")
def main(): pregan = PreGAN() pregan.init_weights() # print(pregan) # pregan.to('cuda') # done by vaseGen biggan = BigGAN.from_pretrained('biggan-deep-512') # generate(biggan) # biggan.to('cuda') # done by vaseGen vaseGen = BothGAN(pregan, biggan, lr=1e-5) vaseGen.to('cuda') data_gen = FragmentDataset() # vase_generate(vaseGen, data_gen) batch_size = 1 n_samples = 100 retrain(vaseGen, data_gen, n_samples, batch_size) while True: vase_generate(vaseGen, data_gen)
def main(): args = parser.parse_args() import logging logging.basicConfig(level=logging.INFO) # Load pre-trained model tokenizer (vocabulary) global model model = BigGAN.from_pretrained(args.model_dir).to('cuda') label_str = args.labels.strip() labels = [l.replace('_', ' ') for l in label_str.split(',') if len(l)>0] class_base_vecs = one_hot_from_names(labels) label_alt_str = args.labels_alt.strip() labels_alt = [l.replace('_', ' ') for l in label_alt_str.split(',') if len(l)>0] print(labels, labels_alt) if len(labels_alt) > 0: assert len(labels_alt) == len(labels) c1 = one_hot_from_names(labels_alt) class_base_vecs = args.mixture_prop * class_base_vecs + (1-args.mixture_prop) * c1 outs = [] labels = [] for _ in trange(0, args.n_samples // args.batch_size): # Prepare a input cls = np.random.randint(0, class_base_vecs.shape[0], size=(args.batch_size,)) class_vector = class_base_vecs[cls] noise_vector = truncated_noise_sample( truncation=args.truncation, batch_size=args.batch_size) outs.append(gen_image( noise_vector, class_vector, args.crop_ratio, args.truncation)) labels.append(cls) outs = np.concatenate(outs) labels = np.concatenate(labels) np.savez(args.dump_dir+'.npz', args=vars(args), samples=outs, labels=labels) Image.fromarray(tile_images(outs[:81])).save(args.dump_dir+'.samples-81.png') Image.fromarray(tile_images(outs[:16])).save(args.dump_dir+'.samples-16.png')
def __init__(self, sound_model, translator_model, stylizer_model=None, logo_path=join(config['basedir'], 'resources/deep_sing.png')): # Model for translating sound into sentiment self.sound_model = sound_model # Model for translating sentiment into the gan space self.translator_model = translator_model # Load GAN self.gan_model = BigGAN.from_pretrained('biggan-deep-512') self.gan_model.to(config['device']) # Model for stylizing the generated imaegs self.stylizer = stylizer_model if logo_path is not None: self.logo = cv2.imread(logo_path) else: self.logo = None print("deepsing hyper-space synapses loaded!")
smooth_factor=int(args.smooth_factor * 512 / frame_length) else: smooth_factor=args.smooth_factor #set duration if args.duration: seconds=args.duration frame_lim=int(np.floor(seconds*22050/frame_length/batch_size)) else: frame_lim=int(np.floor(len(y)/sr*22050/frame_length/batch_size)) # Load pre-trained model model = BigGAN.from_pretrained(model_name) #set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ######################################## ######################################## ######################################## ######################################## ######################################## #create spectrogram
tokenizerG = GPT2Tokenizer.from_pretrained("gpt2") modelG = GPT2LMHeadModel.from_pretrained("gpt2") modelG.to("cuda") print("Done") print("XLNet Time") tokenizerX = XLNetTokenizer.from_pretrained("xlnet-base-cased") modelX = XLNetLMHeadModel.from_pretrained("xlnet-base-cased") print("BigGan Time!") from pytorch_pretrained_biggan import ( BigGAN, one_hot_from_names, truncated_noise_sample, convert_to_images, ) modelBG = BigGAN.from_pretrained("biggan-deep-256") modelX.to("cuda") print("All prep complete!") labels = { int(key): value for (key, value) in requests.get( "https://s3.amazonaws.com/outcome-blog/imagenet/labels.json" ) .json() .items() } detect_labels = { int(key): value for (key, value) in requests.get( "https://gist.githubusercontent.com/RehanSD/6f74a9992848e25658e091148ee20e17/raw/fae1f9f3ee0c3eb20ca9829e99cd8b616f22fa45/cocolabels.json"
import torch from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample, save_as_images, display_in_terminal) # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows import logging logging.basicConfig(level=logging.INFO) # Load pre-trained model tokenizer (vocabulary) model = BigGAN.from_pretrained('biggan-deep-512') # Prepare a input truncation = 0.4 class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3) noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3) # All in tensors noise_vector = torch.from_numpy(noise_vector) class_vector = torch.from_numpy(class_vector) # # If you have a GPU, put everything on cuda # noise_vector = noise_vector.to('cuda') # class_vector = class_vector.to('cuda') # model.to('cuda') # Generate an image with torch.no_grad(): output = model(noise_vector, class_vector, truncation) # If you have a GPU put back on CPU # output = output.to('cpu')
import os, os.path import cv2 import pickle import torch from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample, save_as_images, display_in_terminal) # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows import logging from PIL import Image logging.basicConfig(level=logging.INFO) # Load pre-trained model tokenizer (vocabulary) model = BigGAN.from_pretrained('./BigGAN/model') # Prepare a input truncation = 0.4 class_vector = one_hot_from_names(['lakeshore'], batch_size=1) noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1) # All in tensors noise_vector = torch.from_numpy(noise_vector) class_vector = torch.from_numpy(class_vector) # If you have a GPU, put everything on cuda # noise_vector = noise_vector.to('cuda') # class_vector = class_vector.to('cuda') # model.to('cuda')
parser.add_argument('--save_folder', required=True, type=str, help='folder to save generated images') parser.add_argument('--image_folder', required=True, type=str, help='images to generate for') parser.add_argument('--model', default='deepsim', type=str, choices=['deepsim', 'biggan'], help='which generator model to use for optimizing images') args = parser.parse_args() shutil.rmtree(args.save_folder, ignore_errors=True) os.mkdir(args.save_folder) encoder = alexnet(pretrained=True) encoder.classifier = encoder.classifier[:-1] encoder.eval() if args.model == 'deepsim': generator = DeePSiM() elif args.model == 'biggan': generator = BigGAN.from_pretrained('biggan-deep-256') if torch.cuda.is_available(): encoder.cuda() generator.cuda() for image_file in tqdm(os.listdir(args.image_folder)): image = image_to_tensor(os.path.join(args.image_folder, image_file), resolution=256) with torch.no_grad(): target = encoder(image.unsqueeze(0)).squeeze(0) target_mean_square = (target ** 2).mean().item() generated_image, _, lowest_loss, _ = optimize(generator, encoder, target, F.mse_loss) generated_image = to_pil_image(generated_image) generated_image.save(os.path.join(args.save_folder, image_file)) print('Lowest loss for {}:\t{}\nMean square of target for {}:\t{}\n' .format(image_file, lowest_loss, image_file, target_mean_square))
import os import pickle import numpy as np import PIL.Image import dnnlib import dnnlib.tflib as tflib import config import random app = Quart(__name__) tflib.init_tf() models = { "biggan-deep-512": BigGAN.from_pretrained('biggan-deep-512'), "biggan-deep-256": BigGAN.from_pretrained('biggan-deep-256'), "waifu": pickle.load( open("2019-04-30-stylegan-danbooru2018-portraits-02095-066083.pkl", 'rb'))[-1], "celeb": pickle.load(open("karras2019stylegan-celebahq-1024x1024.pkl", 'rb'))[-1] } def get_model(name="biggan-deep-256"): "Get the deep model from known models" return models[name]
def __init__(self, config): super(DeepMindBigGAN, self).__init__() self.config = config self.G = DMBigGAN.from_pretrained("biggan-deep-256") self.D = None
def reconstruct_image(img_size=224, test_case=1, num_epochs=200, print_iter=10, save_iter=50): # Load pre-trained VGG19 model to extract image features model = models.vgg19(pretrained=True) if use_gpu: model = model.cuda(gpu_id) model.eval() # Generate a random image which we will optimize if use_gpu: recon_img = Variable( 1e-1 * torch.randn(1, 3, img_size, img_size).cuda(gpu_id), requires_grad=True) else: recon_img = Variable(1e-1 * torch.randn(1, 3, img_size, img_size), requires_grad=True) # Define optimizer for previously created image optimizer = optim.SGD([recon_img], lr=1e2, momentum=0.9) # Decay learning rate by a factor of 0.1 every x epochs scheduler = lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1) # Use deep generator network to get initial image if use_dgn: print('Loading deep generator network...') # Load pre-trained model tokenizer (vocabulary) dgn = BigGAN.from_pretrained('biggan-deep-256') # Prepare an input truncation = 0.4 class_vector = np.zeros((1, 1000), dtype='float32') noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1) # All in tensors class_vector = torch.from_numpy(class_vector) noise_vector = torch.from_numpy(noise_vector) if use_gpu: class_vector = class_vector.cuda(gpu_id) noise_vector = noise_vector.cuda(gpu_id) dgn.cuda(gpu_id) # Generate image with torch.no_grad(): output = dgn(noise_vector, class_vector, truncation) output = output.cpu() output = nn.functional.interpolate(output, size=(img_size, img_size), mode='bilinear', align_corners=True) if use_gpu: recon_img = Variable(output.cuda(gpu_id), requires_grad=True) else: recon_img = Variable(output, requires_grad=True) # Training for epoch in range(num_epochs): scheduler.step() optimizer.zero_grad() # Get the features from the model of the generated image output_features = get_features_from_layers(model, recon_img) # Calculate the losses euc_loss, alpha_loss, tv_loss, total_loss = get_total_loss( recon_img, output_features, original_feats) # Step total_loss.backward() optimizer.step() # Generate image every x iterations if (epoch + 1) % print_iter == 0: print('Epoch %d:\tAlpha: %.6f\tTV: %.6f\tEuc: %.6f\tLoss: %.6f' % (epoch + 1, alpha_loss.data.cpu().numpy(), tv_loss.data.cpu().numpy(), euc_loss.data.cpu().numpy(), total_loss.data.cpu().numpy())) # Save the image every x iterations if (epoch + 1) % save_iter == 0: img_sample = torch.squeeze(recon_img.cpu()) im_path = examples_dir + cls + '_' + key + '_fmri.jpg' save_image(img_sample, im_path, normalize=True)
def __init__(self, opt): super().__init__() self.model = BigGAN.from_pretrained('biggan-deep-256') self.imagenet_preprocessing_str = all_models_rb[BenchmarkDataset( opt.dataset_to_use)][ThreatModel(opt.use_robust_bench_threat)][ opt.use_robust_bench_model]['preprocessing']