def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50): #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp): pad1 = ((opt.ker_size-1)*opt.num_layer)/2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2]-pad1*2)*scale_v nzy = (Z_opt.shape[3]-pad1*2)*scale_h images_prev = images_cur images_cur = [] for i in range(0,num_samples,1): if n == 0: z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev,1/opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp*(z_curr)+I_prev I_curr = G(z_in.detach(),I_prev) if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n+=1 return I_curr.detach()
def draw_concat(Gs,Zs,reals,NoiseAmp,in_s,mode,m_noise,m_image,opt): G_z = in_s if len(Gs) > 0: if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size-1)*opt.num_layer)/2) if opt.mode == 'animation_train': pad_noise = 0 for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): if count == 0: z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device) z = z.expand(1, 3, z.shape[2], z.shape[3]) else: z = functions.generate_noise([opt.nc_z,Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device) z = m_noise(z) G_z = G_z[:,:,0:real_curr.shape[2],0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp*z+G_z G_z = G(z_in.detach(),G_z) G_z = imresize(G_z,1/opt.scale_factor,opt) G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]] count += 1 if mode == 'rec': count = 0 for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp*Z_opt+G_z G_z = G(z_in.detach(),G_z) G_z = imresize(G_z,1/opt.scale_factor,opt) G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]] #if count != (len(Gs)-1): # G_z = m_image(G_z) count += 1 return G_z
def draw_concat(Gs,Zs,reals,NoiseAmp,in_s,mode,m_noise,m_image,opt): G_z = in_s if len(Gs) > 0: if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size-1)*opt.num_layer)/2) if opt.mode == 'animation_train': pad_noise = 0 for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): if count == 0: z = functions.generate_noise([1, Z_opt.shape[1] - 2 * pad_noise, Z_opt.shape[2] - 2 * pad_noise]) z = tf.broadcast_to(z, [1, z.shape[1], z.shape[2], 3]) else: z = functions.generate_noise([opt.nc_z,Z_opt.shape[1] - 2 * pad_noise, Z_opt.shape[2] - 2 * pad_noise]) z = m_noise(z) G_z = G_z[:,0:real_curr.shape[1],0:real_curr.shape[2],:] #PY: NCWH, TF:NWHC G_z = m_image(G_z) z_in = noise_amp*z+G_z G_z = G(z_in,G_z, training=True) G_z = imresize(G_z,1/opt.scale_factor,opt) G_z = G_z[:,0:real_next.shape[1],0:real_next.shape[2], :] count += 1 if mode == 'rec': count = 0 for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): G_z = G_z[:, 0:real_curr.shape[1], 0:real_curr.shape[2], :] G_z = m_image(G_z) z_in = noise_amp*Z_opt+G_z G_z = G(z_in,G_z, training=True) G_z = imresize(G_z,1/opt.scale_factor,opt) G_z = G_z[:,0:real_next.shape[1],0:real_next.shape[2], :] count += 1 return G_z
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt): G_z = in_s # if it's not the first scale, else do nothign if len(Gs) > 0: # if in random mode if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) #from each scale for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): # for the first loop if count == 0: #generate the noise z = functions.generate_noise([ 1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) #broadcast it to correct shape z = z.expand(1, 3, z.shape[2], z.shape[3]) else: #direct generate the noise z = functions.generate_noise([ opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) #padding the noise z = m_noise(z) #------------------------------------------------------------ #generate a shape of current real image's [width,height] from G_z(in_s) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] #padding it with images G_z = m_image(G_z) #amplify the generated noise, then add with the G_z z_in = noise_amp * z + G_z #generate a new output from generator G_z = G(z_in.detach(), G_z) #resize the graph with 1/opt.scale_factor G_z = imresize(G_z, 1 / opt.scale_factor, opt) #generate a shape of current real image's [width,height] from G_z(in_s) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 if mode == 'rec': count = 0 #from each scale for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): # do same thing except G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * Z_opt + G_z # for here we use Z_opt instead of generated noise G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 return G_z
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt): """ Generate through all higher level Gs """ G_z = in_s # G_z is the current image output if len( Gs ) > 0: # skipped for the initial pyr level, since there is no previous G to generate if mode == 'rand': # using random noise map Z_opt count = 0 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': pad_noise = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): if count == 0: # Z_opt is not really used, except its size. z = functions.generate_noise([ 1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) z = z.expand(1, 3, z.shape[2], z.shape[3]) # same value along color channel else: z = functions.generate_noise( [ opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) # noise including color z = m_noise(z) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * z + G_z G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) # upsample it to current level G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 if mode == 'rec': # using reconstruction vectors Z_opt count = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[ 3]] # make sure the size is the same as real pyr G_z = m_image(G_z) z_in = noise_amp * Z_opt + G_z # use the loaded noise amplitude G_z = G(z_in.detach(), G_z) # THis is the iteration equation for G_z G_z = imresize(G_z, 1 / opt.scale_factor, opt) # upsample it to current level G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[ 3]] # make sure the size is the same as real pyr #if count != (len(Gs)-1): # G_z = m_image(G_z) count += 1 return G_z
def _create_noise_for_draw_concat(opt, count, pad_noise, m_noise, Z_opt, noise_mode): if count == 0: z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) z = z.expand(1, 3, z.shape[2], z.shape[3]) else: z = functions.generate_noise([opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) z = m_noise(z) return z
def compute_z_diff(n, Z_opt, z_prev1, z_prev2, beta, device): """ compute z_diff_n(t+1) """ nzx, nzy = Z_opt.shape[2], Z_opt.shape[3] nc_z = 3 if n == 0: z_rand = functions.generate_noise([1, nzx, nzy], device=device) # make z_rand same across channels z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3]) z_diff = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand else: z_diff = beta * (z_prev1 - z_prev2) + (1 - beta) * ( functions.generate_noise([nc_z, nzx, nzy], device=device)) return z_diff
def _create_noise_for_iteration(is_first_scale, m_noise, opt, default_z_opt, noise_mode): if is_first_scale: z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: z_opt = default_z_opt noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) noise_ = m_noise(noise_) return noise_, z_opt
def _generate_noise_for_sampling(m, n, nzx, nzy, opt, noise_mode): if n == 0: z_curr = functions.generate_noise( [1, nzx, nzy], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise( [opt.nc_z, nzx, nzy], device=opt.device, noise_mode=noise_mode, gaussian_noise_z_distance=opt.gaussian_noise_z_distance) z_curr = m(z_curr) return z_curr
def compute_z_prev(n, Z_opt, device): """ compute z_n at previous time, i.e. z_n(t), z_n(t-1) :param: n -- int, indicate scale level (0 = first generator, i.e. coarest level) Z_opt -- input noise at the n-th scale (gaussian noise at first generator, elsewhere 0) device -- torch.device, CUDA / CPU """ nzx, nzy = Z_opt.shape[2], Z_opt.shape[3] # no. of channel for noise input nc_z = 3 if n == 0: # z_rand is gaussian noise z_rand = functions.generate_noise([1, nzx, nzy], device=device) z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3]) z_prev1 = 0.95 * Z_opt + 0.05 * z_rand z_prev2 = Z_opt else: z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise( [nc_z, nzx, nzy], device=device) z_prev2 = Z_opt return z_prev1, z_prev2
def SinGAN_anchor_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=1, anchor_image=None, direction=None, transfer=None, noise_solutions=None, factor=None, base=None, insert_limit=0): #### Loading in Anchor if Needed ##### anchor = anchor_image if anchor is not None: anchors = [] anchor = functions.np2torch(anchor_image, opt) anchor_ = imresize(anchor, opt.scale1, opt) anchors = functions.creat_reals_pyramid(anchor_, anchors, opt) #high key hacky code if direction is not None: directions = [] direction = functions.np2torch(direction, opt) direction_ = imresize(direction, opt.scale1, opt) directions = functions.creat_reals_pyramid(direction_, directions, opt) #high key hacky code if base is not None: bases = [] base = functions.np2torch(base, opt) base_ = imresize(base, opt.scale1, opt) bases = functions.creat_reals_pyramid(base_, bases, opt) #high key hacky code #### MY CODE #### #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: #COARSEST SCALE z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) z_orig = z_curr if images_prev == []: #FIRST GENERATION IN COARSEST SCALE I_prev = m(in_s) else: #NOT FIRST GENERATION, BUT AT COARSEST SCALE I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) #upscale #print(n) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling( I_prev, z_curr.shape[2], z_curr.shape[3]) #make it fit padded noise else: #prev_before = I_prev #MY ADDITION I_prev = m(I_prev) if n < gen_start_scale: #anything less than final z_curr = Z_opt #Z_opt comes from trained pyramid.... z_in = noise_amp * (z_curr) + I_prev if noise_solutions is not None: z_curr = noise_solutions[n] z_in = (1 - factor) * noise_amp * ( z_curr ) + I_prev + factor * noise_amp * z_orig #adds in previous image to z_opt''' I_curr = G(z_in.detach(), I_prev) if base is not None: if n == insert_limit: I_curr = bases[n] * factor + I_curr * (1 - factor) if anchor is not None and direction is not None: anchor_curr = anchors[n] I_curr = reinforcement(anchor_curr, I_curr, directions[n]) #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n) ###### ENFORCE LH = ANCHOR FOR IMAGE ####### if n == opt.stop_scale: #hacky code if anchor is not None and direction is not None: anchor_curr = anchors[n] I_curr = reinforcement(anchor_curr, I_curr, direction) #I_curr = reinforcement_sigmoid(anchor_curr, I_curr, direction, n) array = functions.convert_image_np(I_curr.detach()) images_cur.append(I_curr) n += 1 return array
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50): real = functions.read_image(opt) real = real.numpy() real = resize(real, reals[-1].shape) real = torch.from_numpy(real) new_reals = creat_reals_pyramid(real, [], opt) buffer = [] for new_real, real in zip(new_reals, reals): ele = new_real.numpy() ele = resize(ele, real.shape) ele = torch.from_numpy(ele) buffer.append(ele) reals = buffer for i, real_img in enumerate(reals): dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) plt.imsave('%s/%s_%d.png' % (dir2save, "real", i), functions.convert_image_np(real_img.detach()), vmin=0, vmax=1) if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h # For Section IV # if n == 0: # images_prev = images_cur # else: # new_img_prev = [] # for img in images_cur: # ele = reals[n].numpy() # ele = resize(ele, img.shape) # ele = torch.from_numpy(ele) # new_img_prev.append(ele) # images_prev = new_img_prev images_prev = images_cur # if n != 0: # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_cur", n), functions.convert_image_np(images_prev[0].detach()), vmin=0,vmax=1) # plt.imsave('%s/%s_%d.png' % (dir2save, "img_prev", n), functions.convert_image_np(images_cur[0].detach()), vmin=0,vmax=1) images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev if opt.skip != '' and int(opt.skip) == n: I_curr = I_prev else: I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # For Section VI # if opt.mode == 'train': # dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) # else: # dir2save = functions.generate_dir2save(opt) # try: # os.makedirs(dir2save) # except OSError: # pass # if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): # plt.imsave('%s/%d_%d.png' % (dir2save, i, n), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train_single_scale( netD, #current discriminator netG, #current generator reals, #the list of all resized data Gs, #generator list Zs, # in_s, # NoiseAmp, # opt, #parameters centers=None): real = reals[len(Gs)] # get the current resized real picture #get the x and y opt.nzx = real.shape[2] opt.nzy = real.shape[3] #receptive field opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride #padding width pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) #this stuff create a torch.nn class adding 0 pads, tf is slightly harder m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) #get alpha from opt alpha = opt.alpha #generate a noise in the following size fixed_noise = functions.generate_noise( [ opt.nc_z, #noise # channels opt.nzx, opt.nzy ], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) #generate a tensor of size fixed_noise.shape filled with 0. z_opt = m_noise(z_opt) #give it a zero pad with width int(pad_noise) # setup optimizer and learning rate optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) #some plot list errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] #for iteration number' loop for epoch in tqdm_notebook(range(opt.niter), desc=f"scale {len(Gs)}", leave=False): #if it's the first graph, for G need an additional imput if (Gs == []): #generate a noise of size [1,opt.nzx,opt.nzy] z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it a zero pad with width int(pad_noise) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) #generate another noise noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) #give it additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) # when it's not the first graph else: #nc_z is 'noise # channels' noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) #the padding it in all the dimension noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### # for Discriminator inner steps' loop for j in range(opt.Dsteps): # train with real #before training reset grad, torch operation netD.zero_grad() #generate a result output = netD(real).to(opt.device) #error for D, to minimize -(D(x) + D(G(z))), the mean should be -1 errD_real = -output.mean() #-a # have all the gradients computed errD_real.backward(retain_graph=True) #return the list with all dictionary keys with negative values D_x = -errD_real.item() # train with fake # for the first loop in the first epoch if (j == 0) & (epoch == 0): #if it's the first scale if (Gs == []): #set prev to all 0 prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev #zero padding with width int(pad_image) prev = m_image(prev) #set z_prev to all 0 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) #padding with noise z_prev = m_noise(z_prev) #set amp = 1 opt.noise_amp = 1 else: # generate the prev from rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) #zero padding with width int(pad_image) prev = m_image(prev) # generate the z_prev from rec mode z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) #use opt.noise_amp_init*RMSE as the loss criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE # add a padding of width (int(pad_image)) z_prev = m_image(z_prev) else: #generate the prev form rand mode prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # add a padding of width (int(pad_image)) prev = m_image(prev) #if it's the first scale if (Gs == []): #a noise added additional dimention with size 3, in all these dimension all 3 layers are the same #the padding it in all the dimension noise = noise_ else: #amplify the padded noise + prev noise = opt.noise_amp * noise_ + prev # generate the fake graph with noise # detach() detaches the output from the computationnal graph. # So no gradient will be backproped along this variable # in the very first loop G is RAW now fake = netG(noise.detach(), prev) # generate the output output = netD(fake.detach()) # generate the error from fake, to minimize -(D(x) + D(G(z))), the mean should be positive errD_fake = output.mean() # have all the gradients computed errD_fake.backward(retain_graph=True) #get the discriminator D_G_z = output.mean().item() #calculate the penalty gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) #calculate gradient gradient_penalty.backward() #calculate penal D errD = errD_real + errD_fake + gradient_penalty #updates the parameters. optimizerD.step() #add the stuff into a record errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # init to 0 netG.zero_grad() # generate the output from the discrimator output = netD(fake) #the the loss of G is negative to result of D for competition errG = -output.mean() #calculate the backward errG.backward(retain_graph=True) if alpha != 0: #define MSE loss loss = nn.MSELoss() #amplify the z Z_opt = opt.noise_amp * z_opt + z_prev #use the result generate from Z_opt.detach(),z_prev, calculate the MSE with real, scale with alpha rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) #backward rec_loss.backward(retain_graph=True) #get a number loss rec_loss = rec_loss.detach() else: #alpha = 0 #else get Z as z Z_opt = z_opt #set the rec_loss = o rec_loss = 0 #update the result optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) #if epoch % 25 == 0 or epoch == (opt.niter-1): #replaced by tqdm #print(f'scale {len(Gs)}:[{epoch}/{opt.niter}]' #if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): #only saved once (for small graph) #save the fake sample plt.imsave(f'{opt.outf}/fake_sample.png', functions.convert_image_np(fake.detach()), vmin=0, vmax=1) #save the z_opt plt.imsave(f'{opt.outf}/G(z_opt).png', functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #save the model torch.save(z_opt, f'{opt.outf}/z_opt.pth') #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #update learning rate schedulerD.step() schedulerG.step() # save the model functions.save_networks(netG, netD, z_opt, opt) # return the z, in_s(what's this), and generator G return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, crops, masks, eye_color, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real_fullsize = reals[len(Gs)] crop_size = crops[len(Gs)].size()[2] fixed_crop = real_fullsize[:, :, 0:crop_size, 0:crop_size] if opt.random_crop: real, h_idx, w_idx = functions.random_crop(real_fullsize.clone(), crop_size) else: real = real_fullsize.clone() mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) width opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) height opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) for epoch in range(opt.niter): if opt.resize: max_patch_size = int( min(real.size()[2], real.size()[3], mask.size()[2] * 1.25)) min_patch_size = int(max(mask.size()[2] * 0.75, 1)) patch_size = random.randint(min_patch_size, max_patch_size) mask_in = nn.functional.interpolate(mask.clone(), size=patch_size) eye_in = nn.functional.interpolate(eye.clone(), size=patch_size) else: mask_in = mask.clone() eye_in = eye.clone() eye_colored = eye_in.clone() if opt.random_eye_color: eye_color = functions.get_eye_color(real) eye_colored[:, 0, :, :] *= (eye_color[0] / 255) eye_colored[:, 1, :, :] *= (eye_color[1] / 255) eye_colored[:, 2, :, :] *= (eye_color[2] / 255) if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() #print(z_prev.get_device()) #print(real.get_device()) RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # Stacking masks and noise to make input G_input = functions.make_input(noise, mask_in, eye_colored) fake_background = netG(G_input.detach(), prev) import copy netG_copy = copy.deepcopy(netG) # Cropping mask shape from generated image and putting on top of real image at random location fake, fake_ind, eye_ind = functions.gen_fake( real, fake_background, mask_in, eye_in, eye_color, opt) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev input_opt = functions.make_input(Z_opt, mask_in, eye_in) rec_loss = alpha * loss(netG(input_opt.detach(), z_prev), real) #rec_loss = alpha*loss(netG(input_opt.detach(),z_prev),fixed_crop) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach())) plt.imsave('%s/fake_indicator.png' % (opt.outf), functions.convert_image_np(fake_ind.detach())) plt.imsave('%s/eye_indicator.png' % (opt.outf), functions.convert_image_np(eye_ind.detach())) plt.imsave('%s/background.png' % (opt.outf), functions.convert_image_np(fake_background.detach())) #plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(input_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() if opt.random_crop: real, h_idx, w_idx = functions.random_crop( real_fullsize, crop_size) #randomly find crop in image if opt.random_eye: eye = functions.generate_eye_mask(opt, masks[-1], opt.stop_scale - len(Gs)) functions.save_networks(netG, netD, z_opt, opt) if len(Gs) == (opt.stop_scale): netG = netG_copy return z_opt, in_s, netG
def SinGAN_denoise(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=1): #if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev I_curr = G(z_in.detach(), I_prev) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): # print("Gs:", Gs) # Gs:scale尺度 print("len(Gs):", len(Gs)) # 获取当前scale的真实值 real = reals[len(Gs)] opt.nzx = real.shape[2] # +(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] # +(opt.ker_size-1)*(opt.num_layer) # 接受野 opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride print(opt.receptive_field) # out: 3+2*4*1 = 11 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) print(pad_noise) # pad_noise: 5 pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) # ZeroPad2d 在输入的数据周围做zero-padding m_noise = nn.ZeroPad2d(int(pad_noise)) print('m_noise', m_noise) # ZeroPad2d(padding=(5, 5, 5, 5), value=0.0) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha print(alpha) fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) # print('fixed_noise', fixed_noise.shape) # torch.Size([1, 3, 76, 76]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # 设置优化器 optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # 绘画损失列表 list_ssim = [] list_ssim_1 = [] list_ssim_2 = [] list_ssim_3 = [] list_ssim_4 = [] errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] # 循环迭代 for epoch in range(opt.niter): schedulerD.step() schedulerG.step() # 与运算 if (Gs == []) & (opt.mode != 'SR_train'): # opt.nzx和opt.nzy是当前scale的尺寸 z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) # 扩充维度为(1, 3, opt.nzx, opt.nzy) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) # 更新判别器 ########################### # Dsteps = 3 for j in range(opt.Dsteps): # train with real 用真实图像训练 netD.zero_grad() output = netD(real).to(opt.device) # print(netD) # print('output', output) # 4维数据 errD_real = -output.mean() # -a # print('errD_real', errD_real) # -2. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake 用虚假图像训练 # 仅第一次训练用到z_prev(噪声) if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) # nc_z 3通道 z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) # print('z_prev', z_prev) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s # MSE 军方误差损失函数 criterion = nn.MSELoss() # 均方根误差, 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() # 标准误差 RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev # .detach()用于切断反向传播 fake = netG(noise.detach(), prev) # 添加的ssim_loss ssim_loss = ssim(real, fake, data_range=255, size_average=True) output = netD(fake.detach()) # 判别以后损失反向传播 errD_fake = output.mean() # print('errD_fake', errD_fake) # -0.0072 errD_fake.backward(retain_graph=True) # 判别器_生成器_噪声 D_G_z = output.mean().item() # print('D_G_z', D_G_z) # 梯度惩罚--------------------------------------------------------------------------- gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) # print('gradient_penalty', gradient_penalty) # 梯度惩罚更新 gradient_penalty.backward() # 损失函数:对抗损失 + 重构损失 D_ssim_3 = errD_real + errD_fake + gradient_penalty errD = errD_real + errD_fake + gradient_penalty # print('item', errD.item()) # D_ssim = 0.8 * (errD_real + errD_fake) + 0.68 * gradient_penalty + (1 - ssim_loss) # D_ssim_1 = 0.7 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # D_ssim_2 = 0.75 * (errD_real + errD_fake) + 0.6 * gradient_penalty + 1.2 * (1 - ssim_loss) # errD = (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # D_ssim_4 = 0.6 * (errD_real + errD_fake) + 0.5 * gradient_penalty + 1.4 * (1 - ssim_loss) # int_ssim = D_ssim.item() # int_ssim = round(int_ssim, 4) # int_ssim_1 = D_ssim_1.item() # int_ssim_1 = round(int_ssim_1, 4) # # int_ssim_2 = D_ssim_2.item() # int_ssim_2 = round(int_ssim_2, 4) int_ssim_3 = D_ssim_3.item() int_ssim_3 = round(int_ssim_3, 4) optimizerD.step() errDint = [] errD2plot.append(errD.detach()) # print('errD2plot', errD2plot) for i in range(len(errD2plot)): errDint.append(errD2plot[i].cpu().numpy()) # list_ssim.append(int_ssim) # list_ssim_1.append(int_ssim_1) # list_ssim_2.append(int_ssim_2) # list_ssim_3.append(int_ssim_3) # list_ssim_4.append(int_ssim_4) # print('list_ssim', list_ssim) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) # D_fake_map = output.detach() # errG均值函数 errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errGint = [] errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) for i in range(len(errG2plot)): errGint.append(errG2plot[i].cpu().numpy()) if epoch % 100 == 0 or epoch == (opt.niter - 1): # len(Gs):scale epoch= , niter = 2000 print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) # plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) # plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) # 目前是训练单次scale, 绘制第一次scale while (len(Gs) == 0): name = 'G_loos & G_loos' functions.plot_learning_curves(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break while (len(Gs) == 1): name = 'G_loos & G_loos_1' functions.plot_learning_curves_1(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break # while (len(Gs) == 2): # name = 'G_loos & G_loos_2' # functions.plot_learning_curves(errGint, errDint, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', # 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 3): # name = 'G_loos & G_loos_3' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 4): # name = 'G_loos & G_loos_4' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 5): # name = 'G_loos & G_loos_5' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 6): # name = 'G_loos & G_loos_6' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # # while (len(Gs) == 7): # name = 'G_loos & G_loos_7' # functions.plot_learning_curves(errGint, errDint, list_ssim, list_ssim_1, list_ssim_2, # list_ssim_3, list_ssim_4, opt.niter, 'labelG', 'labelD', 'ssim', # 'ssim_1', 'ssim_2', 'ssim_3', 'ssim_4', name) # break # while (len(Gs) == 8): name = 'G_loos & G_loos_8' functions.plot_learning_curves_8(errGint, errDint, opt.niter, 'Generator', 'Discriminator', name) break functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD,netG,reals,Gs,Zs,in_s,NoiseAmp,opt,scale_num, netG_optimizer, netD_optimizer): real = reals[len(Gs)] opt.nzx = real.shape[1]#+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[2]#+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = tf.keras.layers.ZeroPadding2D(padding=(int(pad_noise), int(pad_noise))) m_image = tf.keras.layers.ZeroPadding2D(padding=(int(pad_image), int(pad_image))) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy]) z_opt = tf.zeros_like(fixed_noise) z_opt = m_noise(z_opt) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): with tf.GradientTape(persistent=True) as netD_tape, tf.GradientTape(persistent=True) as netG_tape: if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1,opt.nzx,opt.nzy]) # (1,33,25) z_opt = tf.broadcast_to(z_opt, [1, z_opt.shape[1], z_opt.shape[2], 3]) # (1,33,25,3) z_opt = m_noise(z_opt) noise_ = functions.generate_noise([1,opt.nzx,opt.nzy]) noise_ = tf.broadcast_to(noise_, [1, noise_.shape[1], noise_.shape[2], 3]) noise_ = m_noise(noise_) else: noise_ = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): real = reals[len(Gs)] output_real = netD(real) errD_real = -tf.reduce_mean(output_real) D_x = float(-errD_real.numpy()) # Conversion to numpy required # train with fake if (j==0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = tf.zeros([1,opt.nzx,opt.nzy,opt.nc_z]) in_s = prev prev = m_image(prev) z_prev = tf.zeros([1, opt.nzx, opt.nzy, opt.nc_z]) z_prev = m_noise(z_prev) opt.noise_amp = 1 else: prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt) prev = m_image(prev) z_prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rec',m_noise,m_image,opt) RMSE = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(real, z_prev)))) opt.noise_amp = opt.noise_amp_init*RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs,Zs,reals,NoiseAmp,in_s,'rand',m_noise,m_image,opt) prev = m_image(prev) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp*noise_+prev fake = netG(noise, prev, training=True) output_fake = netD(fake) errD_fake = tf.reduce_mean(output_fake) D_G_z = float(output_fake.numpy().mean()) fake_gp = functions.fake_gp_generator(real, fake) with tf.GradientTape() as gp_tape: gp_tape.watch(fake_gp) gp_D_src = netD(fake_gp) gp_D_grad = gp_tape.gradient(gp_D_src, fake_gp) gp = opt.lambda_grad*tf.reduce_mean(((tf.norm(gp_D_grad, ord=2, axis=3)-1.0)**2)) errD = errD_real + errD_fake + gp print('errD_real:', errD_real) print('errD_fake:', errD_fake) netD_gradients = netD_tape.gradient(errD, netD.trainable_variables) netD_optimizer.apply_gradients(zip(netD_gradients, netD.trainable_variables)) for j in range(opt.Gsteps): errG_fake = -tf.reduce_mean(output_fake) if alpha!=0: Z_opt = opt.noise_amp*z_opt+z_prev rec_loss = alpha * tf.reduce_mean(tf.square(tf.subtract(netG(Z_opt, z_prev, training=True), real))) else: Z_opt = z_opt rec_loss = 0 errG = errG_fake + rec_loss print('errG_fake:', errG_fake) print('rec_loss:', rec_loss) netG_gradients = netG_tape.gradient(errG, netG.trainable_variables) # 오직 netG만 update! netG_optimizer.apply_gradients(zip(netG_gradients, netG.trainable_variables)) del netG_tape, netD_tape errG2plot.append(errG+rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 1 == 0 or epoch == (opt.niter-1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter-1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake)) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np(netG(Z_opt, z_prev, training=False))) functions.save_networks(netD,netG,z_opt,opt,scale_num) # single scale training 끝날때 마다 저장 함 return z_opt,in_s,netG
# 3. Generate GIFs (varying beta && start_scale) in_s = torch.full(Zs[0].shape, 0, device=device) images_cur = [] count = 0 for G, Z_opt, noise_amp, real in zip(Gs, Zs, NoiseAmp, reals): pad_image = int(((ker_size - 1) * num_layer) / 2) # what it means?? nzx = Z_opt.shape[2] nzy = Z_opt.shape[3] m_image = nn.ZeroPad2d(int(pad_image)) images_prev = images_cur images_cur = [] if count == 0: # z_rand is gaussian noise z_rand = functions.generate_noise([1, nzx, nzy], device=device) z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3]) z_prev1 = 0.95 * Z_opt + 0.05 * z_rand z_prev2 = Z_opt else: z_prev1 = 0.95 * Z_opt + 0.05 * functions.generate_noise( [nc_z, nzx, nzy], device=device) z_prev2 = Z_opt for i in range(0, 100, 1): if count == 0: z_rand = functions.generate_noise([1, nzx, nzy], device=device) # make z_rand same across channels z_rand = z_rand.expand(1, 3, Z_opt.shape[2], Z_opt.shape[3]) diff_curr = beta * (z_prev1 - z_prev2) + (1 - beta) * z_rand else:
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): stamp = datetime.datetime.now() timestamp.append(stamp) delta = (timestamp[-1] - timestamp[-2]).seconds mbs, percent = memory_check() print('scale %d:[%d/%d] | Mb: %.3f | Percent %.3f | secs %.3f ' % (len(Gs), epoch, opt.niter, mbs, 100 * percent, delta)) full_memory.append([mbs, percent]) full_time.append(delta) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) nvidia_smi.nvmlInit() handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) mbs = mem_res.used / (1024**2) percent = mem_res.used / mem_res.total print(f'mem: {mem_res.used / (1024**2)} (GiB)') # usage in GiB print(f'mem: {100 * (mem_res.used / mem_res.total):.6f}%') # percentage return z_opt, in_s, netG, mbs, percent
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): schedulerD.step() schedulerG.step() if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy]) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy]) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy]) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=50, output_image=False): #if torch.is_tensor(in_s) == False: if in_s is None: # make in_s a 0 tensor with reals[0] shape in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] #for each layers for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): #generate a pad class with width ((ker_size-1)*num_layer)/2 pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) #the shape inside padding * scale nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h #get all the previsous image images_prev = images_cur images_cur = [] output_list = [] #for the number of samples for i in range(0, num_samples, 1): if n == 0: #generate the noise z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) #broadcast to the correct shape z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) #padding it z_curr = m(z_curr) else: #generate noise with defined shape z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) #padding z_curr = m(z_curr) #if it's the first scale if images_prev == []: #use in_s as the first one I_prev = m(in_s) #I_prev = m(I_prev) #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: #get the last image I_prev = images_prev[i] #resize it by 1/scale_factor I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) # cut a piece of shape (round(scale_v * reals[n].shape[2] * round(scale_h * reals[n].shape[3])) I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] #padding I_prev = m(I_prev) #cut a piece of shape (z_curr.shape[2], z_curr.shape[3]) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] #upsample this piece to original shape, with bilinear policy I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) # amplify the z by the param, add the previous graph z_in = noise_amp * (z_curr) + I_prev # pass this value and previous graph to generator, get the value I_curr = G(z_in.detach(), I_prev) #for the last loop if n == len(reals) - 1: #generate the directory dir2save = functions.generate_dir2save(opt) #modified try: os.makedirs(dir2save) except OSError: pass # new variable if (output_image): #save the new generated image plt.imsave(f'{dir2save}/{i}.png', functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # have the generated image into the list output_list.append(functions.convert_image_np(I_curr.detach())) images_cur.append(I_curr) n += 1 return I_curr.detach(), output_list #newly added
def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50): #if torch.is_tensor(in_s) == False: passes = 0 if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp): pad1 = ((opt.ker_size-1)*opt.num_layer)/2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2]-pad1*2)*scale_v nzy = (Z_opt.shape[3]-pad1*2)*scale_h images_prev = images_cur images_cur = [] for i in range(0,num_samples,1): if n == 0: z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: print("in_s shape before padding with m", in_s.shape) I_prev = m(in_s) print("in_s shape after padding with m now I_prev", in_s.shape) # I_prev = m(I_prev) # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) print("I_prev shape after upsampling using noise shape", I_prev.shape) else: I_prev = images_prev[i] I_prev = imresize(I_prev,1/opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt # real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Salt_and_Pepper_Golden_Bridge_by_night (2).png") real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Input\Images/Noisy_Golden_Bridge_by_night.jpg") # real = img.imread("D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night/1.png") real = real[:,:,:,None] real = real.transpose((3,2,0,1))/255 real = torch.from_numpy(real) real = move_to_gpu(real) real = real.type(torch.cuda.FloatTensor) real = ((real - 0.5)*2).clamp(-1,1) real = real[:,0:3,:,:] # real = imresize(real,1/opt.scale_factor, opt) # real = real[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] real = m(real) # real = real[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] I_prev = functions.upsampling(real,z_curr.shape[2],z_curr.shape[3]) print("I_prev",I_prev.shape) print("z_curr",z_curr.shape) print('---') # z_in = noise_amp*(z_curr)+I_prev z_in = 0*(z_curr)+I_prev I_curr = G(z_in.detach(),I_prev) if n == len(reals)-1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # plt.imsave('%s/%d.png' % (dir2save, passes), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) # passes +=1 #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n+=1 # plt.imsave('D:\MVA\CompVision\Project\SinGAN-master\Output\RandomSamples\Golden_Bridge_by_night\gen_start_scale=0\Denoised.png', functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) return I_curr.detach()
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): """This is the core to understand it all. """ real = reals[len(Gs)] # real image at current scale opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': # Supplementary says they generate noise on the border in animation mode opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy ]) # select a fixed noise at start z_opt = torch.full( fixed_noise.shape, 0, device=opt.device) # but they didn't use it, just used all 0 instead. z_opt = m_noise(z_opt) # get 0 padded (still all 0) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] # collect the error in D training errG2plot = [] # collect the error in G training D_real2plot = [] # collect D_x error for real image D_fake2plot = [] # Discriminator error for fake image z_opt2plot = [] # collect reconstruction error for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): # if it's the first scale. z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand( 1, 3, opt.nzx, opt.nzy)) # this is chosen start of each epoch noise_ = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) # single channel noise noise_ = m_noise( noise_.expand(1, 3, opt.nzx, opt.nzy) ) # expand is like repeat, it copy data along channel axis. So noise in RGB channels share the same val else: # for each epocs only one noise_ is used! why? noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # a few D step first # train with real netD.zero_grad() output = netD(real).to( opt.device) # Output of netD, the mean of it is the score! #D_real_map = output.detach() errD_real = -output.mean( ) #-a # want to maximize D output for patches in real img. errD_real.backward(retain_graph=True) D_x = -errD_real.item() # D loss for the real image! # train with fake, need to generate through the previous Generators if (j == 0) & (epoch == 0): # first Dstep in this level (epoch 0) if (Gs == []) & (opt.mode != 'SR_train'): # initial scale prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev # in_s doesn't get padded! prev = m_image(prev) # prev gets padded! z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion( real, z_prev)) # MSE between real and z_prev opt.noise_amp = opt.noise_amp_init * RMSE # learn the noise amplitude z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) # process the in_s ! prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): # top level noise = noise_ # now, full 0 tersor else: # other level noise = opt.noise_amp * noise_ + prev # now, full 0 tersor still # generate a single fake image through G and pass to D fake = netG( noise.detach(), prev ) # netG takes 2 inputs prev and noise, net process noise and + prev output = netD( fake.detach() ) # netD score the fake image, note they detach here, so error not back prop to G errD_fake = output.mean() # decrease the D output for fake. errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach().item() ) # loss combining D loss for real, fake and gradient ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): # then a few G steps netG.zero_grad() output = netD( fake) # note here the same fake image is used multi times!??? #D_fake_map = output.detach() errG = -output.mean( ) # the D, adversarial loss part. here want to minimize adversarial loss. (Fake the img) errG.backward(retain_graph=True) # Why? retain_graph if alpha != 0: # compute the reconstruction loss is alpha non-zero loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers( z_prev, centers ) # z_prev here are all inherited from D training part plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) # Why? retain_graph rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach().item() + rec_loss.item()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss.item()) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave( '%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1 ) # this is the noise that go into generator, so prev + amp * noise_ plt.imsave('%s/Z_opt.png' % (opt.outf), functions.convert_image_np(Z_opt), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save( { "errD": errD2plot, "errG": errG2plot, "D_real": D_real2plot, "D_fake": D_fake2plot, "recons": z_opt2plot }, '%s/loss_trace.pth' % (opt.outf)) plt.figure() plt.plot(errD2plot, label="errD") plt.plot(errG2plot, label="errG") plt.plot(D_real2plot, label="Dreal") plt.plot(D_fake2plot, label="Dfake") plt.plot(z_opt2plot, label="recons") plt.legend() plt.savefig('%s/loss_trace.png' % (opt.outf)) plt.close() torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, modification=None, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=10): # start_scale = here we manipulate the image # func stylize # if torch.is_tensor(in_s) == False: if in_s is None: in_s = torch.full(reals[0].shape, 0, device=opt.device) images_cur = [] for G, Z_opt, noise_amp in zip(Gs, Zs, NoiseAmp): pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 m = nn.ZeroPad2d(int(pad1)) nzx = (Z_opt.shape[2] - pad1 * 2) * scale_v nzy = (Z_opt.shape[3] - pad1 * 2) * scale_h images_prev = images_cur images_cur = [] for i in range(0, num_samples, 1): if n == 0: z_curr = functions.generate_noise([1, nzx, nzy], device=opt.device) z_curr = z_curr.expand(1, 3, z_curr.shape[2], z_curr.shape[3]) z_curr = m(z_curr) else: z_curr = functions.generate_noise([opt.nc_z, nzx, nzy], device=opt.device) z_curr = m(z_curr) if images_prev == []: I_prev = m(in_s) # I_prev = m(I_prev) # I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] # I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) if opt.mode != "SR": I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] I_prev = m(I_prev) I_prev = I_prev[:, :, 0:z_curr.shape[2], 0:z_curr.shape[3]] I_prev = functions.upsampling(I_prev, z_curr.shape[2], z_curr.shape[3]) else: I_prev = m(I_prev) if n < gen_start_scale: z_curr = Z_opt z_in = noise_amp * (z_curr) + I_prev dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], modification, gen_start_scale) try: os.makedirs(dir2save) except OSError: pass if n==gen_start_scale: plt.imsave('%s/%d_before_modification.png' % (dir2save, i), functions.convert_image_np(z_in.detach()), vmin=0,vmax=1) # ##################################### Image modification ################################################# #TODO if you want the modification to happen only once, change the >= into == #TODO at the moment, modification happens at every scale from the gen_start_scale and above, unless no #TODO modification is specificed #TODO The modified image is saved only at the generation scale #TODO when using blending, consider trying different blending options and opcity. These can be modified #TODO within the modify_input_to_generator function below if (n >= gen_start_scale) & (modification is not None): shape = z_in.shape cont_in = preprocess_content_image(opt, reals,n) z_in = modify_input_to_generator(z_in, cont_in, modification, opacity=1) assert shape == z_in.shape if n==gen_start_scale: plt.imsave('%s/%d_after_modification.png' % (dir2save, i), functions.convert_image_np(z_in.detach()), vmin=0,vmax=1) # ################################## End of image modification ############################################# I_curr = G(z_in.detach(), I_prev) if n == len(reals) - 1: if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], modification, gen_start_scale) else: #dir2save = functions.generate_dir2save(opt) dir2save = '%s/RandomSamples/%s/%s/gen_start_scale=%d' % ( opt.out, opt.input_name[:-4], modification, gen_start_scale) try: os.makedirs(dir2save) except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & ( opt.mode != "paint2image"): plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) # plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) images_cur.append(I_curr) n += 1 return I_curr.detach()
def train_single_scale(netD, netD_mask1, netD_mask2,netG,reals1, reals2, Gs,Zs,in_s1, in_s2,NoiseAmp,opt): real1 = reals1[len(Gs)] real2 = reals2[len(Gs)] if opt.replace_background: background_real1 = create_background(functions.convert_image_np(real1)) real2 = create_img_over_background(functions.convert_image_np(real2), background_real1) plt.imsave('%s/background_real_scale1.png' % (opt.outf), background_real1, vmin=0, vmax=1) plt.imsave('%s/real_scale2_new.png' % (opt.outf), real2, vmin=0, vmax=1) real2 = functions.np2torch(real2, opt) # assumption: the images are the same size opt.nzx = real1.shape[2]#+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real1.shape[3]#+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size-1)*(opt.num_layer-1))*opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z,opt.nzx,opt.nzy],device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) l1_loss = nn.L1Loss() zero_mask_tensor = torch.zeros([1,opt.nzx,opt.nzy], device=opt.device) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d) optimizerD_masked1 = optim.Adam(netD_mask1.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask1) optimizerD_masked2 = optim.Adam(netD_mask2.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay_d_mask2) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) if opt.cyclic_lr: schedulerD = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerD_masked1 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked1, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerD__masked2 = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerD_masked2, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) schedulerG = torch.optim.lr_scheduler.CyclicLR(optimizer=optimizerG, base_lr=opt.lr_d*opt.gamma, max_lr=opt.lr_d, step_size_up=opt.niter/10, mode="triangular2", cycle_momentum=False) else: schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD,milestones=[1600],gamma=opt.gamma) schedulerD_masked1 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked1, milestones=[1600], gamma=opt.gamma) schedulerD__masked2 = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD_masked2, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG,milestones=[1600],gamma=opt.gamma) discriminators = [netD, netD_mask1, netD_mask2] discriminators_optimizers = [optimizerD, optimizerD_masked1, optimizerD_masked2] discriminators_schedulers = [schedulerD, schedulerD_masked1, schedulerD__masked2] err_D_img1_2plot = [] err_D_img2_2plot = [] err_D_mask1_2plot = [] err_D_mask2_2plot = [] errG_total_loss_2plot = [] errG_total_loss1_2plot = [] errG_total_loss2_2plot = [] errG_fake1_2plot = [] errG_fake2_2plot = [] D1_real2plot = [] D2_real2plot = [] D1_fake2plot = [] D2_fake2plot = [] l1_mask_loss2plot = [] mask_loss2plot = [] reconstruction_loss1_2plot = [] reconstruction_loss2_2plot = [] for epoch in range(opt.niter): """ We want to ensure that there exists a specific set of input noise maps, which generates the original image x. We specifically choose {z*, 0, 0, ..., 0}, where z* is some fixed noise map. In the first scale, we create that z* (aka z_opt). On other scales, z_opt is just zeros (initialized above) """ is_first_scale = len(Gs) == 0 noise1_, z_opt1 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z1) noise2_, z_opt2 = _create_noise_for_iteration(is_first_scale, m_noise, opt, z_opt, NoiseMode.Z2) ############################ # (1) Update D networks: # - netD: train with real on 2 input images (real1, real2) # - netD: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) # if opt.enable_mask is ON, then: # - netD_mask1: train with real on real1 # - netD_mask2: train with real on real2 # - netD_mask1: train with fake on the generated fake image with mask1 applied on it # - netD_mask2: train with fake on the generated fake image with mask2 applied on it ########################### for j in range(opt.Dsteps): # train with real for discriminator in discriminators: discriminator.zero_grad() errD_real1, D_x1 = discriminator_train_with_real(netD, opt, real1) errD_real2, D_x2 = discriminator_train_with_real(netD, opt, real2) # single discriminator for each image if opt.enable_mask: errD_mask1_real1, _ = discriminator_train_with_real(netD_mask1, opt, real1) errD_mask2_real2, _ = discriminator_train_with_real(netD_mask2, opt, real2) # train with fake in_s1, noise1, prev1, new_z_prev1 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch, in_s1, is_first_scale, j, m_image, m_noise, noise1_, opt, real1, reals1, NoiseMode.Z1) in_s2, noise2, prev2, new_z_prev2 = _prepare_discriminator_train_with_fake_input(Gs, NoiseAmp, Zs, epoch, in_s2, is_first_scale, j, m_image, m_noise, noise2_, opt, real2, reals2, NoiseMode.Z2) if new_z_prev1 is not None: z_prev1 = new_z_prev1 if new_z_prev2 is not None: z_prev2 = new_z_prev2 # Z1 only: mixed_noise1 = functions.merge_noise_vectors(noise1, torch.zeros(noise1.shape, device=opt.device), opt.noise_vectors_merge_method) fake1_output = _generate_fake(netG, mixed_noise1, prev1) if opt.enable_mask: fake1, fake1_mask1, fake1_mask2 = fake1_output else: fake1 = fake1_output[0] D_G_z_1, errD_fake1, gradient_penalty1 = _train_discriminator_with_fake(netD, fake1, opt, real1) # Z2 only: mixed_noise2 = functions.merge_noise_vectors(torch.zeros(noise2.shape, device=opt.device), noise2, opt.noise_vectors_merge_method) fake2_output = _generate_fake(netG, mixed_noise2, prev2) if opt.enable_mask: fake2, fake2_mask1, fake2_mask2 = fake2_output else: fake2 = fake2_output[0] D_G_z_2, errD_fake2, gradient_penalty2 = _train_discriminator_with_fake(netD, fake2, opt, real2) if opt.enable_mask: _, errD_mask1_fake1, _ = _train_discriminator_with_fake(netD_mask1, fake1_mask1, opt, real1) _, errD_mask2_fake2, _ = _train_discriminator_with_fake(netD_mask2, fake2_mask2, opt, real2) errD_image1 = errD_real1 + errD_fake1 + gradient_penalty1 errD_image2 = errD_real2 + errD_fake2 + gradient_penalty2 for discriminator_optimizer in discriminators_optimizers: discriminator_optimizer.step() err_D_img1_2plot.append(errD_image1.detach()) err_D_img2_2plot.append(errD_image2.detach()) ############################ # (2) Update G network: # - netG: train with fake on 2 fake images from different noise source (NoiseMode.Z1, NoiseMode.Z2) against netD # - netG: reconstruction loss against 2 real images # if opt.enable_mask is ON, then: # - netD_mask1: train with fake on the generated fake image with mask1 applied on it against netD_mask1 # - netD_mask2: train with fake on the generated fake image with mask2 applied on it against netD_mask2 ########################### for j in range(opt.Gsteps): netG.zero_grad() errG_fake1, D_fake1_map = _generator_train_with_fake(fake1, netD) errG_fake2, D_fake2_map = _generator_train_with_fake(fake2, netD) rec_loss1, Z_opt1 = _reconstruction_loss(alpha, netG, opt, z_opt1, z_prev1, real1, NoiseMode.Z1, opt.noise_amp1) rec_loss2, Z_opt2 = _reconstruction_loss(alpha, netG, opt, z_opt2, z_prev2, real2, NoiseMode.Z2, opt.noise_amp2) if opt.enable_mask: mask_loss_fake1_mask1, D_mask1_fake1_mask1_map = _generator_train_with_fake(fake1_mask1, netD_mask1) mask_loss_fake2_mask1, D_mask1_fake2_mask1_map = _generator_train_with_fake(fake2_mask1, netD_mask1) mask_loss_fake1_mask2, D_mask2_fake1_mask2_map = _generator_train_with_fake(fake1_mask2, netD_mask2) mask_loss_fake2_mask2, D_mask2_fake2_mask2_map = _generator_train_with_fake(fake2_mask2, netD_mask2) optimizerG.step() errG_total_loss1_2plot.append(errG_fake1.detach()+rec_loss1) errG_total_loss2_2plot.append(errG_fake2.detach()+rec_loss2) errG_fake1_2plot.append(errG_fake1.detach()) errG_fake2_2plot.append(errG_fake2.detach()) G_total_loss = errG_fake1.detach()+rec_loss1 + errG_fake2.detach()+rec_loss2 errG_total_loss_2plot.append(G_total_loss) D1_real2plot.append(D_x1) D2_real2plot.append(D_x2) D1_fake2plot.append(D_G_z_1) D2_fake2plot.append(D_G_z_2) reconstruction_loss1_2plot.append(rec_loss1) reconstruction_loss2_2plot.append(rec_loss2) if epoch % 25 == 0 or epoch == (opt.niter-1): logger.info('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter-1): plt.imsave('%s/fake_sample1.png' % (opt.outf), functions.convert_image_np(fake1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake_sample2.png' % (opt.outf), functions.convert_image_np(fake2.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt1).png' % (opt.outf), functions.convert_image_np(netG(Z_opt1.detach(), z_prev1)[0].detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt2).png' % (opt.outf), functions.convert_image_np(netG(Z_opt2.detach(), z_prev2)[0].detach()), vmin=0, vmax=1) # torch.save((mixed_noise1, prev1), '%s/fake1_noise_source.pth' % (opt.outf)) # torch.save((mixed_noise2, prev2), '%s/fake2_noise_source.pth' % (opt.outf)) # torch.save((Z_opt1, z_prev1), '%s/G(z_opt1)_noise_source.pth' % (opt.outf)) # torch.save((Z_opt2, z_prev2), '%s/G(z_opt2)_noise_source.pth' % (opt.outf)) torch.save(z_opt1, '%s/z_opt1.pth' % (opt.outf)) torch.save(z_opt2, '%s/z_opt2.pth' % (opt.outf)) if epoch == (opt.niter-1) and opt.enable_mask: plt.imsave('%s/fake1_mask1.png' % (opt.outf), functions.convert_image_np(fake1_mask1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake2_mask1.png' % (opt.outf), functions.convert_image_np(fake2_mask1.detach()), vmin=0, vmax=1) plt.imsave('%s/fake1_mask2.png' % (opt.outf), functions.convert_image_np(fake1_mask2.detach()), vmin=0, vmax=1) plt.imsave('%s/fake2_mask2.png' % (opt.outf), functions.convert_image_np(fake2_mask2.detach()), vmin=0, vmax=1) _imsave_discriminator_map(D_fake1_map, "D_fake1_map", opt) _imsave_discriminator_map(D_fake2_map, "D_fake2_map", opt) _imsave_discriminator_map(D_mask1_fake1_mask1_map, "D_mask1_fake1_mask1_map", opt) _imsave_discriminator_map(D_mask1_fake2_mask1_map, "D_mask1_fake2_mask1_map", opt) _imsave_discriminator_map(D_mask2_fake1_mask2_map, "D_mask2_fake1_mask2_map", opt) _imsave_discriminator_map(D_mask2_fake2_mask2_map, "D_mask2_fake2_mask2_map", opt) for discriminator_scheduler in discriminators_schedulers: discriminator_scheduler.step() schedulerG.step() functions.save_networks(netG,netD, netD_mask1, netD_mask2,z_opt1, z_opt2,opt) functions.plot_learning_curves("G_loss", opt.niter, [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot, errG_fake1_2plot, errG_fake2_2plot, reconstruction_loss1_2plot, reconstruction_loss2_2plot], ["G_total_loss", "G_total_loss1", "G_total_loss2", "G_fake1_loss", "G_fake2_loss", "G_recon_loss_1", "G_recon_loss_2"], opt.outf) d_plots = [err_D_img1_2plot, err_D_img2_2plot] d_labels = ["D1_total_loss", "D2_total_loss"] functions.plot_learning_curves("D_loss", opt.niter, d_plots, d_labels, opt.outf) functions.plot_learning_curves("G_vs_D_loss", opt.niter, [errG_total_loss_2plot, errG_total_loss1_2plot, errG_total_loss2_2plot, err_D_img1_2plot, err_D_img2_2plot], ["G_total_loss", "G_total_loss1", "G_total_loss2", "D1_total_loss", "D2_total_loss"], opt.outf) return (z_opt1, z_opt2), (in_s1, in_s2), netG
def invert_model(test_image, model_name, scales2invert=None, penalty=1e-3, show=True): '''test_image is an array, model_name is a name''' Noise_Solutions = [] parser = get_arguments() parser.add_argument('--input_dir', help='input image dir', default='Input/Images') parser.add_argument('--mode', default='RandomSamples') opt = parser.parse_args("") opt.input_name = model_name opt.reg = penalty if model_name == 'islands2_basis_2.jpg': #HARDCODED opt.scale_factor = 0.6 opt = functions.post_config(opt) ### Loading in Generators Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) for G in Gs: G = functions.reset_grads(G, False) G.eval() ### Loading in Ground Truth Test Images reals = [] #deleting old real images real = functions.np2torch(test_image, opt) functions.adjust_scales2image(real, opt) real_ = functions.np2torch(test_image, opt) real = imresize(real_, opt.scale1, opt) reals = functions.creat_reals_pyramid(real, reals, opt) ### General Padding pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_image = nn.ZeroPad2d(int(pad_image)) I_prev = None REC_ERROR = 0 if scales2invert is None: scales2invert = opt.stop_scale + 1 for scale in range(scales2invert): #for scale in range(3): #Get X, G X = reals[scale] G = Gs[scale] noise_amp = NoiseAmp[scale] #Defining Dimensions opt.nc_z = X.shape[1] opt.nzx = X.shape[2] opt.nzy = X.shape[3] #getting parameters for prior distribution penalty pdf = torch.distributions.Normal(0, 1) alpha = opt.reg #alpha = 1e-2 #Defining Z if scale == 0: z_init = functions.generate_noise( [1, opt.nzx, opt.nzy], device=opt.device) #only 1D noise else: z_init = functions.generate_noise( [3, opt.nzx, opt.nzy], device=opt.device) #otherwise move up to 3d noise z_init = Variable(z_init.cuda(), requires_grad=True) #variable to optimize #Building I_prev if I_prev == None: #first scale scenario in_s = torch.full(reals[0].shape, 0, device=opt.device) #all zeros I_prev = in_s I_prev = m_image(I_prev) #padding else: #otherwise take the output from the previous scale and upsample I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) #upsamples I_prev = m_image(I_prev) I_prev = I_prev[:, :, 0:X.shape[2] + 10, 0:X.shape[ 3] + 10] #making sure that precision errors don't mess anything up I_prev = functions.upsampling(I_prev, X.shape[2] + 10, X.shape[3] + 10) #seems to be redundant LR = [2e-3, 2e-2, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1, 2e-1] Zoptimizer = torch.optim.RMSprop([z_init], lr=LR[scale]) #Defining Optimizer x_loss = [] #for plotting epochs = [] #for plotting niter = [ 200, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400, 400 ] for epoch in range(niter[scale]): #Gradient Descent on Z if scale == 0: noise_input = m_noise(z_init.expand(1, 3, opt.nzx, opt.nzy)) #expand and padd else: noise_input = m_noise(z_init) #padding z_in = noise_amp * noise_input + I_prev G_z = G(z_in, I_prev) x_recLoss = F.mse_loss(G_z, X) #MSE loss logProb = pdf.log_prob(z_init).mean() #Gaussian loss loss = x_recLoss - (alpha * logProb.mean()) Zoptimizer.zero_grad() loss.backward() Zoptimizer.step() #losses['rec'].append(x_recLoss.data[0]) #print('Image loss: [%d] loss: %0.5f' % (epoch, x_recLoss.item())) #print('Noise loss: [%d] loss: %0.5f' % (epoch, z_recLoss.item())) x_loss.append(loss.item()) epochs.append(epoch) REC_ERROR = x_recLoss if show: plt.plot(epochs, x_loss, label='x_loss') plt.legend() plt.show() I_prev = G_z.detach( ) #take final output, maybe need to edit this line something's very very fishy _ = show_image(X, show, 'target') reconstructed_image = show_image(I_prev, show, 'output') _ = show_image(noise_input.detach().cpu(), show, 'noise') Noise_Solutions.append(noise_input.detach()) return Noise_Solutions, reconstructed_image, REC_ERROR
def generate_gif(Gs,Zs,reals,NoiseAmp,opt,alpha=0.1,beta=0.9,start_scale=2,fps=10): in_s = torch.full(Zs[0].shape, 0, device=opt.device) images_cur = [] count = 0 for G,Z_opt,noise_amp,real in zip(Gs,Zs,NoiseAmp,reals): pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) nzx = Z_opt.shape[2] nzy = Z_opt.shape[3] #pad_noise = 0 #m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) images_prev = images_cur images_cur = [] if count == 0: z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) z_prev1 = 0.95*Z_opt +0.05*z_rand z_prev2 = Z_opt else: z_prev1 = 0.95*Z_opt +0.05*functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) z_prev2 = Z_opt for i in range(0,100,1): if count == 0: z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*z_rand else: diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*(functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)) z_curr = alpha*Z_opt+(1-alpha)*(z_prev1+diff_curr) z_prev2 = z_prev1 z_prev1 = z_curr if images_prev == []: I_prev = in_s else: I_prev = images_prev[i] I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]] I_prev = m_image(I_prev) if count < start_scale: z_curr = Z_opt z_in = noise_amp*z_curr+I_prev I_curr = G(z_in.detach(),I_prev) if (count == len(Gs)-1): I_curr = functions.denorm(I_curr).detach() I_curr = I_curr[0,:,:,:].cpu().numpy() I_curr = I_curr.transpose(1, 2, 0)*255 I_curr = I_curr.astype(np.uint8) images_cur.append(I_curr) count += 1 dir2save = functions.generate_dir2save(opt) try: os.makedirs('%s/start_scale=%d' % (dir2save,start_scale) ) except OSError: pass imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save,start_scale,alpha,beta),images_cur,fps=fps) del images_cur
def train_single_scale(netD, netG, reals, masks, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] mask = masks[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # Here we calculate the decrease in output size due to conv layers. # required for setting the size of discriminators map. # TODO: currently, calculation doesn't consider opt.stride if (opt.ker_size % 2 == 0): r = (opt.num_layer) * (opt.ker_size - 1) else: #(opt.ker_size % 2 != 0): r = (opt.num_layer) * (opt.ker_size - 1) / 2 r = int(r) _, _, h, w = mask.size() discriminators_mask = mask.detach()[:, :, r:h - r, r:w - r][:, 0, :, :].unsqueeze(0) _, _, h, w = discriminators_mask.size() fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] norm = [] norm.append(1) norm.append((h * w) / discriminators_mask.sum().item()) plt.imsave('%s/mask.png' % (opt.outf), functions.convert_image_np(real * mask)) for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): if (epoch == 0): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) output = output * discriminators_mask D_real_map = output.detach() errD_real = -(output.mean()) * norm[opt.norm] #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device, discriminators_mask * norm[opt.norm]) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev netG_out = netG(Z_opt.detach(), z_prev) netG_out = netG_out * mask real = real * mask rec_loss = alpha * loss(netG_out, real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errD2plot.append(errD.detach()) errG2plot.append(errG.detach()) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) with open('%s/D_real2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_real2plot, fp) with open('%s/D_fake2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(D_fake2plot, fp) # with open('%s/D_real2plot' % (opt.outf), "rb") as fp: # Unpickling # D_fake2plot = pickle.load(fp) with open('%s/errD2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errD2plot, fp) with open('%s/errG2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(errG2plot, fp) with open('%s/z_opt2plot' % (opt.outf), "wb") as fp: # Pickling pickle.dump(z_opt2plot, fp) schedulerD.step() schedulerG.step() # plt.imsave('%s/masked_img.png' % (opt.outf), functions.convert_image_np(real*mask)) functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha # generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1) # size: [opt.nc_z, opt.nzx, opt.nzy] # z_opt: input noise for calculating reconstruction loss in Generator fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) # TODO: these plots are not visualized errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): start_time = time.time() if (Gs == []) & (opt.mode != 'SR_train'): # Bottom generator, here without zero init z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) # noise_: input noise for the discriminator noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): # Initialize prev and z_prev # prev: image outputs from previous level # z_prev: image outputs from previous level of fixed noise z_opt if (Gs == []) & (opt.mode != 'SR_train'): # in_s and prev are both noise # z_prev are also noise prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch > 1 and (epoch % 100 == 0 or epoch == (opt.niter - 1)): total_time = time.time() - start_time start_time = time.time() print('scale %d:[%d/%d], total time: %f' % (len(Gs), epoch, opt.niter, total_time)) memory = torch.cuda.max_memory_allocated() # print('allocated memory: %dG %dM %dk %d' % # ( memory // (1024*1024*1024), # (memory // (1024*1024)) % 1024, # (memory // 1024) % 1024, # memory % 1024 )) print('allocated memory: %.03f GB' % (memory / (1024 * 1024 * 1024 * 1.0))) # if epoch % 500 == 0 or epoch == (opt.niter-1): if epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) # plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) # plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) # plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) plt.imsave('%s/prev_plus_noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def SinGAN_generate(Gs, Zs, reals, crops, masks, NoiseAmp, opt, in_s=None, scale_v=1, scale_h=1, n=0, gen_start_scale=0, num_samples=20, mask_locs=None): #if torch.is_tensor(in_s) == False: Gs[-1].train() if in_s == None: in_s = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) for i in range(0, num_samples, 1): eye = functions.generate_eye_mask(opt, masks[-1], 0) #generate eye in random location eye_colored = eye.clone() if opt.random_eye_color: eye_color = functions.get_eye_color(reals[-1]) opt.eye_color = eye_color eye_colored[:, 0, :, :] *= (eye_color[0] / 255) eye_colored[:, 1, :, :] *= (eye_color[1] / 255) eye_colored[:, 2, :, :] *= (eye_color[2] / 255) noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) pad1 = ((opt.ker_size - 1) * opt.num_layer) / 2 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) noise_ = m_noise(noise_) prev = functions.draw_concat(Gs, Zs, reals, crops, masks, eye_colored, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) noise = opt.noise_amp * noise_ + prev G_input = functions.make_input(noise, masks[-1], eye_colored) fake_background = Gs[-1](G_input.detach(), prev) border = False #TODO if opt.random_crop: crop_size = crops[-1].size()[2] crop, h_idx, w_idx = functions.random_crop(reals[-1], crop_size) I_curr, fake_ind, eye_ind = functions.gen_fake( crop, fake_background, masks[-1], eye, opt.eye_color, opt, border, mask_loc=mask_locs[i]) full_fake = reals[-1].clone() full_fake[:, :, h_idx:h_idx + crop_size, w_idx:w_idx + crop_size] = I_curr full_mask = torch.zeros_like(full_fake) full_mask[:, :, h_idx:h_idx + crop_size, w_idx:w_idx + crop_size] = fake_ind else: I_curr, fake_ind, eye_ind = functions.gen_fake( reals[-1], fake_background, masks[-1], eye, opt.eye_color, opt, border, mask_loc=mask_locs[i]) if opt.mode == 'train': dir2save = '%s/RandomSamples/%s/SinGAN/%s' % ( opt.out, opt.input_name[:-4], opt.run_name) else: dir2save = functions.generate_dir2save(opt) try: os.makedirs(dir2save + "/fake") os.makedirs(dir2save + "/background") os.makedirs(dir2save + "/mask") os.makedirs(dir2save + "/eye") if opt.random_crop: os.makedirs(dir2save + "/full_fake") os.makedirs(dir2save + "/full_mask") except OSError: pass if (opt.mode != "harmonization") & (opt.mode != "editing") & ( opt.mode != "SR") & (opt.mode != "paint2image"): plt.imsave('%s/%s/%d.png' % (dir2save, "fake", i), functions.convert_image_np(I_curr.detach())) plt.imsave('%s/%s/%d.png' % (dir2save, "background", i), functions.convert_image_np(fake_background.detach())) plt.imsave('%s/%s/%d.png' % (dir2save, "mask", i), functions.convert_image_np(fake_ind.detach())) plt.imsave('%s/%s/%d.png' % (dir2save, "eye", i), functions.convert_image_np(eye_ind.detach())) if opt.random_crop: plt.imsave('%s/%s/%d.png' % (dir2save, "full_fake", i), functions.convert_image_np(full_fake.detach())) plt.imsave('%s/%s/%d.png' % (dir2save, "full_mask", i), functions.convert_image_np(full_mask.detach()))