D = [] for i in range(len(targets_dir)): tmpD = networks.discriminator(args.in_ndc, args.out_ndc, args.ndf) if args.latest_discriminator_model != '': if torch.cuda.is_available(): tmpD.load_state_dict( torch.load(targets_dir[i] + args.latest_discriminator_model)) else: tmpD.load_state_dict( torch.load(targets_dir[i] + args.latest_discriminator_model, map_location=lambda storage, loc: storage)) tmpD.to(device) tmpD.train() D.append(tmpD) VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True) VGG.to(device) VGG.eval() print('---------- Networks initialized -------------') utils.print_network(G) utils.print_network(D[0]) utils.print_network(VGG) print('-----------------------------------------------') # loss BCE_loss = nn.BCELoss().to(device) L1_loss = nn.L1Loss().to(device) # Adam optimizer G_optimizer = optim.Adam(G.parameters(), lr=args.lrG,
def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.backends.cudnn.enabled: torch.backends.cudnn.benchmark = True prepare_result() make_edge_promoting_img() # data_loader landscape_dataloader = CreateTrainDataLoader(args, "landscape") anime_dataloader = CreateTrainDataLoader(args, "anime") landscape_test_dataloader = CreateTestDataLoader(args, "landscape") anime_test_dataloader = CreateTestDataLoader(args, "anime") generator = networks.Generator(args.ngf) if args.latest_generator_model != '': if torch.cuda.is_available(): generator.load_state_dict(torch.load(args.latest_generator_model)) else: # cpu mode generator.load_state_dict( torch.load(args.latest_generator_model, map_location=lambda storage, loc: storage)) discriminator = networks.Discriminator(args.in_ndc, args.out_ndc, args.ndf) if args.latest_discriminator_model != '': if torch.cuda.is_available(): discriminator.load_state_dict( torch.load(args.latest_discriminator_model)) else: discriminator.load_state_dict( torch.load(args.latest_discriminator_model, map_location=lambda storage, loc: storage)) VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True) generator.to(device) discriminator.to(device) VGG.to(device) generator.train() discriminator.train() VGG.eval() G_optimizer = optim.Adam(generator.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) D_optimizer = optim.Adam(discriminator.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) # G_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=G_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) # D_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=D_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) print('---------- Networks initialized -------------') utils.print_network(generator) utils.print_network(discriminator) utils.print_network(VGG) print('-----------------------------------------------') BCE_loss = nn.BCELoss().to(device) Hinge_loss = nn.HingeEmbeddingLoss().to(device) L1_loss = nn.L1Loss().to(device) MSELoss = nn.MSELoss().to(device) Adv_loss = BCE_loss pre_train_hist = {} pre_train_hist['Recon_loss'] = [] pre_train_hist['per_epoch_time'] = [] pre_train_hist['total_time'] = [] """ Pre-train reconstruction """ if args.latest_generator_model == '': print('Pre-training start!') start_time = time.time() for epoch in range(args.pre_train_epoch): epoch_start_time = time.time() Recon_losses = [] for lcimg, lhimg, lsimg in landscape_dataloader: lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) # train generator G G_optimizer.zero_grad() x_feature = VGG((lcimg + 1) / 2) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) G_feature = VGG((gen_img + 1) / 2) Recon_loss = 10 * L1_loss(G_feature, x_feature.detach()) Recon_losses.append(Recon_loss.item()) pre_train_hist['Recon_loss'].append(Recon_loss.item()) Recon_loss.backward() G_optimizer.step() per_epoch_time = time.time() - epoch_start_time pre_train_hist['per_epoch_time'].append(per_epoch_time) print('[%d/%d] - time: %.2f, Recon loss: %.3f' % ((epoch + 1), args.pre_train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Recon_losses)))) # Save if (epoch + 1) % 5 == 0: with torch.no_grad(): generator.eval() for n, (lcimg, lhimg, lsimg) in enumerate(landscape_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_train_recon_' + f'epoch_{epoch}_' + str(n + 1) + '.png') plt.imsave( path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (lcimg, lhimg, lsimg) in enumerate(landscape_test_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_test_recon_' + f'epoch_{epoch}_' + str(n + 1) + '.png') plt.imsave( path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break total_time = time.time() - start_time pre_train_hist['total_time'].append(total_time) with open(os.path.join(args.name + '_results', 'pre_train_hist.pkl'), 'wb') as f: pickle.dump(pre_train_hist, f) torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'generator_pretrain.pkl')) else: print('Load the latest generator model, no need to pre-train') train_hist = {} train_hist['Disc_loss'] = [] train_hist['Gen_loss'] = [] train_hist['Con_loss'] = [] train_hist['per_epoch_time'] = [] train_hist['total_time'] = [] print('training start!') start_time = time.time() real = torch.ones(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) fake = torch.zeros(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) for epoch in range(args.train_epoch): epoch_start_time = time.time() generator.train() Disc_losses = [] Gen_losses = [] Con_losses = [] for i, ((acimg, ac_smooth_img, _), (lcimg, lhimg, lsimg)) in enumerate( zip(anime_dataloader, landscape_dataloader)): acimg, ac_smooth_img, lcimg, lhimg, lsimg = acimg.to( device), ac_smooth_img.to(device), lcimg.to(device), lhimg.to( device), lsimg.to(device) if i % args.n_dis == 0: # train G G_optimizer.zero_grad() mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) D_fake = discriminator(gen_img) D_fake_loss = Adv_loss(D_fake, real) x_feature = VGG((lcimg + 1) / 2) G_feature = VGG((gen_img + 1) / 2) Con_loss = args.con_lambda * L1_loss(G_feature, x_feature.detach()) Gen_loss = D_fake_loss + Con_loss Gen_losses.append(D_fake_loss.item()) train_hist['Gen_loss'].append(D_fake_loss.item()) Con_losses.append(Con_loss.item()) train_hist['Con_loss'].append(Con_loss.item()) Gen_loss.backward() G_optimizer.step() # G_scheduler.step() # train D D_optimizer.zero_grad() D_real = discriminator(acimg) D_real_loss = Adv_loss(D_real, real) # Hinge Loss (?) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) gen_img = generator(lsimg, hint) D_fake = discriminator(gen_img) D_fake_loss = Adv_loss(D_fake, fake) D_edge = discriminator(ac_smooth_img) D_edge_loss = Adv_loss(D_edge, fake) Disc_loss = D_real_loss + D_fake_loss + D_edge_loss # Disc_loss = D_real_loss + D_fake_loss Disc_losses.append(Disc_loss.item()) train_hist['Disc_loss'].append(Disc_loss.item()) Disc_loss.backward() D_optimizer.step() # G_scheduler.step() # D_scheduler.step() per_epoch_time = time.time() - epoch_start_time train_hist['per_epoch_time'].append(per_epoch_time) print( '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((epoch + 1), args.train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Disc_losses)), torch.mean(torch.FloatTensor(Gen_losses)), torch.mean(torch.FloatTensor(Con_losses)))) if epoch % 2 == 1 or epoch == args.train_epoch - 1: with torch.no_grad(): generator.eval() for n, (lcimg, lhimg, lsimg) in enumerate(landscape_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_train_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (lcimg, lhimg, lsimg) in enumerate(landscape_test_dataloader): lcimg, lhimg, lsimg = lcimg.to(device), lhimg.to( device), lsimg.to(device) mask = mask_gen() hint = torch.cat((lhimg * mask, mask), 1) g_recon = generator(lsimg, hint) result = torch.cat((lcimg[0], g_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_test_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'generator_latest.pkl')) torch.save( generator.state_dict(), os.path.join(args.name + '_results', 'discriminator_latest.pkl')) total_time = time.time() - start_time train_hist['total_time'].append(total_time) print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor( train_hist['per_epoch_time'])), args.train_epoch, total_time)) print("Training finish!... save training results") torch.save(generator.state_dict(), os.path.join(args.name + '_results', 'generator_param.pkl')) torch.save(discriminator.state_dict(), os.path.join(args.name + '_results', 'discriminator_param.pkl')) with open(os.path.join(args.name + '_results', 'train_hist.pkl'), 'wb') as f: pickle.dump(train_hist, f)
def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.backends.cudnn.enabled: torch.backends.cudnn.benchmark = True prepare_result() make_edge_promoting_img() # data_loader src_transform = transforms.Compose([ transforms.Resize((args.input_size, args.input_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) tgt_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) train_loader_src = utils.data_load(os.path.join('data', args.src_data), 'train', src_transform, args.batch_size, shuffle=True, drop_last=True) train_loader_tgt = utils.data_load(os.path.join('data', args.tgt_data), 'pair', tgt_transform, args.batch_size, shuffle=True, drop_last=True) test_loader_src = utils.data_load(os.path.join('data', args.src_data), 'test', src_transform, 1, shuffle=True, drop_last=True) # network G = networks.generator(args.in_ngc, args.out_ngc, args.ngf, args.nb) if args.latest_generator_model != '': if torch.cuda.is_available(): G.load_state_dict(torch.load(args.latest_generator_model)) else: # cpu mode G.load_state_dict( torch.load(args.latest_generator_model, map_location=lambda storage, loc: storage)) D = networks.discriminator(args.in_ndc, args.out_ndc, args.ndf) if args.latest_discriminator_model != '': if torch.cuda.is_available(): D.load_state_dict(torch.load(args.latest_discriminator_model)) else: D.load_state_dict( torch.load(args.latest_discriminator_model, map_location=lambda storage, loc: storage)) VGG = networks.VGG19(init_weights=args.vgg_model, feature_mode=True) G.to(device) D.to(device) VGG.to(device) G.train() D.train() VGG.eval() print('---------- Networks initialized -------------') utils.print_network(G) utils.print_network(D) utils.print_network(VGG) print('-----------------------------------------------') # loss BCE_loss = nn.BCELoss().to(device) L1_loss = nn.L1Loss().to(device) # Adam optimizer G_optimizer = optim.Adam(G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) D_optimizer = optim.Adam(D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) G_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=G_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) D_scheduler = optim.lr_scheduler.MultiStepLR( optimizer=D_optimizer, milestones=[args.train_epoch // 2, args.train_epoch // 4 * 3], gamma=0.1) pre_train_hist = {} pre_train_hist['Recon_loss'] = [] pre_train_hist['per_epoch_time'] = [] pre_train_hist['total_time'] = [] """ Pre-train reconstruction """ if args.latest_generator_model == '': print('Pre-training start!') start_time = time.time() for epoch in range(args.pre_train_epoch): epoch_start_time = time.time() Recon_losses = [] for x, _ in train_loader_src: x = x.to(device) # train generator G G_optimizer.zero_grad() x_feature = VGG((x + 1) / 2) G_ = G(x) G_feature = VGG((G_ + 1) / 2) Recon_loss = 10 * L1_loss(G_feature, x_feature.detach()) Recon_losses.append(Recon_loss.item()) pre_train_hist['Recon_loss'].append(Recon_loss.item()) Recon_loss.backward() G_optimizer.step() per_epoch_time = time.time() - epoch_start_time pre_train_hist['per_epoch_time'].append(per_epoch_time) print('[%d/%d] - time: %.2f, Recon loss: %.3f' % ((epoch + 1), args.pre_train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Recon_losses)))) total_time = time.time() - start_time pre_train_hist['total_time'].append(total_time) with open(os.path.join(args.name + '_results', 'pre_train_hist.pkl'), 'wb') as f: pickle.dump(pre_train_hist, f) with torch.no_grad(): G.eval() for n, (x, _) in enumerate(train_loader_src): x = x.to(device) G_recon = G(x) result = torch.cat((x[0], G_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_train_recon_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (x, _) in enumerate(test_loader_src): x = x.to(device) G_recon = G(x) result = torch.cat((x[0], G_recon[0]), 2) path = os.path.join( args.name + '_results', 'Reconstruction', args.name + '_test_recon_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break else: print('Load the latest generator model, no need to pre-train') train_hist = {} train_hist['Disc_loss'] = [] train_hist['Gen_loss'] = [] train_hist['Con_loss'] = [] train_hist['per_epoch_time'] = [] train_hist['total_time'] = [] print('training start!') start_time = time.time() real = torch.ones(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) fake = torch.zeros(args.batch_size, 1, args.input_size // 4, args.input_size // 4).to(device) for epoch in range(args.train_epoch): epoch_start_time = time.time() G.train() Disc_losses = [] Gen_losses = [] Con_losses = [] for (x, _), (y, _) in zip(train_loader_src, train_loader_tgt): e = y[:, :, :, args.input_size:] y = y[:, :, :, :args.input_size] x, y, e = x.to(device), y.to(device), e.to(device) # train D D_optimizer.zero_grad() D_real = D(y) D_real_loss = BCE_loss(D_real, real) G_ = G(x) D_fake = D(G_) D_fake_loss = BCE_loss(D_fake, fake) D_edge = D(e) D_edge_loss = BCE_loss(D_edge, fake) Disc_loss = D_real_loss + D_fake_loss + D_edge_loss Disc_losses.append(Disc_loss.item()) train_hist['Disc_loss'].append(Disc_loss.item()) Disc_loss.backward() D_optimizer.step() # train G G_optimizer.zero_grad() G_ = G(x) D_fake = D(G_) D_fake_loss = BCE_loss(D_fake, real) x_feature = VGG((x + 1) / 2) G_feature = VGG((G_ + 1) / 2) Con_loss = args.con_lambda * L1_loss(G_feature, x_feature.detach()) Gen_loss = D_fake_loss + Con_loss Gen_losses.append(D_fake_loss.item()) train_hist['Gen_loss'].append(D_fake_loss.item()) Con_losses.append(Con_loss.item()) train_hist['Con_loss'].append(Con_loss.item()) Gen_loss.backward() G_optimizer.step() G_scheduler.step() D_scheduler.step() per_epoch_time = time.time() - epoch_start_time train_hist['per_epoch_time'].append(per_epoch_time) print( '[%d/%d] - time: %.2f, Disc loss: %.3f, Gen loss: %.3f, Con loss: %.3f' % ((epoch + 1), args.train_epoch, per_epoch_time, torch.mean(torch.FloatTensor(Disc_losses)), torch.mean(torch.FloatTensor(Gen_losses)), torch.mean(torch.FloatTensor(Con_losses)))) if epoch % 2 == 1 or epoch == args.train_epoch - 1: with torch.no_grad(): G.eval() for n, (x, _) in enumerate(train_loader_src): x = x.to(device) G_recon = G(x) result = torch.cat((x[0], G_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_train_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break for n, (x, _) in enumerate(test_loader_src): x = x.to(device) G_recon = G(x) result = torch.cat((x[0], G_recon[0]), 2) path = os.path.join( args.name + '_results', 'Transfer', str(epoch + 1) + '_epoch_' + args.name + '_test_' + str(n + 1) + '.png') plt.imsave(path, (result.cpu().numpy().transpose(1, 2, 0) + 1) / 2) if n == 4: break torch.save( G.state_dict(), os.path.join(args.name + '_results', 'generator_latest.pkl')) torch.save( D.state_dict(), os.path.join(args.name + '_results', 'discriminator_latest.pkl')) total_time = time.time() - start_time train_hist['total_time'].append(total_time) print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor( train_hist['per_epoch_time'])), args.train_epoch, total_time)) print("Training finish!... save training results") torch.save(G.state_dict(), os.path.join(args.name + '_results', 'generator_param.pkl')) torch.save(D.state_dict(), os.path.join(args.name + '_results', 'discriminator_param.pkl')) with open(os.path.join(args.name + '_results', 'train_hist.pkl'), 'wb') as f: pickle.dump(train_hist, f)