def train(opt): print("Training model with the following parameters:") print("\t number of stages: {}".format(opt.train_stages)) print("\t number of concurrently trained stages: {}".format(opt.train_depth)) print("\t learning rate scaling: {}".format(opt.lr_scale)) print("\t non-linearity: {}".format(opt.activation)) real, real2 = functions.read_two_domains(opt) # real = functions.read_image(opt) # print(0, real.shape) real = functions.adjust_scales2image(real, opt) reals = functions.create_reals_pyramid(real, opt) real2 = functions.adjust_scales2image(real2, opt) reals2 = functions.create_reals_pyramid(real2, opt) generator, generator2 = init_G(opt) fixed_noise = [] noise_amp = [] fixed_noise2 = [] noise_amp2 = [] for scale_num in range(opt.stop_scale+1): opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_,scale_num) try: os.makedirs(opt.outf) except OSError: print(OSError) pass functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num]) d_curr, d_curr2 = init_D(opt) if scale_num > 0: d_curr.load_state_dict(torch.load('%s/%d/netD.pth' % (opt.out_,scale_num-1))) generator.init_next_stage() d_curr2.load_state_dict(torch.load('%s/%d/netD2.pth' % (opt.out_,scale_num-1))) generator2.init_next_stage() writer = SummaryWriter(log_dir=opt.outf) fixed_noise, noise_amp, generator, d_curr, fixed_noise2, noise_amp2, generator2, d_curr2 = \ train_single_scale(d_curr, generator, reals, fixed_noise, noise_amp, d_curr2, generator2, reals2, fixed_noise2, noise_amp2, opt, scale_num, writer) torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_)) torch.save(generator, '%s/G.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_)) torch.save(fixed_noise2, '%s/fixed_noise2.pth' % (opt.out_)) torch.save(generator2, '%s/G2.pth' % (opt.out_)) torch.save(reals2, '%s/reals2.pth' % (opt.out_)) torch.save(noise_amp2, '%s/noise_amp2.pth' % (opt.out_)) del d_curr, d_curr2 writer.close() return
def generate_samples(netG, reals_shapes, noise_amp, scale_w=1.0, scale_h=1.0, reconstruct=False, n=50): if reconstruct: reconstruction = netG(fixed_noise, reals_shapes, noise_amp) if opt.train_mode == "generation" or opt.train_mode == "retarget": functions.save_image('{}/reconstruction.jpg'.format(dir2save), reconstruction.detach()) functions.save_image('{}/real_image.jpg'.format(dir2save), reals[-1].detach()) elif opt.train_mode == "harmonization" or opt.train_mode == "editing": functions.save_image('{}/{}_wo_mask.jpg'.format(dir2save, _name), reconstruction.detach()) functions.save_image( '{}/real_image.jpg'.format(dir2save), imresize_to_shape(real, reals_shapes[-1][2:], opt).detach()) return reconstruction if scale_w == 1. and scale_h == 1.: dir2save_parent = os.path.join(dir2save, "random_samples") else: reals_shapes = [[ r_shape[0], r_shape[1], int(r_shape[2] * scale_h), int(r_shape[3] * scale_w) ] for r_shape in reals_shapes] dir2save_parent = os.path.join( dir2save, "random_samples_scale_h_{}_scale_w_{}".format(scale_h, scale_w)) make_dir(dir2save_parent) for idx in range(n): noise = functions.sample_random_noise(opt.train_stages - 1, reals_shapes, opt) sample = netG(noise, reals_shapes, noise_amp) functions.save_image( '{}/gen_sample_{}.jpg'.format(dir2save_parent, idx), sample.detach())
def generate_samples(netG, opt, depth, noise_amp, writer, reals, iter, n=25): opt.out_ = functions.generate_dir2save(opt) dir2save = '{}/gen_samples_stage_{}'.format(opt.out_, depth) reals_shapes = [r.shape for r in reals] all_images = [] try: os.makedirs(dir2save) except OSError: pass with torch.no_grad(): for idx in range(n): noise = functions.sample_random_noise(depth, reals_shapes, opt) sample = netG(noise, reals_shapes, noise_amp) all_images.append(sample) functions.save_image('{}/gen_sample_{}.jpg'.format(dir2save, idx), sample.detach()) all_images = torch.cat(all_images, 0) all_images[0] = reals[depth].squeeze() grid = make_grid(all_images, nrow=min(5, n), normalize=True) writer.add_image('gen_images_{}'.format(depth), grid, iter)
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2, reals2, fixed_noise2, noise_amp2, opt, depth, writer, fakes, fakes2, in_s, in_s2): reals_shapes = [real.shape for real in reals] real = reals[depth] # reals_shapes2 = [real2.shape for real2 in reals2] real2 = reals2[depth] # alpha = opt.alpha # lambda_idt = opt.lambda_idt # lambda_cyc = opt.lambda_cyc # lambda_tv = opt.lambda_tv lambda_idt = 1 lambda_cyc = 1 lambda_tv = 1 ############################ # define z_opt for training on reconstruction ########################### z_opt = functions.generate_noise( [ 3, # opt.nfc reals_shapes[depth][2], reals_shapes[depth][3] ], device=opt.device) z_opt2 = functions.generate_noise( [3, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device) fixed_noise.append(z_opt.detach()) fixed_noise2.append(z_opt2.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(itertools.chain(netD.parameters(), netD2.parameters()), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training # if depth - opt.train_depth < 0: # parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.head2.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.body2.parameters(), "lr": opt.lr_g}] parameter_list += [{"params": netG.tail2.parameters(), "lr": opt.lr_g}] for block in netG2.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list2 = [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG2.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG2.body[-opt.train_depth:])] # add parameters of head and tail to training # if depth - opt.train_depth < 0: # parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.head2.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.body2.parameters(), "lr": opt.lr_g}] parameter_list2 += [{"params": netG2.tail2.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2), lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerD, milestones=[0.8 * opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR( optimizer=optimizerG, milestones=[0.8 * opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp netG(noise, prev, reals_shapes)[-1] ########################### # if depth == 0: # noise_amp.append(1) # else: # noise_amp.append(0) # start training _iter = tqdm(range(opt.niter)) loss_print = {} for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### # noise = functions.sample_random_noise(reals, depth, reals_shapes, opt, noise_amp) # noise2 = functions.sample_random_noise(reals2, depth, reals_shapes, opt, noise_amp) # 1.1.1 1.2.1 2.1.1 2.2.1 ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real # netD.zero_grad() optimizerD.zero_grad() output = netD(real2).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) loss_print['errD_real'] = errD_real.item() output2 = netD2(real).to(opt.device) errD_real2 = -output2.mean() errD_real2.backward(retain_graph=True) loss_print['errD_real2'] = errD_real2.item() if (j == 0) & (iter == 0): if depth == 0: # 1 opt.bsz 1.1.1 noise_amp.append(1) noise_amp2.append(1) prev = torch.full( [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]], 0, device=opt.device) in_s.append(prev) # in_s_ = prev prev2 = torch.full( [1, opt.nc_im, reals_shapes[0][2], reals_shapes[0][3]], 0, device=opt.device) in_s2.append(prev2) # in_s2_ = prev2 c_prev = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) z_prev = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) c_prev2 = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) z_prev2 = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) else: # 2.1.1 # in_s2 = in_s2_ # in_s = in_s_ prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2, reals2, noise_amp2, opt, depth, reals_shapes, in_s2) prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals, noise_amp, opt, depth, reals_shapes, in_s) z_prev2 = draw_concat(netG, reals2, 'rec', opt, depth, reals_shapes, in_s2) z_prev = draw_concat(netG2, reals, 'rec', opt, depth, reals_shapes, in_s) else: # 1.1.2 1.1.3 2.1.2 2.1.3 2.2.1 if len(in_s2) > 1: ins2_index = in_s2[:-1] ins_index = in_s[:-1] else: ins2_index = in_s2 ins_index = in_s prev2, c_prev2 = cycle_rec(netG2, netG, fixed_noise2, reals2, noise_amp2, opt, depth, reals_shapes, ins2_index) prev, c_prev = cycle_rec(netG, netG2, fixed_noise, reals, noise_amp, opt, depth, reals_shapes, ins_index) if j == 0: # 1.1.1 1.2.1 2.1.1 2.2.1 if depth > 0: in_s_ = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) in_s2_ = torch.full([ 1, opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3] ], 0, device=opt.device) noise_amp.append(0) noise_amp2.append(0) z_reconstruction = netG2(fixed_noise, in_s_, reals_shapes) z_reconstruction2 = netG(fixed_noise2, in_s2_, reals_shapes) if iter != 0: in_s.pop() in_s2.pop() in_s.append(in_s_) in_s2.append(in_s2_) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) rec_loss2 = criterion(z_reconstruction2, real2) RMSE = torch.sqrt(rec_loss).detach() RMSE2 = torch.sqrt(rec_loss2).detach() _noise_amp = 0.1 * RMSE # opt.noise_amp_init _noise_amp2 = 0.1 * RMSE2 noise_amp[-1] = _noise_amp noise_amp2[-1] = _noise_amp2 noise = functions.sample_random_noise(reals, depth, reals_shapes, opt, noise_amp) noise2 = functions.sample_random_noise(reals2, depth, reals_shapes, opt, noise_amp2) # train with fake if j == opt.Dsteps - 1: fake = netG(noise, prev, reals_shapes) fake2 = netG2(noise2, prev2, reals_shapes) else: with torch.no_grad(): fake = netG(noise, prev, reals_shapes) fake2 = netG2(noise2, prev2, reals_shapes) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) loss_print['errD_fake'] = errD_fake.item() gradient_penalty = functions.calc_gradient_penalty( netD, real2, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() loss_print['gradient_penalty'] = gradient_penalty.item() output2 = netD2(fake2.detach()) errD_fake2 = output2.mean() errD_fake2.backward(retain_graph=True) loss_print['errD_fake2'] = errD_fake2.item() gradient_penalty2 = functions.calc_gradient_penalty( netD2, real, fake2, opt.lambda_grad, opt.device) gradient_penalty2.backward() loss_print['gradient_penalty2'] = gradient_penalty2.item() optimizerD.step() # conda activate tui if iter != 0: fakes.pop() fakes2.pop() fakes.append(fake) fakes2.append(fake2) ############################ # (2) Update G network: maximize D(G(z)) ########################### optimizerG.zero_grad() loss_tv = TVLoss() output = netD(fake) errG = -output.mean() + lambda_tv * loss_tv(fake) errG.backward(retain_graph=True) loss_print['errG'] = errG.item() output2 = netD2(fake2) errG2 = -output2.mean() + lambda_tv * loss_tv(fake2) errG2.backward(retain_graph=True) loss_print['errG2'] = errG2.item() loss = nn.L1Loss() # nn.MSELoss() rec = netG(fixed_noise2, z_prev2, reals_shapes) rec_loss = lambda_idt * loss(rec, real2) rec_loss.backward(retain_graph=True) loss_print['rec_loss'] = rec_loss.item() rec_loss = rec_loss.detach() cyc = netG(fakes2, c_prev2, reals_shapes) cyc_loss = lambda_cyc * loss(cyc, real2) cyc_loss.backward(retain_graph=True) loss_print['cyc_loss'] = cyc_loss.item() cyc_loss = cyc_loss.detach() rec2 = netG2(fixed_noise, z_prev, reals_shapes) rec_loss2 = lambda_idt * loss(rec2, real) rec_loss2.backward(retain_graph=True) loss_print['rec_loss2'] = rec_loss2.item() rec_loss2 = rec_loss2.detach() cyc2 = netG2(fakes, c_prev, reals_shapes) cyc_loss2 = lambda_cyc * loss(cyc2, real) cyc_loss2.backward(retain_graph=True) loss_print['cyc_loss2'] = cyc_loss2.item() cyc_loss2 = cyc_loss2.detach() for _ in range(opt.Gsteps): # opt.Gsteps optimizerG.step() ############################ # (3) Log Results ########################### if iter % 500 == 0 or iter == (opt.niter - 1): functions.save_image( '{}/fake_sample_{}.jpg'.format(opt.outf, iter + 1), fake.detach()) functions.save_image( '{}/fake_sample2_{}.jpg'.format(opt.outf, iter + 1), fake2.detach()) # functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) # functions.save_image('{}/reconstruction2_{}.jpg'.format(opt.outf, iter+1), rec2.detach()) # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) log = " Iteration [{}/{}]".format(iter, opt.niter) for tag, value in loss_print.items(): log += ", {}: {:.4f}".format(tag, value) print(log) # if iter % 250 == 0 or iter+1 == opt.niter: # writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) # writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) # writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) # writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1) # writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1) # writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1) # # writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) # writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) # writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1) # writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1) # writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1) # writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1) # # if iter % 500 == 0 or iter+1 == opt.niter: # functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) # functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) # # generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt) return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2, in_s, in_s2
functions.generate_noise([opt.nc_im, fixed_noise[0].shape[2], fixed_noise[0].shape[3]], device=opt.device) out = generate_samples(netG, reals_shapes, noise_amp, reconstruct=True) mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4], opt.naive_img[-4:]) if os.path.exists(mask_file_name): mask = functions.read_image_dir(mask_file_name, opt) if mask.shape[3] != out.shape[3]: mask = imresize_to_shape(mask, [out.shape[2], out.shape[3]], opt) mask = functions.dilate_mask(mask, opt) out = (1 - mask) * reals[-1] + mask * out functions.save_image('{}/{}_w_mask.jpg'.format(dir2save, _name), out.detach()) else: print("Warning: mask {} not found.".format(mask_file_name)) print("Harmonization/Editing only performed without mask.") elif opt.train_mode == "animation": print("Generating GIFs...") for _start_scale in range(3): for _beta in range(80, 100, 5): functions.generate_gif(dir2save, netG, fixed_noise, reals, noise_amp, opt, alpha=0.1,
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, opt, depth, writer): reals_shapes = [real.shape for real in reals] real = reals[depth] alpha = opt.alpha ############################ # define z_opt for training on reconstruction ########################### if depth == 0: if opt.train_mode == "generation" or opt.train_mode == "retarget": z_opt = reals[0] elif opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() else: if opt.train_mode == "generation" or opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2]+opt.num_layer*2, reals_shapes[depth][3]+opt.num_layer*2], device=opt.device) else: z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() fixed_noise.append(z_opt.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(parameter_list, lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp ########################### if depth == 0: noise_amp.append(1) else: noise_amp.append(0) z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) RMSE = torch.sqrt(rec_loss).detach() _noise_amp = opt.noise_amp_init * RMSE noise_amp[-1] = _noise_amp # start training _iter = tqdm(range(opt.niter)) for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### noise = functions.sample_random_noise(depth, reals_shapes, opt) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real) errD_real = -output.mean() # train with fake if j == opt.Dsteps - 1: fake = netG(noise, reals_shapes, noise_amp) else: with torch.no_grad(): fake = netG(noise, reals_shapes, noise_amp) output = netD(fake.detach()) errD_fake = output.mean() gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty errD_total.backward() optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### output = netD(fake) errG = -output.mean() if alpha != 0: loss = nn.MSELoss() rec = netG(fixed_noise, reals_shapes, noise_amp) rec_loss = alpha * loss(rec, real) else: rec_loss = 0 netG.zero_grad() errG_total = errG + rec_loss errG_total.backward() for _ in range(opt.Gsteps): optimizerG.step() ############################ # (3) Log Results ########################### if iter % 250 == 0 or iter+1 == opt.niter: writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) if iter % 500 == 0 or iter+1 == opt.niter: functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() # break functions.save_networks(netG, netD, z_opt, opt) return fixed_noise, noise_amp, netG, netD
def generate_samples(netG, img_to_augment, naive_img, naive_img_large, aug, opt, depth, noise_amp, writer, reals, iter, n=16): opt.out_ = functions.generate_dir2save(opt) dir2save = '{}/harmonized_samples_stage_{}'.format(opt.out_, depth) reals_shapes = [r.shape for r in reals] _name = "harmonized" if opt.train_mode == "harmonization" else "edited" images = [] try: os.makedirs(dir2save) except OSError: pass if naive_img is not None: n = n - 1 if opt.fine_tune: n = 1 with torch.no_grad(): for idx in range(n): noise = [] for d in range(depth + 1): if d == 0: if opt.fine_tune: if opt.train_mode == "harmonization": augmented_image = functions.np2torch( naive_img, opt) noise.append(augmented_image) elif opt.train_mode == "editing": augmented_image = functions.np2torch( naive_img, opt) noise.append(augmented_image + opt.noise_scaling * functions.generate_noise( [ opt.nc_im, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) else: if opt.train_mode == "harmonization": data = {"image": img_to_augment} augmented = aug.transform(**data) augmented_image = functions.np2torch( augmented["image"], opt) noise.append(augmented_image) elif opt.train_mode == "editing": image = functions.shuffle_grid(img_to_augment) augmented_image = functions.np2torch(image, opt) noise.append(augmented_image + opt.noise_scaling * functions.generate_noise( [ opt.nc_im, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) else: noise.append( functions.generate_noise( [opt.nfc, reals_shapes[d][2], reals_shapes[d][3]], device=opt.device).detach()) sample = netG(noise, reals_shapes, noise_amp) functions.save_image( '{}/{}_naive_sample.jpg'.format(dir2save, idx), augmented_image) functions.save_image( '{}/{}_{}_sample.jpg'.format(dir2save, idx, _name), sample.detach()) augmented_image = imresize_to_shape(augmented_image, sample.shape[2:], opt) images.append(augmented_image) images.append(sample.detach()) if opt.fine_tune: mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4], opt.naive_img[-4:]) augmented_image = imresize_to_shape(naive_img_large, sample.shape[2:], opt) if os.path.exists(mask_file_name): mask = get_mask(mask_file_name, augmented_image, opt) sample_w_mask = ( 1 - mask) * augmented_image + mask * sample.detach() functions.save_image( '{}/{}_sample_w_mask_{}.jpg'.format(dir2save, _name, iter), sample_w_mask.detach()) images = torch.cat( [augmented_image, sample.detach(), sample_w_mask], 0) grid = make_grid(images, nrow=3, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter) else: print( "Warning: no mask with name {} exists for image {}".format( mask_file_name, opt.input_name)) print("Only showing results without mask.") images = torch.cat([augmented_image, sample.detach()], 0) grid = make_grid(images, nrow=2, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter) functions.save_image( '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter), sample.detach()) else: if naive_img is not None: noise = [] for d in range(depth + 1): if d == 0: if opt.train_mode == "harmonization": noise.append(functions.np2torch(naive_img, opt)) elif opt.train_mode == "editing": noise.append(functions.np2torch(naive_img, opt) + opt.noise_scaling * \ functions.generate_noise([opt.nc_im, reals_shapes[d][2], reals_shapes[d][3]], device=opt.device).detach()) else: noise.append( functions.generate_noise( [ opt.nfc, reals_shapes[d][2], reals_shapes[d][3] ], device=opt.device).detach()) sample = netG(noise, reals_shapes, noise_amp) _naive_img = imresize_to_shape(naive_img_large, sample.shape[2:], opt) images.insert(0, sample.detach()) images.insert(0, _naive_img) functions.save_image( '{}/{}_sample_{}.jpg'.format(dir2save, _name, iter), sample.detach()) mask_file_name = '{}_mask{}'.format(opt.naive_img[:-4], opt.naive_img[-4:]) if os.path.exists(mask_file_name): mask = get_mask(mask_file_name, _naive_img, opt) sample_w_mask = ( 1 - mask) * _naive_img + mask * sample.detach() functions.save_image( '{}/{}_sample_w_mask_{}.jpg'.format( dir2save, _name, iter), sample_w_mask) images = torch.cat(images, 0) grid = make_grid(images, nrow=4, normalize=True) writer.add_image('{}_images_{}'.format(_name, depth), grid, iter)
def train(opt): print("Training model with the following parameters:") print("\t number of stages: {}".format(opt.train_stages)) print("\t number of concurrently trained stages: {}".format( opt.train_depth)) print("\t learning rate scaling: {}".format(opt.lr_scale)) print("\t non-linearity: {}".format(opt.activation)) real = functions.read_image(opt) real = functions.adjust_scales2image(real, opt) reals = functions.create_reals_pyramid(real, opt) print("Training on image pyramid: {}".format([r.shape for r in reals])) print("") if opt.naive_img != "": naive_img = functions.read_image_dir(opt.naive_img, opt) naive_img_large = imresize_to_shape(naive_img, reals[-1].shape[2:], opt) naive_img = imresize_to_shape(naive_img, reals[0].shape[2:], opt) naive_img = functions.convert_image_np(naive_img) * 255.0 else: naive_img = None naive_img_large = None if opt.fine_tune: img_to_augment = naive_img else: img_to_augment = functions.convert_image_np(reals[0]) * 255.0 if opt.train_mode == "editing": opt.noise_scaling = 0.1 generator = init_G(opt) if opt.fine_tune: for _ in range(opt.train_stages - 1): generator.init_next_stage() generator.load_state_dict( torch.load( '{}/{}/netG.pth'.format(opt.model_dir, opt.train_stages - 1), map_location="cuda:{}".format(torch.cuda.current_device()))) fixed_noise = [] noise_amp = [] for scale_num in range(opt.start_scale, opt.train_stages): opt.out_ = functions.generate_dir2save(opt) opt.outf = '%s/%d' % (opt.out_, scale_num) try: os.makedirs(opt.outf) except OSError: print(OSError) pass functions.save_image('{}/real_scale.jpg'.format(opt.outf), reals[scale_num]) d_curr = init_D(opt) if opt.fine_tune: d_curr.load_state_dict( torch.load('{}/{}/netD.pth'.format(opt.model_dir, opt.train_stages - 1), map_location="cuda:{}".format( torch.cuda.current_device()))) elif scale_num > 0: d_curr.load_state_dict( torch.load('%s/%d/netD.pth' % (opt.out_, scale_num - 1))) generator.init_next_stage() writer = SummaryWriter(log_dir=opt.outf) fixed_noise, noise_amp, generator, d_curr = train_single_scale( d_curr, generator, reals, img_to_augment, naive_img, naive_img_large, fixed_noise, noise_amp, opt, scale_num, writer) torch.save(fixed_noise, '%s/fixed_noise.pth' % (opt.out_)) torch.save(generator, '%s/G.pth' % (opt.out_)) torch.save(reals, '%s/reals.pth' % (opt.out_)) torch.save(noise_amp, '%s/noise_amp.pth' % (opt.out_)) del d_curr writer.close() return
def train_single_scale(netD, netG, reals, fixed_noise, noise_amp, netD2, netG2, reals2, fixed_noise2, noise_amp2, opt, depth, writer): reals_shapes = [real.shape for real in reals] real = reals[depth] reals_shapes2 = [real2.shape for real2 in reals2] real2 = reals2[depth] # alpha = opt.alpha lambda_idt = opt.lambda_idt lambda_cyc = opt.lambda_cyc lambda_tv = opt.lambda_tv ############################ # define z_opt for training on reconstruction ########################### if depth == 0: if opt.train_mode == "generation" or opt.train_mode == "retarget": z_opt = reals[0] z_opt2 = reals2[0] elif opt.train_mode == "animation": z_opt = functions.generate_noise([opt.nc_im, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() z_opt2 = functions.generate_noise([opt.nc_im, reals_shapes2[depth][2], reals_shapes2[depth][3]], device=opt.device).detach() else: if opt.train_mode == "generation" or opt.train_mode == "animation": z_opt0 = functions.generate_noise([opt.nfc, reals_shapes[depth][2]+opt.num_layer*2, reals_shapes[depth][3]+opt.num_layer*2], device=opt.device) fixed_noise.append(z_opt0.detach()) # fakes_shapes = [fake.shape for fake in fixed_noise] noise_amp_f = [0.1] * 15 z_opt = netG(reals, fixed_noise, reals_shapes, noise_amp_f) fixed_noise = fixed_noise[: -1] z_opt02 = functions.generate_noise([opt.nfc, reals_shapes2[depth][2]+opt.num_layer*2, reals_shapes2[depth][3]+opt.num_layer*2], device=opt.device) # fixed_noise2.append(z_opt02.detach()) # fakes_shapes2 = [fake2.shape for fake2 in fixed_noise2] z_opt2 = netG2(reals[1:], z_opt, noise_amp_f) fixed_noise2 = fixed_noise2[: -1] # criterion = nn.MSELoss() # rec_loss = criterion(z_opt1, real) # # RMSE = torch.sqrt(rec_loss).detach() # _noise_amp = opt.noise_amp_init * RMSE # noise_amp_f[-1] = _noise_amp # fixed_noise.pop() else: z_opt = functions.generate_noise([opt.nfc, reals_shapes[depth][2], reals_shapes[depth][3]], device=opt.device).detach() # 暂时未更新 fixed_noise.append(z_opt.detach()) fixed_noise2.append(z_opt2.detach()) ############################ # define optimizers, learning rate schedulers, and learning rates for lower stages ########################### # setup optimizers for D optimizerD = optim.Adam(itertools.chain(netD.parameters(),netD2.parameters()), lr=opt.lr_d, betas=(opt.beta1, 0.999)) # setup optimizers for G # remove gradients from stages that are not trained for block in netG.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list += [{"params": netG.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list += [{"params": netG.tail.parameters(), "lr": opt.lr_g}] for block in netG2.body[:-opt.train_depth]: for param in block.parameters(): param.requires_grad = False # set different learning rate for lower stages parameter_list2 = [{"params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG2.body[-opt.train_depth:])-1-idx))} for idx, block in enumerate(netG2.body[-opt.train_depth:])] # add parameters of head and tail to training if depth - opt.train_depth < 0: parameter_list2 += [{"params": netG2.head.parameters(), "lr": opt.lr_g * (opt.lr_scale**depth)}] parameter_list2 += [{"params": netG2.tail.parameters(), "lr": opt.lr_g}] optimizerG = optim.Adam(itertools.chain(parameter_list, parameter_list2), lr=opt.lr_g, betas=(opt.beta1, 0.999)) # define learning rate schedules schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[0.8*opt.niter], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[0.8*opt.niter], gamma=opt.gamma) ############################ # calculate noise_amp ########################### if depth == 0: noise_amp.append(1) else: noise_amp.append(0) z_reconstruction = netG(fixed_noise, reals_shapes, noise_amp) criterion = nn.MSELoss() rec_loss = criterion(z_reconstruction, real) RMSE = torch.sqrt(rec_loss).detach() _noise_amp = opt.noise_amp_init * RMSE noise_amp[-1] = _noise_amp # start training _iter = tqdm(range(opt.niter)) # for iter in _iter: _iter.set_description('stage [{}/{}]:'.format(depth, opt.stop_scale)) ############################ # (0) sample noise for unconditional generation ########################### noise = functions.sample_random_noise(depth, reals_shapes, opt) noise2 = functions.sample_random_noise(depth, reals_shapes, opt) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # # train with real # netD.zero_grad() optimizerD.zero_grad() output = netD(real2) errD_real = -output.mean() output2 = netD(real) errD_real2 = -output2.mean() # train with fake if j == opt.Dsteps - 1: fake = netG(reals, reals_shapes, noise_amp, add_noise=True) # 噪声 + 真实图像 # fake2, _ = netG(noise2, reals_shapes2, noise_amp2) else: with torch.no_grad(): fake = netG(reals, reals_shapes, noise_amp, add_noise=True) # fake2, _ = netG(noise2, reals_shapes2, noise_amp2) output = netD(fake.detach()) errD_fake = output.mean() gradient_penalty = functions.calc_gradient_penalty(netD, real2, fake, opt.lambda_grad, opt.device) if j == opt.Dsteps - 1: fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True) else: with torch.no_grad(): fake2 = netG2(reals2, reals_shapes2, noise_amp2, add_noise=True) output2 = netD2(fake2.detach()) errD_fake2 = output2.mean() gradient_penalty2 = functions.calc_gradient_penalty(netD2, real, fake2, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty + errD_real2 + errD_fake2 + gradient_penalty2 errD_total.backward() optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### optimizerG.zero_grad() loss_tv = TVLoss() output = netD(fake) errG = -output.mean() + lambda_tv * loss_tv(fake) output2 = netD2(fake2) errG2 = -output2.mean() + lambda_tv * loss_tv(fake2) loss = nn.L1Loss() # nn.MSELoss() rec = netG(real2, reals_shapes2, noise_amp2) # real rec_loss = lambda_idt * loss(rec, real2) rec_loss = rec_loss.detach() cyc = netG(fake2, reals_shapes2, noise_amp2) cyc_loss = lambda_cyc* loss(cyc, real2) cyc_loss = cyc_loss.detach() rec2 = netG2(real, reals_shapes, noise_amp) rec_loss2 = lambda_idt * loss(rec2, real) rec_loss2 = rec_loss2.detach() cyc2 = netG2(fake, reals_shapes, noise_amp) cyc_loss2 = lambda_cyc* loss(cyc2, real) cyc_loss2 = cyc_loss2.detach() errG_total = errG + rec_loss + errG2 + cyc_loss + cyc_loss2 + rec_loss2 errG_total.backward() for _ in range(opt.Gsteps): # opt.Gsteps optimizerG.step() ############################ # (3) Log Results ########################### if iter % 250 == 0 or iter+1 == opt.niter: writer.add_scalar('Loss/train/D/real/{}'.format(j), -errD_real.item(), iter+1) writer.add_scalar('Loss/train/D/fake/{}'.format(j), errD_fake.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty/{}'.format(j), gradient_penalty.item(), iter+1) writer.add_scalar('Loss/train/D/real2/{}'.format(j), -errD_real2.item(), iter+1) writer.add_scalar('Loss/train/D/fake2/{}'.format(j), errD_fake2.item(), iter+1) writer.add_scalar('Loss/train/D/gradient_penalty2/{}'.format(j), gradient_penalty2.item(), iter+1) writer.add_scalar('Loss/train/G/gen', errG.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction', rec_loss.item(), iter+1) writer.add_scalar('Loss/train/G/cycle', cyc_loss.item(), iter+1) writer.add_scalar('Loss/train/G/gen2', errG2.item(), iter+1) writer.add_scalar('Loss/train/G/reconstruction2', rec_loss2.item(), iter+1) writer.add_scalar('Loss/train/G/cycle2', cyc_loss2.item(), iter+1) if iter % 500 == 0 or iter+1 == opt.niter: functions.save_image('{}/fake_sample_{}.jpg'.format(opt.outf, iter+1), fake.detach()) functions.save_image('{}/reconstruction_{}.jpg'.format(opt.outf, iter+1), rec.detach()) generate_samples(netG, opt, depth, noise_amp, writer, reals, iter+1) schedulerD.step() schedulerG.step() # break functions.save_networks(netG, netD, z_opt, netG2, netD2, z_opt2, opt) return fixed_noise, noise_amp, netG, netD, fixed_noise2, noise_amp2, netG2, netD2