def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 loader = data_loader.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face", config) G = ExpressionGenerater() D = RealFakeDiscriminator() #FEN = FeatureExtractNet() id_D = IdDiscriminator() kp_D = KeypointDiscriminator() points_G = LandMarksDetect() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( version) log = Logger(os.path.join(ckpt_dir, 'log.txt')) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) D_path = os.path.join(ckpt_dir, '{}-D.ckpt'.format(resume_iter)) D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter)) id_D.load_state_dict( torch.load(IdD_path, map_location=lambda storage, loc: storage)) kp_D_path = os.path.join(ckpt_dir, '{}-kpD.ckpt'.format(resume_iter)) kp_D.load_state_dict( torch.load(kp_D_path, map_location=lambda storage, loc: storage)) points_G_path = os.path.join(ckpt_dir, '{}-pG.ckpt'.format(resume_iter)) points_G.load_state_dict( torch.load(points_G_path, map_location=lambda storage, loc: storage)) else: resume_iter = 0 ##### 训练face2keypoint #### points_G_optimizer = torch.optim.Adam(points_G.parameters(), lr=0.0001, betas=(0.5, 0.9)) kp_D_optimizer = torch.optim.Adam(kp_D.parameters(), lr=0.0001, betas=(0.5, 0.9)) G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.9)) idD_optimizer = torch.optim.Adam(id_D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) id_D.to(device) D.to(device) kp_D.to(device) points_G.to(device) #FEN.to(device) #FEN.eval() # Start training from scratch or resume training. start_iters = resume_iter trigger_rec = 1 data_iter = iter(loader) # Start training. print('Start training...') for i in range(start_iters, 150000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # #faces, origin_points = next(data_iter) #_, target_points = next(data_iter) try: faces, origin_points = next(data_iter) except StopIteration: data_iter = iter(loader) faces, origin_points = next(data_iter) rand_idx = torch.randperm(origin_points.size(0)) target_points = origin_points[rand_idx] target_faces = faces[rand_idx] faces = faces.to(device) target_faces = target_faces.to(device) origin_points = origin_points.to(device) target_points = target_points.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # Real fake Dis real_loss = -torch.mean(D(faces)) # big for real faces_fake = G(faces, target_points) fake_loss = torch.mean(D(faces_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * faces.data + (1 - alpha) * faces_fake.data).requires_grad_(True) out_src = D(x_hat) d_loss_gp = gradient_penalty(out_src, x_hat) lambda_gp = 10 Dis_loss = real_loss + fake_loss + lambda_gp * d_loss_gp D_optimizer.zero_grad() Dis_loss.backward() D_optimizer.step() # ID Dis id_real_loss = -torch.mean(id_D(faces, target_faces)) # big for real faces_fake = G(faces, target_points) id_fake_loss = torch.mean(id_D(faces, faces_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * target_faces.data + (1 - alpha) * faces_fake.data).requires_grad_(True) out_src = id_D(faces, x_hat) id_d_loss_gp = gradient_penalty(out_src, x_hat) id_lambda_gp = 10 id_Dis_loss = id_real_loss + id_fake_loss + id_lambda_gp * id_d_loss_gp idD_optimizer.zero_grad() id_Dis_loss.backward() idD_optimizer.step() # Keypoints Dis kp_real_loss = -torch.mean(kp_D(target_faces, target_points)) # big for real faces_fake = G(faces, target_points) points_fake = points_G(faces_fake) kp_fake_loss = torch.mean(kp_D(target_faces, points_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * target_points.data + (1 - alpha) * points_fake.data).requires_grad_(True) out_src = kp_D(target_faces, x_hat) kp_d_loss_gp = gradient_penalty(out_src, x_hat) kp_lambda_gp = 10 kp_Dis_loss = kp_real_loss + kp_fake_loss + kp_lambda_gp * kp_d_loss_gp kp_D_optimizer.zero_grad() kp_Dis_loss.backward() kp_D_optimizer.step() # if (i + 1) % 5 == 0: # print("iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}".format(i,real_loss.item(), # fake_loss.item(), # lambda_gp * d_loss_gp # )) # =================================================================================== # # 3. Train the keypointsDetecter # # =================================================================================== # points_detect = points_G(faces) detecter_loss_clear = torch.mean( torch.abs(points_detect - origin_points)) detecter_loss = detecter_loss_clear points_G_optimizer.zero_grad() detecter_loss.backward() points_G_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 4 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(faces, target_points) predict_points = points_G(faces_fake) g_keypoints_loss = -torch.mean(kp_D(target_faces, predict_points)) g_fake_loss = -torch.mean(D(faces_fake)) # reconstructs = G(faces_fake, origin_points) # g_cycle_loss = torch.mean(torch.abs(reconstructs - faces)) g_id_loss = -torch.mean(id_D(faces, faces_fake)) l1_loss = torch.mean(torch.abs(faces_fake - target_faces)) #feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(target_faces))) # 轮流训练 # if (i+1) % 50 == 0: # trigger_rec = 1 - trigger_rec # print("trigger_rec : ", trigger_rec) lambda_rec = config.lambda_rec # 2 to 4 to 8 lambda_l1 = config.lambda_l1 lambda_keypoint = config.lambda_keypoint # 100 to 50 lambda_fake = config.lambda_fake lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_keypoint * g_keypoints_loss + lambda_fake*g_fake_loss \ + lambda_id * g_id_loss + lambda_l1 * l1_loss# + lambda_feature*feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0: print( "iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}, id_real_loss {:.2}, " "id_fake_loss {:.2}, id_loss_gp {:.2} , g_keypoints_loss {:.2}, " "g_fake_loss {:.2}, g_id_loss {:.2}, L1_loss {:.2}".format( i, real_loss.item(), fake_loss.item(), lambda_gp * d_loss_gp, id_real_loss.item(), id_fake_loss.item(), id_lambda_gp * id_d_loss_gp, lambda_keypoint * g_keypoints_loss.item(), lambda_fake * g_fake_loss.item(), lambda_id * g_id_loss.item(), lambda_l1 * l1_loss)) sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0: with torch.no_grad(): target_point = target_points[0] fake_face = faces_fake[0] face = faces[0] #reconstruct = reconstructs[0] predict_point = predict_points[0] sample_path_face = os.path.join( sample_dir, '{}-image-face.jpg'.format(i + 1)) save_image(denorm(face.data.cpu()), sample_path_face) # sample_path_rec = os.path.join(sample_dir, '{}-image-reconstruct.jpg'.format(i + 1)) # save_image(denorm(reconstruct.data.cpu()), sample_path_rec) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) sample_path_target = os.path.join( sample_dir, '{}-image-target_point.jpg'.format(i + 1)) save_image(denorm(target_point.data.cpu()), sample_path_target) sample_path_predict_points = os.path.join( sample_dir, '{}-image-predict_point.jpg'.format(i + 1)) save_image(denorm(predict_point.data.cpu()), sample_path_predict_points) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) point_G_path = os.path.join(model_save_dir, '{}-pG.ckpt'.format(i + 1)) torch.save(points_G.state_dict(), point_G_path) kp_D_path = os.path.join(model_save_dir, '{}-kpD.ckpt'.format(i + 1)) torch.save(kp_D.state_dict(), kp_D_path) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(D.state_dict(), D_path) idD_path = os.path.join(model_save_dir, '{}-idD.ckpt'.format(i + 1)) torch.save(id_D.state_dict(), idD_path) print('Saved model checkpoints into {}...'.format(model_save_dir))
def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 loader = data_loader_fusion_v5.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/fusion_face", config) G = FusionGenerater() D = RealFakeDiscriminator() # FEN = FeatureExtractNet() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/fusion-ckpt-{}".format( version) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) D_path = os.path.join(ckpt_dir, '{}-D.ckpt'.format(resume_iter)) D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) print("Load ckpt !!!") else: print("Found NO ckpt !!!") resume_iter = 0 ##### 训练face2keypoint #### G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) D.to(device) # FEN.to(device) # FEN.eval() # Start training from scratch or resume training. start_iters = resume_iter data_iter = iter(loader) # Start training. print('Start training...') for i in range(start_iters, 100000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # try: fullfaces, cropfaces, maskfaces, _, knockoutfaces = next(data_iter) except StopIteration: print("DATA STOPPED !!") data_iter = iter(loader) fullfaces, cropfaces, maskfaces, _, knockoutfaces = next(data_iter) rand_idx = torch.randperm(fullfaces.size(0)) target_crops = cropfaces[rand_idx] target_masks = maskfaces[rand_idx] target_fulls = fullfaces[rand_idx] fullfaces = fullfaces.to(device) # TODO: alt. change to knockout knockoutfaces = knockoutfaces.to(device) target_crops = target_crops.to(device) target_masks = target_masks.to(device) target_fulls = target_fulls.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # Real fake Dis real_loss = -torch.mean(D(fullfaces)) # big for real faces_fake = G(knockoutfaces, target_masks, target_crops) fake_loss = torch.mean(D(faces_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(fullfaces.size(0), 1, 1, 1).to(device) x_hat = (alpha * fullfaces.data + (1 - alpha) * faces_fake.data).requires_grad_(True) out_src = D(x_hat) d_loss_gp = gradient_penalty(out_src, x_hat) lambda_gp = 10 Dis_loss = real_loss + fake_loss + lambda_gp * d_loss_gp D_optimizer.zero_grad() Dis_loss.backward() D_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 4 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(knockoutfaces, target_masks, target_crops) g_fake_loss = -torch.mean(D(faces_fake)) # reconstructs = G(faces_fake, origin_points) # g_cycle_loss = torch.mean(torch.abs(reconstructs - faces)) l1_loss = torch.mean(maskfaces * torch.abs(faces_fake - target_fulls)) feature_loss = 0 # feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(target_fulls))) lambda_rec = config.lambda_rec # 2 to 4 to 8 lambda_l1 = config.lambda_l1 lambda_keypoint = config.lambda_keypoint # 100 to 50 lambda_fake = config.lambda_fake lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_fake*g_fake_loss \ + lambda_l1 * l1_loss + lambda_feature*feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0: print( "iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}" " g_fake_loss {:.2}, L1_loss {:.2}".format( i, real_loss.item(), fake_loss.item(), lambda_gp * d_loss_gp, lambda_fake * g_fake_loss.item(), lambda_l1 * l1_loss.item())) sample_dir = "fusion-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0: with torch.no_grad(): target_mask = target_masks[0] fake_face = faces_fake[0] target_full = target_fulls[0] target_crop = target_crops[0] #origin_full = fullfaces[0] knockoutface = knockoutfaces[0] sample_path_face = os.path.join( sample_dir, '{}-image-mask.jpg'.format(i + 1)) save_image(denorm(target_mask.data.cpu()), sample_path_face) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) sample_path_target = os.path.join( sample_dir, '{}-image-target-full.jpg'.format(i + 1)) save_image(denorm(target_full.data.cpu()), sample_path_target) sample_path_predict_points = os.path.join( sample_dir, '{}-image-crop.jpg'.format(i + 1)) save_image(denorm(target_crop.data.cpu()), sample_path_predict_points) # sample_path_predict_full = os.path.join(sample_dir, '{}-image-origin_full.jpg'.format(i + 1)) # save_image(denorm(origin_full.data.cpu()), sample_path_predict_full) sample_path_predict_full = os.path.join( sample_dir, '{}-image-origin_knockout.jpg'.format(i + 1)) save_image(denorm(knockoutface.data.cpu()), sample_path_predict_full) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "fusion-ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(D.state_dict(), D_path) print('Saved model checkpoints into {}...'.format(model_save_dir))