def generate_samples(netG, reals_shapes, noise_amp, scale_w=1.0, scale_h=1.0, reconstruct=False, n=50): if reconstruct: reconstruction = netG(fixed_noise, reals_shapes, noise_amp) if opt.train_mode == "generation" or opt.train_mode == "retarget": functions.save_image('{}/reconstruction.jpg'.format(dir2save), reconstruction.detach()) functions.save_image('{}/real_image.jpg'.format(dir2save), reals[-1].detach()) elif opt.train_mode == "harmonization" or opt.train_mode == "editing": functions.save_image('{}/{}_wo_mask.jpg'.format(dir2save, _name), reconstruction.detach()) functions.save_image( '{}/real_image.jpg'.format(dir2save), imresize_to_shape(real, reals_shapes[-1][2:], opt).detach()) return reconstruction if scale_w == 1. and scale_h == 1.: dir2save_parent = os.path.join(dir2save, "random_samples") else: reals_shapes = [[ r_shape[0], r_shape[1], int(r_shape[2] * scale_h), int(r_shape[3] * scale_w) ] for r_shape in reals_shapes] dir2save_parent = os.path.join( dir2save, "random_samples_scale_h_{}_scale_w_{}".format(scale_h, scale_w)) make_dir(dir2save_parent) for idx in range(n): noise = functions.sample_random_noise(opt.train_stages - 1, reals_shapes, opt) sample = netG(noise, reals_shapes, noise_amp) functions.save_image( '{}/gen_sample_{}.jpg'.format(dir2save_parent, idx), sample.detach())
def cycle_rec(netG, netG2, fixed_noise, reals, noise_amp, opt, depth, reals_shapes, in_s): # netG2,netG,fixed_noise2,reals2,noise_amp2,in_s2,opt, depth 2.1.1 x_ab = in_s x_aba = in_s if depth > 0: netG_ = copy.deepcopy(netG) # netG_.body = netG_.body[:-1] netG2_ = copy.deepcopy(netG2) netG2_.body = netG2_.body[:-1] # for G,G2,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Gs2,fixed_noise,reals,reals[1:],NoiseAmp): # z = functions.generate_noise([3, fixed_noise[-2].shape[2], fixed_noise[-2].shape[3]], device=opt.device) # z = z.expand(1, 3, z.shape[2], z.shape[3]) # opt.bsz # z = m_noise(z) # x_ab = x_ab[:,:,0:fixed_noise[-2].shape[2],0:fixed_noise[-2].shape[3]] # x_ab = m_image(x_ab) # z_in = z + reals[depth - 1] # m_image(real_curr) x_ab_e = [] for k in range(len(netG.body) - 1): z_in = functions.sample_random_noise(reals, k, reals_shapes, opt, noise_amp) netG_.body = netG.body[:k + 1] g_map = netG_(z_in, x_ab[k], reals_shapes) x_ab_e.append(g_map.detach()) # x_ab = netG_(z_in, x_ab, reals_shapes) x_aba = netG2_(x_ab_e, x_aba[-1], reals_shapes) # x_ab = x_ab.detach() x_aba = x_aba.detach() if depth == 4: opt.scale_factor = 0.455 if depth == 5: opt.scale_factor = 0.6 x_ab = imresize.imresize(x_ab_e[-1], 1 / opt.scale_factor, opt) # detach x_aba = imresize.imresize(x_aba, 1 / opt.scale_factor, opt) x_ab = x_ab[:, :, 0:reals[depth].shape[2], 0:reals[depth].shape[3]] x_aba = x_aba[:, :, 0:reals[depth].shape[2], 0:reals[depth].shape[3]] # count += 1 return x_ab, x_aba else: return x_ab[-1], x_aba[-1]
def generate_samples(netG, opt, depth, noise_amp, writer, reals, iter, n=25): opt.out_ = functions.generate_dir2save(opt) dir2save = '{}/gen_samples_stage_{}'.format(opt.out_, depth) reals_shapes = [r.shape for r in reals] all_images = [] try: os.makedirs(dir2save) except OSError: pass with torch.no_grad(): for idx in range(n): noise = functions.sample_random_noise(depth, reals_shapes, opt) sample = netG(noise, reals_shapes, noise_amp) all_images.append(sample) functions.save_image('{}/gen_sample_{}.jpg'.format(dir2save, idx), sample.detach()) all_images = torch.cat(all_images, 0) all_images[0] = reals[depth].squeeze() grid = make_grid(all_images, nrow=min(5, n), normalize=True) writer.add_image('gen_images_{}'.format(depth), grid, iter)
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2, reals2, fixed_noise2, noise_amp2, opt, depth, writer, fakes, fakes2, in_s, in_s2): reals_shapes = [real.shape for real in reals] real = reals[depth] # reals_shapes2 = [real2.shape for real2 in reals2] real2 = reals2[depth] # alpha = opt.alpha # lambda_idt = opt.lambda_idt # lambda_cyc = opt.lambda_cyc # lambda_tv = opt.lambda_tv lambda_idt = 1 lambda_cyc = 1 lambda_tv = 1 ############################ # define z_opt for training on reconstruction ########################### z_opt = functions.generate_noise( [ 3, # opt.nfc reals_shapes[depth][2], reals_shapes[depth][3] ], device=opt.device) z_opt2 = functions.generate_noise( [3, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device) fixed_noise.append(z_opt.detach()) fixed_noise2.append(z_opt2.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(itertools.chain(netD.parameters(), netD2.parameters()), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training # if depth - opt.train_depth < 0: # parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.head2.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.body2.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.tail2.parameters(), "lr": opt.lr_g}] for block in netG2.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list2 = [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG2.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG2.body[-opt.train_depth:])] # add parameters of head and tail to training # if depth - opt.train_depth < 0: # parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.head2.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.body2.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.tail2.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2), lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerD, milestones=[0.8 * opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerG, milestones=[0.8 * opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp netG(noise, prev, reals_shapes)[-1] ########################### # if depth == 0: # noise_amp.append(1) # else: # noise_amp.append(0) # start training _iter = tqdm(range(opt.niter)) loss_print = {} for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### # noise = functions.sample_random_noise(reals, depth, reals_shapes, opt, noise_amp) # noise2 = functions.sample_random_noise(reals2, depth, reals_shapes, opt, noise_amp) # 1.1.1 1.2.1 2.1.1 2.2.1 ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real # netD.zero_grad() optimizerD.zero_grad() output = netD(real2).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) loss_print['errD_real'] = errD_real.item() output2 = netD2(real).to(opt.device) errD_real2 = -output2.mean() errD_real2.backward(retain_graph=True) loss_print['errD_real2'] = errD_real2.item() if (j == 0) & (iter == 0): if depth == 0: # 1 opt.bsz 1.1.1 noise_amp.append(1) noise_amp2.append(1) prev = torch.full( [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]], 0, device=opt.device) in_s.append(prev) # in_s_ = prev prev2 = torch.full( [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]], 0, device=opt.device) in_s2.append(prev2) # in_s2_ = prev2 c_prev = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) z_prev = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) c_prev2 = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) z_prev2 = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) else: # 2.1.1 # in_s2 = in_s2_ # in_s = in_s_ prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2, reals2, noise_amp2, opt, depth, reals_shapes, in_s2) prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals, noise_amp, opt, depth, reals_shapes, in_s) z_prev2 = draw_concat(netG, reals2, 'rec', opt, depth, reals_shapes, in_s2) z_prev = draw_concat(netG2, reals, 'rec', opt, depth, reals_shapes, in_s) else: # 1.1.2 1.1.3 2.1.2 2.1.3 2.2.1 if len(in_s2) > 1: ins2_index = in_s2[:-1] ins_index = in_s[:-1] else: ins2_index = in_s2 ins_index = in_s prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2, reals2, noise_amp2, opt, depth, reals_shapes, ins2_index) prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals, noise_amp, opt, depth, reals_shapes, ins_index) if j == 0: # 1.1.1 1.2.1 2.1.1 2.2.1 if depth > 0: in_s_ = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) in_s2_ = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) noise_amp.append(0) noise_amp2.append(0) z_reconstruction = netG2(fixed_noise, in_s_, reals_shapes) z_reconstruction2 = netG(fixed_noise2, in_s2_, reals_shapes) if iter != 0: in_s.pop() in_s2.pop() in_s.append(in_s_) in_s2.append(in_s2_) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) rec_loss2 = criterion(z_reconstruction2, real2) RMSE = torch.sqrt(rec_loss).detach() RMSE2 = torch.sqrt(rec_loss2).detach() _noise_amp = 0.1 * RMSE # opt.noise_amp_init _noise_amp2 = 0.1 * RMSE2 noise_amp[-1] = _noise_amp noise_amp2[-1] = _noise_amp2 noise = functions.sample_random_noise(reals, depth, reals_shapes, opt, noise_amp) noise2 = functions.sample_random_noise(reals2, depth, reals_shapes, opt, noise_amp2) # train with fake if j == opt.Dsteps - 1: fake = netG(noise, prev, reals_shapes) fake2 = netG2(noise2, prev2, reals_shapes) else: with torch.no_grad(): fake = netG(noise, prev, reals_shapes) fake2 = netG2(noise2, prev2, reals_shapes) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) loss_print['errD_fake'] = errD_fake.item() gradient_penalty = functions.calc_gradient_penalty( netD, real2, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() loss_print['gradient_penalty'] = gradient_penalty.item() output2 = netD2(fake2.detach()) errD_fake2 = output2.mean() errD_fake2.backward(retain_graph=True) loss_print['errD_fake2'] = errD_fake2.item() gradient_penalty2 = functions.calc_gradient_penalty( netD2, real, fake2, opt.lambda_grad, opt.device) gradient_penalty2.backward() loss_print['gradient_penalty2'] = gradient_penalty2.item() optimizerD.step() # conda activate tui if iter != 0: fakes.pop() fakes2.pop() fakes.append(fake) fakes2.append(fake2) ############################ # (2) Update G network: maximize D(G(z)) ########################### optimizerG.zero_grad() loss_tv = TVLoss() output = netD(fake) errG = -output.mean() + lambda_tv * loss_tv(fake) errG.backward(retain_graph=True) loss_print['errG'] = errG.item() output2 = netD2(fake2) errG2 = -output2.mean() + lambda_tv * loss_tv(fake2) errG2.backward(retain_graph=True) loss_print['errG2'] = errG2.item() loss = nn.L1Loss() # nn.MSELoss() rec = netG(fixed_noise2, z_prev2, reals_shapes) rec_loss = lambda_idt * loss(rec, real2) rec_loss.backward(retain_graph=True) loss_print['rec_loss'] = rec_loss.item() rec_loss = rec_loss.detach() cyc = netG(fakes2, c_prev2, reals_shapes) cyc_loss = lambda_cyc * loss(cyc, real2) cyc_loss.backward(retain_graph=True) loss_print['cyc_loss'] = cyc_loss.item() cyc_loss = cyc_loss.detach() rec2 = netG2(fixed_noise, z_prev, reals_shapes) rec_loss2 = lambda_idt * loss(rec2, real) rec_loss2.backward(retain_graph=True) loss_print['rec_loss2'] = rec_loss2.item() rec_loss2 = rec_loss2.detach() cyc2 = netG2(fakes, c_prev, reals_shapes) cyc_loss2 = lambda_cyc * loss(cyc2, real) cyc_loss2.backward(retain_graph=True) loss_print['cyc_loss2'] = cyc_loss2.item() cyc_loss2 = cyc_loss2.detach() for _ in range(opt.Gsteps): # opt.Gsteps optimizerG.step() ############################ # (3) Log Results ########################### if iter % 500 == 0 or iter == (opt.niter - 1): functions.save_image( '{}/fake_sample_{}.jpg'.format(opt.outf, iter + 1), fake.detach()) functions.save_image( '{}/fake_sample2_{}.jpg'.format(opt.outf, iter + 1), fake2.detach()) # functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) # functions.save_image('{}/reconstruction2_{}.jpg'.format(opt.outf, iter+1), rec2.detach()) # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) log = " Iteration [{}/{}]".format(iter, opt.niter) for tag, value in loss_print.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # if iter % 250 == 0 or iter+1 == opt.niter: # writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) # writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) # writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) # writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1) # writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1) # writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1) # # writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) # writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) # writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1) # writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1) # writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1) # writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1) # # if iter % 500 == 0 or iter+1 == opt.niter: # functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) # functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) # # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt) return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2, in_s, in_s2
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, opt, depth, writer): reals_shapes = [real.shape for real in reals] real = reals[depth] alpha = opt.alpha ############################ # define z_opt for training on reconstruction ########################### if depth == 0: if opt.train_mode == "generation" or opt.train_mode == "retarget": z_opt = reals[0] elif opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() else: if opt.train_mode == "generation" or opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2]+opt.num_layer*2, reals_shapes[depth][3]+opt.num_layer*2], device=opt.device) else: z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() fixed_noise.append(z_opt.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(parameter_list, lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp ########################### if depth == 0: noise_amp.append(1) else: noise_amp.append(0) z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) RMSE = torch.sqrt(rec_loss).detach() _noise_amp = opt.noise_amp_init * RMSE noise_amp[-1] = _noise_amp # start training _iter = tqdm(range(opt.niter)) for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### noise = functions.sample_random_noise(depth, reals_shapes, opt) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real) errD_real = -output.mean() # train with fake if j == opt.Dsteps - 1: fake = netG(noise, reals_shapes, noise_amp) else: with torch.no_grad(): fake = netG(noise, reals_shapes, noise_amp) output = netD(fake.detach()) errD_fake = output.mean() gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty errD_total.backward() optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### output = netD(fake) errG = -output.mean() if alpha != 0: loss = nn.MSELoss() rec = netG(fixed_noise, reals_shapes, noise_amp) rec_loss = alpha * loss(rec, real) else: rec_loss = 0 netG.zero_grad() errG_total = errG + rec_loss errG_total.backward() for _ in range(opt.Gsteps): optimizerG.step() ############################ # (3) Log Results ########################### if iter % 250 == 0 or iter+1 == opt.niter: writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) if iter % 500 == 0 or iter+1 == opt.niter: functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() # break functions.save_networks(netG, netD, z_opt, opt) return fixed_noise, noise_amp, netG, netD
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2, reals2, fixed_noise2, noise_amp2, opt, depth, writer): reals_shapes = [real.shape for real in reals] real = reals[depth] reals_shapes2 = [real2.shape for real2 in reals2] real2 = reals2[depth] # alpha = opt.alpha lambda_idt = opt.lambda_idt lambda_cyc = opt.lambda_cyc lambda_tv = opt.lambda_tv ############################ # define z_opt for training on reconstruction ########################### if depth == 0: if opt.train_mode == "generation" or opt.train_mode == "retarget": z_opt = reals[0] z_opt2 = reals2[0] elif opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() z_opt2 = functions.generate_noise([opt.nc_im, reals_shapes2[depth][2], reals_shapes2[depth][3]], device=opt.device).detach() else: if opt.train_mode == "generation" or opt.train_mode == "animation": z_opt0 = functions.generate_noise([opt.nfc, reals_shapes[depth][2]+opt.num_layer*2, reals_shapes[depth][3]+opt.num_layer*2], device=opt.device) fixed_noise.append(z_opt0.detach()) # fakes_shapes = [fake.shape for fake in fixed_noise] noise_amp_f = [0.1] * 15 z_opt = netG(reals, fixed_noise, reals_shapes, noise_amp_f) fixed_noise = fixed_noise[: -1] z_opt02 = functions.generate_noise([opt.nfc, reals_shapes2[depth][2]+opt.num_layer*2, reals_shapes2[depth][3]+opt.num_layer*2], device=opt.device) # fixed_noise2.append(z_opt02.detach()) # fakes_shapes2 = [fake2.shape for fake2 in fixed_noise2] z_opt2 = netG2(reals[1:], z_opt, noise_amp_f) fixed_noise2 = fixed_noise2[: -1] # criterion = nn.MSELoss() # rec_loss = criterion(z_opt1, real) # # RMSE = torch.sqrt(rec_loss).detach() # _noise_amp = opt.noise_amp_init * RMSE # noise_amp_f[-1] = _noise_amp # fixed_noise.pop() else: z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() # 暂时未更新 fixed_noise.append(z_opt.detach()) fixed_noise2.append(z_opt2.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(itertools.chain(netD.parameters(),netD2.parameters()), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] for block in netG2.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list2 = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG2.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG2.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2), lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp ########################### if depth == 0: noise_amp.append(1) else: noise_amp.append(0) z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) RMSE = torch.sqrt(rec_loss).detach() _noise_amp = opt.noise_amp_init * RMSE noise_amp[-1] = _noise_amp # start training _iter = tqdm(range(opt.niter)) # for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### noise = functions.sample_random_noise(depth, reals_shapes, opt) noise2 = functions.sample_random_noise(depth, reals_shapes, opt) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # # train with real # netD.zero_grad() optimizerD.zero_grad() output = netD(real2) errD_real = -output.mean() output2 = netD(real) errD_real2 = -output2.mean() # train with fake if j == opt.Dsteps - 1: fake = netG(reals, reals_shapes, noise_amp, add_noise=True) # 噪声 + 真实图像 # fake2, _ = netG(noise2, reals_shapes2, noise_amp2) else: with torch.no_grad(): fake = netG(reals, reals_shapes, noise_amp, add_noise=True) # fake2, _ = netG(noise2, reals_shapes2, noise_amp2) output = netD(fake.detach()) errD_fake = output.mean() gradient_penalty = functions.calc_gradient_penalty(netD, real2, fake, opt.lambda_grad, opt.device) if j == opt.Dsteps - 1: fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True) else: with torch.no_grad(): fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True) output2 = netD2(fake2.detach()) errD_fake2 = output2.mean() gradient_penalty2 = functions.calc_gradient_penalty(netD2, real, fake2, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty + errD_real2 + errD_fake2 + gradient_penalty2 errD_total.backward() optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### optimizerG.zero_grad() loss_tv = TVLoss() output = netD(fake) errG = -output.mean() + lambda_tv * loss_tv(fake) output2 = netD2(fake2) errG2 = -output2.mean() + lambda_tv * loss_tv(fake2) loss = nn.L1Loss() # nn.MSELoss() rec = netG(real2, reals_shapes2, noise_amp2) # real rec_loss = lambda_idt * loss(rec, real2) rec_loss = rec_loss.detach() cyc = netG(fake2, reals_shapes2, noise_amp2) cyc_loss = lambda_cyc* loss(cyc, real2) cyc_loss = cyc_loss.detach() rec2 = netG2(real, reals_shapes, noise_amp) rec_loss2 = lambda_idt * loss(rec2, real) rec_loss2 = rec_loss2.detach() cyc2 = netG2(fake, reals_shapes, noise_amp) cyc_loss2 = lambda_cyc* loss(cyc2, real) cyc_loss2 = cyc_loss2.detach() errG_total = errG + rec_loss + errG2 + cyc_loss + cyc_loss2 + rec_loss2 errG_total.backward() for _ in range(opt.Gsteps): # opt.Gsteps optimizerG.step() ############################ # (3) Log Results ########################### if iter % 250 == 0 or iter+1 == opt.niter: writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1) writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1) writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1) writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1) writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1) if iter % 500 == 0 or iter+1 == opt.niter: functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() # break functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt) return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2