def generate_images(self): """ Generates images using trained model. Returns: Creates directory for output images to be stored in. """ opt = self.opt if opt.swap_style is not None: # Generate directory of style that we want to swap style = opt.style_input mode = opt.mode opt.style_input = opt.swap_style opt.mode = 'train' swapdir = generate_dir2save(opt) opt.style_input = style opt.mode = mode # Not quite as elegant, but this is the only way to prevent multiple explicit path definitions (changing the path naming scheme would be far more troublesome). if (os.path.exists(swapdir)): swap_Gs = torch.load('%s/Gs.pth' % swapdir) styles = torch.load('%s/styles.pth' % swapdir) else: print("Style swap failed") return self.Gs = swap_style_weights(self.Gs, swap_Gs) self.styles = styles if opt.mode == 'random_samples': #in_s = generate_in2coarsest(styles, 1, 1, opt) if opt.gen_all_scales: for n in range(len(self.Gs)): print("Generating random samples on scale %d" % n) opt.gen_start_scale = n SinGAN_generate(self.Gs, self.Zs, self.reals, self.styles, self.NoiseAmp, opt, gen_start_scale=opt.gen_start_scale) else: SinGAN_generate(self.Gs, self.Zs, self.reals, self.styles, self.NoiseAmp, opt, gen_start_scale=opt.gen_start_scale) # elif opt.mode == 'random_samples_arbitrary_sizes': # in_s = generate_in2coarsest(styles, opt.scale_v, opt.scale_h, opt) # SinGAN_generate(Gs, Zs, styles, NoiseAmp, opt, in_s, scale_v=opt.scale_v, scale_h=opt.scale_h) return
def __init__(self, Generator, Discriminator, opt): self.Generator = Generator self.Discriminator = Discriminator self.opt = opt ### Set parameters for the training of the 0th layer self.Gs = [] # Generator list for each scale self.Zs = [] # Optimal noise list for each scale [z*, 0, 0, ..., 0] self.NoiseAmp = [ ] # Ratio of noise when merging with the output of the previous layer for each scale self.in_s = 0 # 0 Tensor with the downsampled dimensions of the input image for scale 0 ### TrainedModel Directory dir2save = generate_dir2save(self.opt) if (os.path.exists(dir2save)): print( "Would you look at that, the TrainedModel directory already exists!" ) else: try: os.makedirs(dir2save) except OSError: print("Making the directory really didn't work out, hyelp") # In case we're not training, load existing model if self.opt.mode != 'train': self.Gs, self.Zs, _, _, self.NoiseAmp = load_trained_pyramid( self.opt) # We might wish to replace content or style images if self.opt.test_content is not None: self.opt.content = self.opt.test_content if self.opt.test_style is not None: self.opt.style = self.opt.test_style ### Content image pyramid self.real_ = read_image(self.opt) self.style_ = read_image(self.opt, style=True) if self.style_.shape != self.real_.shape: self.style_ = imresize_to_shape( self.style_, [self.real_.shape[2], self.real_.shape[3]], opt) self.style_ = self.style_[:, :, :self.real_.shape[2], :self.real_. shape[3]] # "adjust_scales2image" also arranges network parameters according to input dimensions assert self.real_.shape == self.style_.shape self.real = adjust_scales2image(self.real_, self.opt) self.reals = create_reals_pyramid(self.real, self.opt) self.style = imresize(self.style_, self.opt.scale1, self.opt) self.styles = create_reals_pyramid(self.style, self.opt)
def train(self): """ Trains GAN for niter epochs over stop_scale number of scales. Main training loop that calls train_scale. Controls transition between layers. After training is done for a certain layer, freezes weights of the trained scale, and arranges computational graph by changing requires_grad parameters. """ scale_num = 0 nfc_prev = 0 ### Visualization # For visualization, let's just for now do maximal image dimensions self.opt.viswindows = [ ] # Windows in visdom that is updated during training G(z_opt) self.max_width = convert_image_np(self.real).shape[0] self.max_height = convert_image_np(self.real).shape[1] ### Load the VGG network vgg = VGG() vgg.load_state_dict( torch.load(self.opt.pretrained_VGG, map_location=self.opt.device)) self.vgg = vgg.to(self.opt.device) # Make sure this network is frozen for parameter in self.vgg.parameters(): parameter.requires_grad_(False) # Training loop for each scale while scale_num < self.opt.stop_scale + 1: # Number of filters in D and G changes every 4th scale self.opt.nfc = min( self.opt.nfc_init * pow(2, math.floor(scale_num / 4)), 128) self.opt.min_nfc = min( self.opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), 128) # Create output directory and save the downsampled image self.opt.out_ = generate_dir2save(self.opt) self.opt.outf = '%s/%d' % (self.opt.out_, scale_num) try: os.makedirs(self.opt.outf) except OSError: pass #plt.imsave('%s/in.png' % (self.opt.out_), convert_image_np(self.real), vmin=0, vmax=1) #plt.imsave('%s/original.png' % (self.opt.out_), convert_image_np(real_), vmin=0, vmax=1) plt.imsave('%s/real_scale.png' % (self.opt.outf), convert_image_np(self.reals[scale_num]), vmin=0, vmax=1) # Initialize D and G of the current scale. D and G will be initialized with the previous scale's weights if the dimensions match. D_curr, G_curr = self.init_models() if (nfc_prev == self.opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (self.opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (self.opt.out_, scale_num - 1))) # Training of single scale z_curr, G_curr = self.train_scale(G_curr, D_curr, self.opt) # Stop gradient calculation for G and D of current scale G_curr = reset_grads(G_curr, False) G_curr.eval() D_curr = reset_grads(D_curr, False) D_curr.eval() # Store the necessary variables of this scale self.Gs.append(G_curr) self.Zs.append(z_curr) self.NoiseAmp.append(self.opt.noise_amp) # Save the networks and important parameters torch.save(self.Zs, '%s/Zs.pth' % (self.opt.out_)) torch.save(self.Gs, '%s/Gs.pth' % (self.opt.out_)) torch.save(self.reals, '%s/reals.pth' % (self.opt.out_)) torch.save(self.styles, '%s/styles.pth' % (self.opt.out_)) torch.save(self.NoiseAmp, '%s/NoiseAmp.pth' % (self.opt.out_)) scale_num += 1 nfc_prev = self.opt.nfc # Update the number of filters del D_curr, G_curr # Generate with training variables SinGAN_generate(self.Gs, self.Zs, self.reals, self.styles, self.NoiseAmp, self.opt)