def preprocess_content_image(opt, reals,scale): real = functions.read_image(opt) functions.adjust_scales2image(real, opt) ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) if ref.shape[3] != real.shape[3]: ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] N = len(reals) - 1 n = scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] return in_s
img[:, :, k][im_mask[:, :, k] == 1] = img[:, :, k][im_mask[:, :, k] == 0].mean() cv2.imwrite( '%s/%s_global_mean%s' % (opt.input_dir, opt.ref_name[:-4], opt.ref_name[-4:]), img) ref = functions.read_image_dir( '%s/%s_global_mean%s' % (opt.input_dir, opt.ref_name[:-4], opt.ref_name[-4:]), opt) mask = functions.read_image_dir( '%s/%s_mask%s' % (opt.ref_dir, opt.ref_name[:-4], opt.ref_name[-4:]), opt) if ref.shape[3] != real.shape[3]: mask = imresize_to_shape(mask, [real.shape[2], real.shape[3]], opt) mask = mask[:, :, :real.shape[2], :real.shape[3]] ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) ref = ref[:, :, :real.shape[2], :real.shape[3]] mask = functions.dilate_mask(mask, opt) N = len(reals) - 1 n = opt.inpainting_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] out = SinGAN_generate(Gs[n:], Zs[n:], reals,
NoiseAmp = [] dir2save = functions.generate_dir2save(opt) if dir2save is None: print('task does not exist') #elif (os.path.exists(dir2save)): # print("output already exist") else: try: os.makedirs(dir2save) except OSError: pass real = functions.read_image(opt) if opt.max_size < 251: a = imresize_to_shape(real, [real.shape[2], real.shape[3]], opt) real = functions.adjust_scales2image(real, opt) Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) if (opt.inpainting_scale_start < 1) | (opt.inpainting_scale_start > (len(Gs) - 1)): print("injection scale should be between 1 and %d" % (len(Gs) - 1)) else: #Importing masked image if opt.on_drive != None: masked_img = cv2.imread( '%s/%s/%s' % (opt.on_drive, opt.input_dir, opt.input_name)) masked_img = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB) else: masked_img = cv2.imread('%s/%s' % (opt.input_dir, opt.input_name)) masked_img = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB)
os.makedirs(dir2save) except OSError: pass #real = functions.read_image(opt) #real = functions.adjust_scales2image(real, opt) #Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) real1, real2 = functions.read_image(opt) real1 = functions.adjust_scales2image(real1, opt) real2 = functions.adjust_scales2image(real2, opt) Gs, Zs, reals1, reals2, NoiseAmp = functions.load_trained_pyramid(opt) if (opt.paint_start_scale < 1) | (opt.paint_start_scale > (len(Gs)-1)): print("injection scale should be between 1 and %d" % (len(Gs)-1)) else: ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) if ref.shape[3] != real1.shape[3]: ref = imresize_to_shape(ref, [real1.shape[2], real1.shape[3]], opt) ref = ref[:, :, :real1.shape[2], :real1.shape[3]] N = len(reals1) - 1 n = opt.paint_start_scale in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) in_s = in_s[:, :, :reals1[n - 1].shape[2], :reals1[n - 1].shape[3]] in_s = imresize(in_s, 1 / opt.scale_factor, opt) in_s = in_s[:, :, :reals1[n].shape[2], :reals1[n].shape[3]] if opt.quantization_flag: opt.mode = 'paint_train' dir2trained_model = functions.generate_dir2save(opt) real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt) real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] real_quant, centers = functions.quant(real_s, opt.device) plt.imsave('%s/real_quant.png' % dir2save, functions.convert_image_np(real_quant), vmin=0, vmax=1)
def train(opt, Gs, Zs, reals, NoiseAmp): real_ = functions.read_image(opt) in_s = 0 # cur_scale_level: current level from coarest to finest. cur_scale_level = 0 # real = imresize(real_,opt.scale1,opt) # scale1: for the largest patch size, what ratio wrt the image shape reals = functions.creat_reals_pyramid(real_, reals, opt) upsamples = [reals[0]] diffs = [reals[0]] # Need to generate opt.stop_scale, thus upsample opt.stop_scale-1 for i in range(opt.stop_scale): cur_img = reals[i] next_img = reals[i + 1] _, b, c, d = next_img.shape upsampled_real = imresize_to_shape(cur_img, (c, d, b), opt) upsamples.append(upsampled_real) diff = (next_img - upsampled_real).abs() - 1 # [-1, 1] diffs.append(diff) # Train including opt.stop_scale while cur_scale_level < opt.stop_scale + 1: # nfc: number of out channels in conv block opt.nfc = min(opt.nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) opt.min_nfc = min( opt.min_nfc_init * pow(2, math.floor(cur_scale_level / 4)), 128) # out_: output directory # outf: output folder, with scale opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, cur_scale_level) try: os.makedirs(opt.outf) except OSError: pass plt.imsave('%s/real_scale.png' % (opt.outf), functions.convert_image_np(reals[cur_scale_level]), vmin=0, vmax=1) plt.imsave('%s/diff.png' % (opt.outf), functions.convert_image_np(diffs[cur_scale_level].detach()), vmin=0, vmax=1) plt.imsave('%s/upsampled.png' % (opt.outf), functions.convert_image_np( upsamples[cur_scale_level].detach()), vmin=0, vmax=1) D_curr, G_curr = init_models(opt) # Notice, as the level increases, the architecture of CNN block might differ. (every 4 levels according to the paper) # No need to reload, since training in parallel # if (nfc_prev==opt.nfc): # G_curr.load_state_dict(torch.load('%s/%d/netG.pth' % (opt.out_,cur_scale_level-1))) # D_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,cur_scale_level-1))) # in_s: guess: initial signal? it doesn't change during the training, and is a zero tensor. # z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, Gs, Zs, in_s, NoiseAmp, opt) z_curr, in_s, G_curr = train_single_scale(D_curr, G_curr, reals, upsamples, cur_scale_level, in_s, NoiseAmp, opt) G_curr = functions.reset_grads(G_curr, False) G_curr.eval() D_curr = functions.reset_grads(D_curr, False) D_curr.eval() Gs.append(G_curr) Zs.append(z_curr) NoiseAmp.append(opt.noise_amp) torch.save(Zs, '%s/Zs.pth' % (opt.out_)) torch.save(Gs, '%s/Gs.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(NoiseAmp, '%s/NoiseAmp.pth' % (opt.out_)) cur_scale_level += 1 # nfc_prev = opt.nfc del D_curr, G_curr torch.cuda.empty_cache() return