def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt, index_image, cuda_device): G_z = in_s if len(Gs) > 0: if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) for scale_idx, (G) in enumerate(Gs): Z_opt = torch.cat([Zs[idx][scale_idx] for idx in index_image], dim=0) real_curr = torch.cat( [reals[idx][scale_idx] for idx in index_image], dim=0) real_next = torch.cat( [reals[idx][1:][scale_idx] for idx in index_image], dim=0) noise_amp = torch.cat( ([NoiseAmp[id][scale_idx] for id in range(opt.num_images)]), dim=0).to(cuda_device) if count == 0: z = functions.generate_noise([ 1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=cuda_device, num_samp=real_curr.shape[0]) z = z.expand(real_curr.shape[0], 3, z.shape[2], z.shape[3]) else: z = functions.generate_noise([ opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=cuda_device, num_samp=real_curr.shape[0]) noise_amp_tensor = torch.full( [1, z.shape[1], z.shape[2], z.shape[3]], noise_amp[0][0].item(), dtype=torch.float).to(cuda_device) for i in range(1, opt.num_images): temp = torch.full([1, z.shape[1], z.shape[2], z.shape[3]], noise_amp[i][0].item(), dtype=torch.float).to(cuda_device) noise_amp_tensor = torch.cat((noise_amp_tensor, temp), dim=0) z = m_noise(z) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = m_noise(noise_amp_tensor) * z + G_z padded_id_z_in = pad_image_id(z_in, index_image) G_z_temp = G(padded_id_z_in.detach(), G_z) if isinstance(G_z_temp, list): G_z_temp = [tens.to(cuda_device) for tens in G_z_temp] G_z_temp = torch.cat(G_z_temp) G_z = imresize(torch.unsqueeze(G_z_temp[0], dim=0), 1 / opt.scale_factor, opt) for id in range(1, opt.num_images): G_z = torch.cat( (G_z, imresize(torch.unsqueeze(G_z_temp[id], dim=0), 1 / opt.scale_factor, opt)), dim=0) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 if mode == 'rec': count = 0 for scale_idx, (G) in enumerate((Gs)): Z_opt = torch.cat([Zs[idx][scale_idx] for idx in index_image], dim=0) real_curr = torch.cat( [reals[idx][scale_idx] for idx in index_image], dim=0) real_next = torch.cat( [reals[idx][1:][scale_idx] for idx in index_image], dim=0) noise_amp = torch.cat( ([NoiseAmp[id][scale_idx] for id in range(opt.num_images)]), dim=0).to(cuda_device) noise_amp_tensor = torch.full( [1, Z_opt.shape[1], Z_opt.shape[2], Z_opt.shape[3]], noise_amp[0][0].item(), dtype=torch.float).to(cuda_device) for i in range(1, opt.num_images): temp = torch.full( [1, Z_opt.shape[1], Z_opt.shape[2], Z_opt.shape[3]], noise_amp[i][0].item(), dtype=torch.float).to(cuda_device) noise_amp_tensor = torch.cat((noise_amp_tensor, temp), dim=0) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp_tensor * Z_opt + G_z padded_id_z_in = pad_image_id(z_in, index_image) G_z_temp = G(padded_id_z_in.detach(), G_z) if isinstance(G_z_temp, list): G_z_temp = [tens.to(cuda_device) for tens in G_z_temp] G_z_temp = torch.cat(G_z_temp) G_z = imresize(torch.unsqueeze(G_z_temp[0], dim=0), 1 / opt.scale_factor, opt) for id in range(1, opt.num_images): G_z = torch.cat( (G_z, imresize(torch.unsqueeze(G_z_temp[id], dim=0), 1 / opt.scale_factor, opt)), dim=0) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 return G_z
def train(opt, Gs, Zs, reals, NoiseAmp): reals, Zs, NoiseAmp, in_s, scale_num = functions.collect_reals( opt, reals, Zs, NoiseAmp) nfc_prev = 0 for index_image in range(int(opt.num_images)): NoiseAmp[index_image] = [] while scale_num < opt.stop_scale + 1: opt.nfc = min(opt.nfc_init * pow(2, math.floor(scale_num / 4)), opt.size_image) opt.min_nfc = min(opt.min_nfc_init * pow(2, math.floor(scale_num / 4)), opt.size_image) D_curr, G_curr = init_models(opt) index_arr_flag = 0 for epoch in range(opt.num_epochs): optimizerD = optim.Adam(D_curr.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(G_curr.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerD, milestones=[500], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerG, milestones=[1000], gamma=opt.gamma) print(" ") print("this is class: ", opt.pos_class) print("size image: ", opt.size_image) print("num images: ", opt.num_images) print("index of download: ", opt.index_download) print("num transforms: ", opt.num_transforms) print(" ") index_image = range(int(opt.num_images)) G_curr = functions.reset_grads(G_curr, True) D_curr = functions.reset_grads(D_curr, True) opt.out_ = functions.generate_dir2save(opt) opt.global_outf = '%s/%d' % (opt.out_, scale_num) if os.path.exists(opt.global_outf): shutil.rmtree(opt.global_outf) opt.outf = [ '%s/%d/index_image_%d' % (opt.out_, scale_num, id) for id in index_image ] try: for j in opt.outf: try: os.makedirs(j) except: pass except OSError as err: print("OS error: {0}".format(err)) pass for id in index_image: plt.imsave('%s/real_scale%d.png' % (opt.outf[id], id), functions.convert_image_np(reals[id][scale_num]), vmin=0, vmax=1) if (nfc_prev == opt.nfc): G_curr.load_state_dict( torch.load('%s/%d/netG.pth' % (opt.out_, scale_num - 1))) D_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) nfc_prev = 0 if not index_arr_flag: real = torch.cat([reals[id][scale_num] for id in index_image], dim=0) opt.nzx = real.shape[2] opt.nzy = real.shape[3] opt.receptive_field = opt.ker_size + ( (opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride # 11 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) # 5 m_noise = nn.ZeroPad2d(int(pad_noise)) z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise( z_opt.expand(real.shape[0], 3, opt.nzx, opt.nzy)) elif index_arr_flag: z_opt = torch.cat(([Zs[id][scale_num] for id in index_image]), dim=0) in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt, index_image, z_opt, optimizerD, optimizerG, schedulerD, schedulerG, index_arr_flag) if not index_arr_flag: for id in index_image: Zs[id].append(z_opt[id:id + 1]) NoiseAmp[id].append(opt.noise_amp[id:id + 1]) else: for id in index_image: NoiseAmp[id][scale_num] = opt.noise_amp[id:id + 1] index_arr_flag = True G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) scale_num += 1 nfc_prev = opt.nfc del D_curr, G_curr, optimizerD, optimizerG torch.cuda.empty_cache() return
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, index_image, z_opt, optimizerD, optimizerG, schedulerD, schedulerG, is_passing_im_before, centers=None): real = torch.cat([reals[id][len(Gs)] for id in index_image], dim=0) pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha errD2plot = [] errG2plot = [] D_real2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if Gs == []: noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device, num_samp=real.shape[0]) noise_ = m_noise(noise_.expand(real.shape[0], 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device, num_samp=real.shape[0]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for p in netD.parameters(): p.requires_grad = True # to avoid computation for j in range(opt.Dsteps): netD.zero_grad() reals_arr = [] num_transforms = 0 for index_transform, pair in enumerate(opt.list_transformations): num_transforms += 1 flag_color, is_flip, tx, ty, k_rotate = pair real_transform = apply_augmentation(real, is_flip, tx, ty, k_rotate, flag_color).to(opt.device) real_transform = torch.squeeze(real_transform) reals_arr.append(real_transform) opt.num_transforms = num_transforms num_transforms_range = range(opt.num_transforms) real_transform = torch.stack(reals_arr) if opt.num_images > 1: real_transform = real_transform.reshape( -1, real_transform.shape[2], real_transform.shape[3], real_transform.shape[4]) output = netD(real_transform) if isinstance(output, list): id_padding = [ torch.full((opt.num_images, 1, output[0].shape[2], output[0].shape[3]), id, dtype=torch.float) for id in num_transforms_range ] else: id_padding = [ torch.full( (opt.num_images, 1, output.shape[2], output.shape[3]), id, dtype=torch.float).to(opt.device) for id in num_transforms_range ] errD_real = get_err_D_real_and_backward(output, id_padding, opt.device_ids) D_x = -errD_real.item() # train with fake - this is the first time in this scale if (j == 0) & (epoch == 0): if (Gs == []) & (is_passing_im_before == False): prev = torch.full( [real.shape[0], opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device, dtype=torch.long) in_s = prev prev = m_image(prev) z_prev = torch.full( [real.shape[0], opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device, dtype=torch.long) z_prev = m_noise(z_prev) opt.noise_amp = torch.full([real.shape[0], 1], 1, dtype=torch.long).to(opt.device) opt.noise_amp_tensor = torch.full( [real.shape[0], opt.nc_z, opt.nzx, opt.nzy], 1, dtype=torch.long).to(opt.device) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt, index_image, cuda_device=opt.device) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt, index_image, cuda_device=opt.device) criterion = nn.MSELoss(reduction='none') opt.noise_amp = torch.cat(([ NoiseAmp[id][len(Gs) - 1] for id in range(opt.num_images) ]), dim=0).to(opt.device) if not is_passing_im_before: temp = criterion(real, z_prev) RMSE = torch.sqrt(temp) opt.noise_amp_init_tensor = torch.full( [real.shape[0], opt.nc_z, opt.nzx, opt.nzy], opt.noise_amp_init, dtype=torch.float).to(opt.device) opt.noise_amp = opt.noise_amp_init_tensor * RMSE opt.noise_amp = torch.unsqueeze(torch.mean( torch.flatten(opt.noise_amp, start_dim=1), dim=1), dim=1) opt.noise_amp_tensor = torch.full( [1, opt.nc_z, opt.nzx, opt.nzy], opt.noise_amp[0][0].item(), dtype=torch.float).to(opt.device) for i in range(1, real.shape[0]): temp = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], opt.noise_amp[i][0].item(), dtype=torch.float).to(opt.device) opt.noise_amp_tensor = torch.cat( (opt.noise_amp_tensor, temp), dim=0) else: opt.noise_amp = torch.cat(([ NoiseAmp[id][len(Gs) - 1] for id in range(opt.num_images) ]), dim=0).to(opt.device) opt.noise_amp_tensor = torch.full( [1, opt.nc_z, opt.nzx, opt.nzy], opt.noise_amp[0][0].item(), dtype=torch.float).to(opt.device) for i in range(1, real.shape[0]): temp = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], opt.noise_amp[i][0].item(), dtype=torch.float).to(opt.device) opt.noise_amp_tensor = torch.cat( (opt.noise_amp_tensor, temp), dim=0) z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt, index_image, cuda_device=opt.device) prev = m_image(prev) if Gs == []: noise = noise_ else: noise = m_noise(opt.noise_amp_tensor) * noise_ + prev padded_id_noise = pad_image_id(noise, index_image) fake = netG(padded_id_noise.detach(), prev) if isinstance(fake, list): fake = [tens.to(opt.device) for tens in fake] fake = torch.cat(fake) fakes_arr = [] for index_transform, pair in enumerate(opt.list_transformations): flag_color, is_flip, tx, ty, k_rotate = pair fake_transform = apply_augmentation(fake.detach(), is_flip, tx, ty, k_rotate, flag_color).to(opt.device) fake_transform = torch.squeeze(fake_transform) fakes_arr.append(fake_transform) fake_transform = torch.stack(fakes_arr) if opt.num_images > 1: fake_transform = fake_transform.reshape( -1, fake_transform.shape[2], fake_transform.shape[3], fake_transform.shape[4]) output = netD(fake_transform) errD_fake = get_err_D_fake_and_backward(output, opt.num_transforms, opt.num_images, opt.device_ids) if isinstance(output, list): output = [tens.to(opt.device) for tens in output] output = torch.cat(output) D_G_z = output.mean().item() errD = errD_real + errD_fake if j == opt.Dsteps - 1 and index_transform == opt.num_transforms - 1 and ( epoch % 50 == 0 or epoch == (opt.niter - 1)): print("errD real: ", D_x, "errD fake: ", D_G_z) optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for p in netD.parameters(): p.requires_grad = False # to avoid computation for j in range(opt.Gsteps): netG.zero_grad() padded_id_noise = pad_image_id(noise, index_image) fake = netG(padded_id_noise, prev) if isinstance(fake, list): fake = [tens.to(opt.device) for tens in fake] fake = torch.cat(fake) fakes_arr_G = [] for index_transform, pair in enumerate(opt.list_transformations): flag_color, is_flip, tx, ty, k_rotate = pair fake_transform_G = apply_augmentation( fake, is_flip, tx, ty, k_rotate, flag_color).to(opt.device) fake_transform_G = torch.squeeze(fake_transform_G) fakes_arr_G.append(fake_transform_G) fake_transform_G = torch.stack(fakes_arr_G) if opt.num_images > 1: fake_transform_G = fake_transform_G.reshape( -1, fake_transform_G.shape[2], fake_transform_G.shape[3], fake_transform_G.shape[4]) output = netD(fake_transform_G) if isinstance(output, list): id_padding = [ torch.full((opt.num_images, 1, output[0].shape[2], output[0].shape[3]), id, dtype=torch.float).to(opt.device) for id in num_transforms_range ] else: id_padding = [ torch.full( (opt.num_images, 1, output.shape[2], output.shape[3]), id, dtype=torch.float).to(opt.device) for id in num_transforms_range ] errG = get_err_G_fake_and_backward(output, id_padding, opt.device_ids) if alpha != 0: loss = nn.MSELoss() Z_opt = m_noise(opt.noise_amp_tensor) * z_opt + z_prev padded_id_Z_opt = pad_image_id(Z_opt.detach(), index_image) negG_output = netG(padded_id_Z_opt.detach(), z_prev) if isinstance(output, list): negG_output = [tens.to(opt.device) for tens in negG_output] negG_output = torch.cat(negG_output) rec_loss = alpha * loss(negG_output, real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 if j == opt.Gsteps - 1 and index_transform == opt.num_transforms - 1 and ( epoch % 50 == 0 or epoch == (opt.niter - 1)): print("errG fake: ", errG.detach().item(), "rec loss: ", rec_loss.detach().item()) optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) z_opt2plot.append(rec_loss) if (epoch % 25 == 0 or epoch == (opt.niter - 1)): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) print(" ") if index_transform == opt.num_transforms - 1 and ( epoch % 150 == 0 or epoch == (opt.niter - 1)): for j, id in enumerate(opt.outf): plt.imsave('%s/fake_sample_epoch%d.png' % (id, epoch), functions.convert_image_np(fake[j:j + 1].detach()), vmin=0, vmax=1) padded_id_Z_opt = pad_image_id(Z_opt.detach(), index_image) G_z_opt = netG(padded_id_Z_opt.detach(), z_prev.detach()) if isinstance(G_z_opt, list): G_z_opt = [tens.to(opt.device) for tens in G_z_opt] G_z_opt = torch.cat(G_z_opt) for j, id in enumerate(opt.outf): plt.imsave('%s/G(z_opt)_epoch%d.png' % (id, epoch), functions.convert_image_np(G_z_opt[j:j + 1].detach()), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.global_outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return in_s, netG
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=200): if in_s is None: in_s = torch.full(reals[0][0].shape, 0, device=opt.device, dtype=torch.long) images_cur = [] index_image = range(int(opt.num_images)) for scale_idx, (G) in enumerate(Gs): Z_opt = torch.cat([Zs[idx][scale_idx] for idx in index_image], dim=0) noise_amp = torch.cat(([NoiseAmp[id][scale_idx] for id in range(opt.num_images)]), dim=0).cuda() pad1 = ((opt.ker_size-1)*opt.num_layer)/2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2]-pad1*2)*scale_v nzy = (Z_opt.shape[3]-pad1*2)*scale_h images_prev = images_cur images_cur = [] for i in range(0,num_samples,1): if n == 0: z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) else: I_prev_temp = images_prev[i] I_prev = imresize(torch.unsqueeze(I_prev_temp[0], dim=0), 1 / opt.scale_factor, opt) for id in range(1, opt.num_images): I_prev = torch.cat((I_prev, imresize(torch.unsqueeze(I_prev_temp[id], dim=0), 1 / opt.scale_factor, opt)), dim=0) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[0][n].shape[2]), 0:round(scale_h * reals[0][n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt noise_amp_tensor = torch.full([1, z_curr.shape[1], z_curr.shape[2], z_curr.shape[3]], noise_amp[0][0].item(), dtype=torch.float).cuda() for j in range(1, opt.num_images): temp = torch.full([1, z_curr.shape[1], z_curr.shape[2], z_curr.shape[3]], noise_amp[j][0].item(), dtype=torch.float).cuda() noise_amp_tensor = torch.cat((noise_amp_tensor, temp), dim=0) z_in = noise_amp_tensor*(z_curr)+I_prev padded_id_Z_opt = pad_image_id(z_in, index_image) I_curr = G(padded_id_Z_opt.detach(),I_prev) if n == len(Gs)-1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): for j in range(opt.num_images): plt.imsave('%s/%d%d.png' % (dir2save, i,j), functions.convert_image_np(I_curr[j].unsqueeze(dim=0).detach()), vmin=0,vmax=1) images_cur.append(I_curr) n+=1 return I_curr.detach()