def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 finetune_loader = data_loader_augment.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face", config) loader, first_loader = data_loader.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_sideface", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_sideface", config) points_G = LandMarksDetect() G = FastTrackExpressionGenerater() D = SNResRealFakeDiscriminator() FEN = FeatureExtractNet() id_D = SNResIdDiscriminator() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( version) if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir) log = Logger(os.path.join(ckpt_dir, 'log.txt')) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) D_path = os.path.join(ckpt_dir, '{}-D.ckpt'.format(resume_iter)) D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter)) id_D.load_state_dict( torch.load(IdD_path, map_location=lambda storage, loc: storage)) points_G_path = os.path.join(ckpt_dir, '{}-pG.ckpt'.format(resume_iter)) points_G.load_state_dict( torch.load(points_G_path, map_location=lambda storage, loc: storage)) else: resume_iter = 0 ##### 训练face2keypoint #### points_G_optimizer = torch.optim.Adam(points_G.parameters(), lr=0.0001, betas=(0.5, 0.9)) G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.9)) idD_optimizer = torch.optim.Adam(id_D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) id_D.to(device) D.to(device) points_G.to(device) FEN.to(device) FEN.eval() log.print(config) # Start training from scratch or resume training. start_iters = resume_iter trigger_rec = 1 finetune_iter = iter(finetune_loader) data_iter = iter(loader) first_data_iters = iter(first_loader) # Start training. print('Start training...') for i in range(start_iters, 150000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # if i % 2 == 0: try: faces, origin_points = next(finetune_iter) except StopIteration: finetune_iter = iter(finetune_loader) faces, origin_points = next(finetune_iter) rand_idx = torch.randperm(origin_points.size(0)) target_points = origin_points[rand_idx] target_faces = faces[rand_idx] first_faces = faces.to(device) target_faces = target_faces.to(device) first_points = origin_points.to(device) target_points = target_points.to(device) else: try: first_faces, first_points = next(first_data_iters) except StopIteration: first_data_iters = iter(first_loader) first_faces, first_points = next(first_data_iters) try: target_faces, target_points = next(data_iter) except StopIteration: data_iter = iter(loader) target_faces, target_points = next(data_iter) target_faces = target_faces.to(device) first_faces = first_faces.to(device) target_points = target_points.to(device) first_points = first_points.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # Real fake Dis real_loss = torch.mean(softplus(-D(target_faces))) # big for real faces_fake = G(first_faces, target_points) fake_loss = torch.mean(softplus(D(faces_fake))) # small for fake Dis_loss = real_loss + fake_loss D_optimizer.zero_grad() Dis_loss.backward() D_optimizer.step() # ID Dis id_real_loss = torch.mean( softplus(-id_D(first_faces, target_faces))) # big for real faces_fake = G(first_faces, target_points) id_fake_loss = torch.mean(softplus(id_D(first_faces, faces_fake))) # small for fake id_Dis_loss = id_real_loss + id_fake_loss idD_optimizer.zero_grad() id_Dis_loss.backward() idD_optimizer.step() # =================================================================================== # # 3. Train the keypointsDetecter # # =================================================================================== # points_detect = points_G(target_faces) detecter_loss_clear = torch.mean( torch.abs(points_detect - target_points)) detecter_loss = detecter_loss_clear points_G_optimizer.zero_grad() detecter_loss.backward() points_G_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 1 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(first_faces, target_points) predict_points = points_G(faces_fake) g_keypoints_loss = torch.mean( torch.abs(predict_points - target_points)) g_fake_loss = torch.mean(softplus(-D(faces_fake))) # reconstructs = G(faces_fake, origin_points) # g_cycle_loss = torch.mean(torch.abs(reconstructs - faces)) g_id_loss = torch.mean(softplus(-id_D(first_faces, faces_fake))) l1_loss = torch.mean(torch.abs(faces_fake - target_faces)) feature_loss = torch.mean( torch.abs(FEN(faces_fake) - FEN(target_faces))) lambda_rec = config.lambda_rec # 2 to 4 to 8 lambda_l1 = config.lambda_l1 lambda_keypoint = config.lambda_keypoint # 100 to 50 lambda_fake = config.lambda_fake lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_keypoint * g_keypoints_loss + lambda_fake*g_fake_loss \ + lambda_id * g_id_loss + lambda_l1 * l1_loss + lambda_feature*feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0 or i % 8 == 0: log.print( "iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, id_real_loss {:.2}, " "id_fake_loss {:.2} , g_keypoints_loss {:.2}, " "g_fake_loss {:.2}, g_id_loss {:.2}, L1_loss {:.2}, feature_loss {:.2}" .format(i, real_loss.item(), fake_loss.item(), id_real_loss.item(), id_fake_loss.item(), lambda_keypoint * g_keypoints_loss.item(), lambda_fake * g_fake_loss.item(), lambda_id * g_id_loss.item(), lambda_l1 * l1_loss, lambda_feature * feature_loss.item())) sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0 or i % 20 == 0: with torch.no_grad(): target_point = target_points[0] fake_face = faces_fake[0] face = first_faces[0] target_face = target_faces[0] #reconstruct = reconstructs[0] predict_point = predict_points[0] sample_path_face = os.path.join( sample_dir, '{}-image-face.jpg'.format(i + 1)) save_image(denorm(face.data.cpu()), sample_path_face) sample_path_face = os.path.join( sample_dir, '{}-image-targetface.jpg'.format(i + 1)) save_image(denorm(target_face.data.cpu()), sample_path_face) # sample_path_rec = os.path.join(sample_dir, '{}-image-reconstruct.jpg'.format(i + 1)) # save_image(denorm(reconstruct.data.cpu()), sample_path_rec) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) sample_path_target = os.path.join( sample_dir, '{}-image-target_point.jpg'.format(i + 1)) save_image(denorm(target_point.data.cpu()), sample_path_target) sample_path_predict_points = os.path.join( sample_dir, '{}-image-predict_point.jpg'.format(i + 1)) save_image(denorm(predict_point.data.cpu()), sample_path_predict_points) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) point_G_path = os.path.join(model_save_dir, '{}-pG.ckpt'.format(i + 1)) torch.save(points_G.state_dict(), point_G_path) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(D.state_dict(), D_path) idD_path = os.path.join(model_save_dir, '{}-idD.ckpt'.format(i + 1)) torch.save(id_D.state_dict(), idD_path) print('Saved model checkpoints into {}...'.format(model_save_dir))
def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 loader = data_loader.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face", config) G = ExpressionGenerater() FEN = FeatureExtractNet() color_D = SNResIdDiscriminator() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( version) if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir) log = Logger(os.path.join(ckpt_dir, 'log.txt')) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter)) color_D.load_state_dict( torch.load(IdD_path, map_location=lambda storage, loc: storage)) else: resume_iter = 0 ##### 训练face2keypoint #### G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) idD_optimizer = torch.optim.Adam(color_D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) color_D.to(device) FEN.to(device) FEN.eval() log.print(config) # Start training from scratch or resume training. start_iters = resume_iter trigger_rec = 1 data_iter = iter(loader) # Start training. print('Start training...') for i in range(start_iters, 150000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # #faces, origin_points = next(data_iter) #_, target_points = next(data_iter) try: color_faces, faces = next(data_iter) except StopIteration: data_iter = iter(loader) color_faces, faces = next(data_iter) rand_idx = torch.randperm(faces.size(0)) condition_faces = faces[rand_idx] faces = faces.to(device) color_faces = color_faces.to(device) condition_faces = condition_faces.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # ID Dis id_real_loss = torch.mean( softplus(-color_D(faces, condition_faces))) # big for real faces_fake = G(color_faces, condition_faces) id_fake_loss = torch.mean( softplus(color_D(faces_fake, condition_faces))) # small for fake id_Dis_loss = id_real_loss + id_fake_loss idD_optimizer.zero_grad() id_Dis_loss.backward() idD_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 1 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(color_faces, condition_faces) g_id_loss = torch.mean( softplus(-color_D(faces_fake, condition_faces))) l1_loss = torch.mean( torch.abs(faces_fake - faces)) + (1 - ssim(faces_fake, faces)) feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(faces))) lambda_l1 = config.lambda_l1 lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_id * g_id_loss + lambda_l1 * l1_loss + lambda_feature * feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0: log.print( "iter {} - id_real_loss {:.2}, " "id_fake_loss {:.2} , g_id_loss {:.2}, L1_loss {:.2}, feature_loss {:.2}" .format(i, id_real_loss.item(), id_fake_loss.item(), lambda_id * g_id_loss.item(), lambda_l1 * l1_loss, lambda_feature * feature_loss.item())) sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0: with torch.no_grad(): fake_face = faces_fake[0] condition_face = condition_faces[0] color_face = color_faces[0] #reconstruct = reconstructs[0] sample_path_face = os.path.join( sample_dir, '{}-image-face.jpg'.format(i + 1)) save_image(denorm(condition_face.data.cpu()), sample_path_face) sample_path_rec = os.path.join( sample_dir, '{}-image-color.jpg'.format(i + 1)) save_image(denorm(color_face.data.cpu()), sample_path_rec) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) idD_path = os.path.join(model_save_dir, '{}-idD.ckpt'.format(i + 1)) torch.save(color_D.state_dict(), idD_path) print('Saved model checkpoints into {}...'.format(model_save_dir))