def create_sr_inference_input(self, real, iter_num): resized_real = real pad = nn.ZeroPad2d(5) finest_G = self.Gs[-1] finest_D = self.Ds[-1] finest_noise_amp = self.noise_amps[-1] self.Zs = [] self.Gs = [] self.Ds = [] self.reals = [] self.noise_amps = [] for i in range(iter_num): resized_real = resize_img(resized_real, pow(1 / self.config.scale_factor, 1), self.config) print(f'{self.config.infer_dir}/{0}_real.png') plt.imsave(f'{self.config.infer_dir}/{0}_real.png', torch2np(resized_real), vmin=0, vmax=1) self.reals.append(resized_real) self.Gs.append(finest_G) self.Ds.append(finest_D) self.noise_amps.append(finest_noise_amp) rec_z = torch.full(resized_real.shape, 0, device=self.config.device) padded_rec_z = pad(rec_z) self.Zs.append(padded_rec_z) return self.reals[0]
def train(self): # Prepare image pyramid train_img = read_img(self.config) real = resize_img(train_img, self.config.start_scale, self.config) self.reals = creat_reals_pyramid(real, self.reals, self.config) prev_nfc = 0 self.writer = SummaryWriter(f'{self.config.exp_dir}/logs') # Pyramid training for scale_iter in range(self.config.stop_scale + 1): # Become larger as scale_iter increase (maximum=128) self.config.nfc = min( self.config.nfc_init * pow(2, math.floor(scale_iter / 4)), 128) self.config.min_nfc = min( self.config.min_nfc_init * pow(2, math.floor(scale_iter / 4)), 128) # Prepare directory to save images self.config.result_dir = f'{self.config.exp_dir}/{scale_iter}' os.makedirs(self.config.result_dir, exist_ok=True) plt.imsave(f'{self.config.result_dir}/real_scale.png', torch2np(self.reals[scale_iter]), vmin=0, vmax=1) cur_discriminator, cur_generator = self.init_models() if prev_nfc == self.config.nfc: cur_generator.load_state_dict( torch.load( f'{self.config.exp_dir}/{scale_iter - 1}/generator.pth' )) cur_discriminator.load_state_dict( torch.load( f'{self.config.exp_dir}/{scale_iter - 1}/discriminator.pth' )) cur_z, cur_generator = self.train_single_stage( cur_discriminator, cur_generator) cur_generator = reset_grads(cur_generator, False) cur_generator.eval() cur_discriminator = reset_grads(cur_discriminator, False) cur_discriminator.eval() self.Gs.append(cur_generator) self.Zs.append(cur_z) self.noise_amps.append(self.config.noise_amp) torch.save(self.Zs, f'{self.config.exp_dir}/Zs.pth') torch.save(self.Gs, f'{self.config.exp_dir}/Gs.pth') torch.save(self.reals, f'{self.config.exp_dir}/reals.pth') torch.save(self.noise_amps, f'{self.config.exp_dir}/noiseAmp.pth') prev_nfc = self.config.nfc del cur_discriminator, cur_generator return
def inference(self, start_img_input): if self.config.save_attention_map: global global_att_dir global global_epoch global_att_dir = f'{self.config.infer_dir}/attention' os.makedirs(global_att_dir, exist_ok=True) if start_img_input is None: start_img_input = torch.full(self.reals[0].shape, 0, device=self.config.device) cur_images = [] for idx, (G, D, Z_opt, noise_amp, real) in enumerate(zip(self.Gs, self.Ds, self.Zs, self.noise_amps, self.reals)): padding_size = ((self.config.kernel_size - 1) * self.config.num_layers) / 2 pad = nn.ZeroPad2d(int(padding_size)) output_h = (Z_opt.shape[2] - padding_size * 2) * self.config.scale_h output_w = (Z_opt.shape[3] - padding_size * 2) * self.config.scale_w prev_images = cur_images cur_images = [] for i in tqdm(range(self.config.num_samples)): if idx == 0: random_z = generate_noise([1, output_h, output_w], device=self.config.device) random_z = random_z.expand(1, 3, random_z.shape[2], random_z.shape[3]) padded_random_z = pad(random_z) else: random_z = generate_noise([self.config.img_channel, output_h, output_w], device=self.config.device) padded_random_z = pad(random_z) if self.config.use_fixed_noise and idx < self.config.gen_start_scale: padded_random_z = Z_opt if not prev_images: padded_random_img = pad(start_img_input) else: prev_img = prev_images[i] upscaled_prev_random_img = resize_img(prev_img, 1 / self.config.scale_factor, self.config) if self.config.mode == "train_SR": padded_random_img = pad(upscaled_prev_random_img) else: upscaled_prev_random_img = upscaled_prev_random_img[:, :, 0:round(self.config.scale_h * self.reals[idx].shape[2]), 0:round(self.config.scale_w * self.reals[idx].shape[3])] padded_random_img = pad(upscaled_prev_random_img) padded_random_img = padded_random_img[:, :, 0:padded_random_z.shape[2], 0:padded_random_z.shape[3]] padded_random_img = upsampling(padded_random_img, padded_random_z.shape[2], padded_random_z.shape[3]) padded_random_img_with_z = noise_amp * padded_random_z + padded_random_img cur_image = G(padded_random_img_with_z.detach(), padded_random_img) np_cur_image = torch2np(cur_image.detach()) if self.config.save_all_pyramid: plt.imsave(f'{self.config.infer_dir}/{i}_{idx}.png', np_cur_image, vmin=0, vmax=1) if self.config.save_attention_map: np_real = torch2np(real) _, _, cur_add_att_maps, cur_sub_att_maps = D(cur_image.detach()) cur_add_att_maps = cur_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) cur_sub_att_maps = cur_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) global_epoch = f'{i}_{idx}thG' parmap.map(save_heatmap, [[np_cur_image, cur_add_att_maps, 'infer_add'], [np_real, cur_sub_att_maps, 'infer_sub']], pm_pbar=False, pm_processes=2) elif idx == len(self.reals) - 1: plt.imsave(f'{self.config.infer_dir}/{i}.png', np_cur_image, vmin=0, vmax=1) if self.config.save_attention_map: np_real = torch2np(real) _, _, cur_add_att_maps, cur_sub_att_maps = D(cur_image.detach()) cur_add_att_maps = cur_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) cur_sub_att_maps = cur_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) global_epoch = f'{i}_{idx}thG' parmap.map(save_heatmap, [[np_cur_image, cur_add_att_maps, 'infer_add'], [np_real, cur_sub_att_maps, 'infer_sub']], pm_pbar=False, pm_processes=2) cur_images.append(cur_image) return cur_image.detach()
def train_single_stage(self, cur_discriminator, cur_generator): real = self.reals[len(self.Gs)] _, _, real_h, real_w = real.shape # Set padding layer(Initial padding) - To Do (Change this for noise padding not zero-padding)g self.config.receptive_field = self.config.kernel_size + ((self.config.kernel_size - 1) * (self.config.num_layers - 1)) * self.config.stride padding_size = int(((self.config.kernel_size - 1) * self.config.num_layers) / 2) noise_pad = nn.ZeroPad2d(int(padding_size)) image_pad = nn.ZeroPad2d(int(padding_size)) # MultiStepLR: lr *= gamma every time reaches one of the milestones D_optimizer = optim.Adam(cur_discriminator.parameters(), lr=self.config.d_lr, betas=(self.config.beta1, self.config.beta2)) G_optimizer = optim.Adam(cur_generator.parameters(), lr=self.config.g_lr, betas=(self.config.beta1, self.config.beta2)) D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=self.config.milestones, gamma=self.config.gamma) G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=self.config.milestones, gamma=self.config.gamma) # Calculate noise amp(amount of info to generate) and recover prev_rec image if not self.Gs: rec_z = generate_noise([1, real_h, real_w], device=self.config.device).expand(1, 3, real_h, real_w) self.first_img_input = torch.full([1, self.config.img_channel, real_h, real_w], 0, device=self.config.device) upscaled_prev_rec_img = self.first_img_input self.config.noise_amp = 1 else: rec_z = torch.full([1, self.config.img_channel, real_h, real_w], 0, device=self.config.device) upscaled_prev_rec_img = self.draw_sequentially('rec', noise_pad, image_pad) criterion = nn.MSELoss() rmse = torch.sqrt(criterion(real, upscaled_prev_rec_img)) self.config.noise_amp = self.config.noise_amp_init * rmse padded_rec_z = noise_pad(rec_z) padded_rec_img = image_pad(upscaled_prev_rec_img) for epoch in tqdm(range(self.config.num_iter), desc=f'{len(self.Gs)}th GAN'): random_z = generate_noise([1, real_h, real_w], device=self.config.device).expand(1, 3, real_h, real_w) \ if not self.Gs else generate_noise([self.config.img_channel, real_h, real_w], device=self.config.device) padded_random_z = noise_pad(random_z) # Train Discriminator: Maximize D(x) - D(G(z)) -> Minimize D(G(z)) - D(X) for i in range(self.config.n_critic): # Make random image input upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad) padded_random_img = image_pad(upscaled_prev_random_img) padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img # Calculate loss with real data cur_discriminator.zero_grad() real_prob_out, real_acm_oth, real_add_att_maps, real_sub_att_maps = cur_discriminator(real) d_real_loss = -real_prob_out.mean() # Maximize D(X) -> Minimize -D(X) # Calculate loss with fake data fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img) fake_prob_out, fake_acm_oth, _, _ = cur_discriminator(fake.detach()) d_fake_loss = fake_prob_out.mean() # Minimize D(G(z)) # Gradient penalty gradient_penalty = calcul_gp(cur_discriminator, real, fake, self.config.device) # Update parameters d_loss = d_real_loss + d_fake_loss + (gradient_penalty * self.config.gp_weights) if self.config.use_acm_oth: d_loss += (torch.abs(real_acm_oth.mean()) + torch.abs(fake_acm_oth.mean())) * self.config.acm_weights d_loss.backward() D_optimizer.step() # Log losses critic = -d_real_loss.item() -d_fake_loss.item() # D(x) - D(G_z) self.log_losses[f'{len(self.Gs)}th_D/d'] = d_loss.item() self.log_losses[f'{len(self.Gs)}th_D/d_critic'] = critic self.log_losses[f'{len(self.Gs)}th_D/d_gp'] = gradient_penalty.item() if self.config.use_acm_oth: self.log_losses[f'{len(self.Gs)}th_D/d_oth'] = real_acm_oth.item() + fake_acm_oth.item() # Train Generator : Maximize D(G(z)) -> Minimize -D(G(z)) for i in range(self.config.generator_iter): cur_generator.zero_grad() # Make fake sample for every iteration upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad) padded_random_img = image_pad(upscaled_prev_random_img) padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img) # Adversarial loss fake_prob_out, _, fake_add_att_maps, fake_sub_att_maps = cur_discriminator(fake) g_adv_loss = -fake_prob_out.mean() # Reconstruction loss mse_criterion = nn.MSELoss() padded_rec_img_with_z = self.config.noise_amp * padded_rec_z + padded_rec_img g_rec_loss = mse_criterion(cur_generator(padded_rec_img_with_z.detach(), padded_rec_img), real) # Update parameters g_loss = g_adv_loss + (g_rec_loss * self.config.rec_weights) g_loss.backward() G_optimizer.step() # Log losses self.log_losses[f'{len(self.Gs)}th_G/g'] = g_loss self.log_losses[f'{len(self.Gs)}th_G/g_critic'] = -g_adv_loss.item() self.log_losses[f'{len(self.Gs)}th_G/g_rec'] = g_rec_loss.item() # Log losses for key, value in self.log_losses.items(): self.writer.add_scalar(key, value, epoch) self.log_losses = {} # Log image if epoch % self.config.img_save_iter == 0 or epoch == (self.config.num_iter - 1): np_real = torch2np(real) np_fake = torch2np(fake.detach()) plt.imsave(f'{self.config.result_dir}/{epoch}_fake_sample.png', np_fake, vmin=0, vmax=1) plt.imsave(f'{self.config.result_dir}/{epoch}_fixed_noise.png', torch2np(padded_rec_img_with_z.detach() * 2 - 1), vmin=0, vmax=1) plt.imsave(f'{self.config.result_dir}/{epoch}_reconstruction.png', torch2np(cur_generator(padded_rec_img_with_z.detach(), padded_rec_img).detach()), vmin=0, vmax=1) real_add_att_maps = real_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) real_sub_att_maps = real_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) fake_add_att_maps = fake_add_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) fake_sub_att_maps = fake_sub_att_maps.detach().to(torch.device('cpu')).numpy().transpose(1, 2, 3, 0) global global_att_dir global global_epoch global_att_dir = self.config.att_dir global_epoch = epoch parmap.map(save_heatmap, [[np_real, real_add_att_maps, 'real_add'], [np_real, real_sub_att_maps, 'real_sub'], [np_fake, fake_add_att_maps, 'fake_add'], [np_real, fake_sub_att_maps, 'fake_sub']], pm_pbar=False, pm_processes=4) D_scheduler.step() G_scheduler.step() # Save model weights torch.save(cur_generator.state_dict(), f'{self.config.result_dir}/generator.pth') torch.save(cur_discriminator.state_dict(), f'{self.config.result_dir}/ACM_discriminator.pth') return padded_rec_z, cur_generator, cur_discriminator
def inference(self, start_img_input): if start_img_input is None: start_img_input = torch.full(self.reals[0].shape, 0, device=self.config.device) cur_images = [] for idx, (G, Z_opt, noise_amp) in tqdm( enumerate(zip(self.Gs, self.Zs, self.noise_amps))): padding_size = ( (self.config.kernel_size - 1) * self.config.num_layers) / 2 pad = nn.ZeroPad2d(int(padding_size)) output_h = (Z_opt.shape[2] - padding_size * 2) * self.config.scale_h output_w = (Z_opt.shape[3] - padding_size * 2) * self.config.scale_w prev_images = cur_images cur_images = [] for i in range(self.config.num_samples): if idx == 0: random_z = generate_noise([1, output_h, output_w], device=self.config.device) random_z = random_z.expand(1, 3, random_z.shape[2], random_z.shape[3]) padded_random_z = pad(random_z) else: random_z = generate_noise( [self.config.img_channel, output_h, output_w], device=self.config.device) padded_random_z = pad(random_z) if self.config.use_fixed_noise and idx < self.config.gen_start_scale: padded_random_z = Z_opt if not prev_images: padded_random_img = pad(start_img_input) else: prev_img = prev_images[i] upscaled_prev_random_img = resize_img( prev_img, 1 / self.config.scale_factor, self.config) if self.config.mode == "train_SR": padded_random_img = pad(upscaled_prev_random_img) else: upscaled_prev_random_img = upscaled_prev_random_img[:, :, 0:round( self.config.scale_h * self.reals[idx].shape[2] ), 0:round(self.config.scale_w * self.reals[idx].shape[3])] padded_random_img = pad(upscaled_prev_random_img) padded_random_img = padded_random_img[:, :, 0:padded_random_z .shape[2], 0:padded_random_z .shape[3]] padded_random_img = upsampling( padded_random_img, padded_random_z.shape[2], padded_random_z.shape[3]) padded_random_img_with_z = noise_amp * padded_random_z + padded_random_img cur_image = G(padded_random_img_with_z.detach(), padded_random_img) if self.config.save_all_pyramid: plt.imsave(f'{self.config.infer_dir}/{i}_{idx}.png', torch2np(cur_image.detach()), vmin=0, vmax=1) elif idx == len(self.reals) - 1: plt.imsave(f'{self.config.infer_dir}/{i}.png', torch2np(cur_image.detach()), vmin=0, vmax=1) cur_images.append(cur_image) return cur_image.detach()
def train_single_stage(self, cur_discriminator, cur_generator): real = self.reals[len(self.Gs)] _, _, real_h, real_w = real.shape # Set padding layer(Initial padding) - To Do (Change this for noise padding not zero-padding)g self.config.receptive_field = self.config.kernel_size + ( (self.config.kernel_size - 1) * (self.config.num_layers - 1)) * self.config.stride padding_size = int( ((self.config.kernel_size - 1) * self.config.num_layers) / 2) noise_pad = nn.ZeroPad2d(int(padding_size)) image_pad = nn.ZeroPad2d(int(padding_size)) # MultiStepLR: lr *= gamma every time reaches one of the milestones D_optimizer = optim.Adam(cur_discriminator.parameters(), lr=self.config.d_lr, betas=(self.config.beta1, self.config.beta2)) G_optimizer = optim.Adam(cur_generator.parameters(), lr=self.config.g_lr, betas=(self.config.beta1, self.config.beta2)) D_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=D_optimizer, milestones=self.config.milestones, gamma=self.config.gamma) G_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=G_optimizer, milestones=self.config.milestones, gamma=self.config.gamma) rec_z = torch.full([1, self.config.img_channel, real_h, real_w], 0, device=self.config.device) for epoch in tqdm(range(self.config.num_iter), desc=f'{len(self.Gs)}th GAN'): # Make noise input if not self.Gs: rec_z = generate_noise([1, real_h, real_w], device=self.config.device).expand( 1, 3, real_h, real_w) random_z = generate_noise([1, real_h, real_w], device=self.config.device).expand( 1, 3, real_h, real_w) else: random_z = generate_noise( [self.config.img_channel, real_h, real_w], device=self.config.device) padded_rec_z = noise_pad(rec_z) padded_random_z = noise_pad(random_z) # Train Discriminator: Maximize D(x) - D(G(z)) -> Minimize D(G(z)) - D(X) for i in range(self.config.n_critic): # Train with real data cur_discriminator.zero_grad() real_prob_out = cur_discriminator(real) d_real_loss = -real_prob_out.mean( ) # Maximize D(X) -> Minimize -D(X) d_real_loss.backward(retain_graph=True) D_x = -d_real_loss.item() if i == 0 and epoch == 0: if not self.Gs: self.first_img_input = torch.full( [1, self.config.img_channel, real_h, real_w], 0, device=self.config.device) upscaled_prev_random_img = self.first_img_input padded_random_img = image_pad(upscaled_prev_random_img) upscaled_prev_rec_img = torch.full( [1, self.config.img_channel, real_h, real_w], 0, device=self.config.device) padded_rec_img = image_pad(upscaled_prev_rec_img) self.config.noise_amp = 1 else: upscaled_prev_random_img = self.draw_sequentially( 'rand', noise_pad, image_pad) padded_random_img = image_pad(upscaled_prev_random_img) upscaled_prev_rec_img = self.draw_sequentially( 'rec', noise_pad, image_pad) criterion = nn.MSELoss() rmse = torch.sqrt( criterion(real, upscaled_prev_rec_img)) self.config.noise_amp = self.config.noise_amp_init * rmse padded_rec_img = image_pad(upscaled_prev_rec_img) else: upscaled_prev_random_img = self.draw_sequentially( 'rand', noise_pad, image_pad) padded_random_img = image_pad(upscaled_prev_random_img) # Make random image input if not self.Gs: padded_random_img_with_z = padded_random_z else: padded_random_img_with_z = ( self.config.noise_amp * padded_random_z) + padded_random_img # Train with fake data fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img) fake_prob_out = cur_discriminator(fake.detach()) d_fake_loss = fake_prob_out.mean() # Minimize D(G(z)) d_fake_loss.backward(retain_graph=True) D_G_z = d_fake_loss.item() # Gradient penalty gradient_penalty = calcul_gp(cur_discriminator, real, fake, self.config.device, False) * self.config.gp_weights gradient_penalty.backward() D_optimizer.step() d_loss = d_real_loss + d_fake_loss + gradient_penalty critic = D_x - D_G_z self.log_losses[f'{len(self.Gs)}th_D/d'] = d_loss.item() self.log_losses[f'{len(self.Gs)}th_D/d_critic'] = critic self.log_losses[ f'{len(self.Gs)}th_D/d_gp'] = gradient_penalty.item() # Train Generator : Maximize D(G(z)) -> Minimize -D(G(z)) for i in range(self.config.generator_iter): cur_generator.zero_grad() # Make fake sample for every iteration # upscaled_prev_random_img = self.draw_sequentially('rand', noise_pad, image_pad) # padded_random_img = image_pad(upscaled_prev_random_img) # padded_random_img_with_z = self.config.noise_amp * padded_random_z + padded_random_img # fake = cur_generator(padded_random_img_with_z.detach(), padded_random_img) # Adversarial loss fake_prob_out = cur_discriminator(fake) g_adv_loss = -fake_prob_out.mean() g_adv_loss.backward(retain_graph=True) g_adv_loss = g_adv_loss.item() # Reconstruction loss mse_criterion = nn.MSELoss() padded_rec_img_with_z = self.config.noise_amp * padded_rec_z + padded_rec_img g_rec_loss = self.config.rec_weights * mse_criterion( cur_generator(padded_rec_img_with_z.detach(), padded_rec_img), real) g_rec_loss.backward(retain_graph=True) g_rec_loss = g_rec_loss.item() G_optimizer.step() g_loss = g_adv_loss + (self.config.rec_weights * g_rec_loss) self.log_losses[f'{len(self.Gs)}th_G/g'] = g_loss self.log_losses[f'{len(self.Gs)}th_G/g_critic'] = -g_adv_loss self.log_losses[f'{len(self.Gs)}th_G/g_rec'] = g_rec_loss # Log losses for key, value in self.log_losses.items(): self.writer.add_scalar(key, value, epoch) self.log_losses = {} # Log image if epoch % self.config.img_save_iter == 0 or epoch == ( self.config.num_iter - 1): plt.imsave(f'{self.config.result_dir}/{epoch}_fake_sample.png', torch2np(fake.detach()), vmin=0, vmax=1) plt.imsave(f'{self.config.result_dir}/{epoch}_fixed_noise.png', torch2np(padded_rec_img_with_z.detach() * 2 - 1), vmin=0, vmax=1) plt.imsave( f'{self.config.result_dir}/{epoch}_reconstruction.png', torch2np( cur_generator(padded_rec_img_with_z.detach(), padded_rec_img).detach()), vmin=0, vmax=1) D_scheduler.step() G_scheduler.step() # Save model weights torch.save(cur_generator.state_dict(), f'{self.config.result_dir}/generator.pth') torch.save(cur_discriminator.state_dict(), f'{self.config.result_dir}/discriminator.pth') return padded_rec_z, cur_generator
from config import Config # from model.ACM_SinGAN import SinGAN_ACM as SinGAN from model.SinGAN import SinGAN from utils.image import read_img, torch2np from utils.utils import process_config, adjust_scales, calcul_sr_scale if __name__ == '__main__': process_config(Config) Config.infer_dir = f'{Config.exp_dir}/infer_pyramid' if Config.save_all_pyramid else f'{Config.exp_dir}/infer' singan = SinGAN(config=Config) singan.load_trained_weights() inference_img = read_img(Config) os.makedirs(Config.infer_dir, exist_ok=True) plt.imsave(f'{Config.exp_dir}/real.png', torch2np(inference_img), vmin=0, vmax=1) if Config.mode == 'train': adjust_scales(inference_img, Config) start_img_input = singan.create_inference_input() elif Config.mode == 'train_SR': in_scale, iter_num = calcul_sr_scale(Config) Config.scale_factor = 1 / in_scale Config.scale_factor_init = 1 / in_scale Config.scale_w = 1 Config.scale_h = 1 start_img_input = singan.create_sr_inference_input( singan.reals[-1], iter_num)