def __init__(self, config, json_file): # For fast training. self.config = config cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.version = config.version self.color_version = config.color_version self.sf_version = config.sf_version self.point_threshold = config.point_threshold self.G = ExpressionGenerater() self.sfG = ExpressionGenerater() self.colorG = NoTrackExpressionGenerater() ####### 载入预训练网络 ###### self.load() self.G.to(device) self.colorG.to(device) if config.eval == "1": self.G.eval() self.colorG.eval() self.transform = [] self.transform.append(T.Resize((224, 224))) self.transform.append(T.ToTensor()) self.transform.append( T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) self.transform = T.Compose(self.transform) self.vid_annos = self.load_json(json_file)
def generate_video(config, first_frm_file, json_file): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version G = ExpressionGenerater() ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/fusion-ckpt-{}".format( config.fusion_version) FG = FusionGenerater() FG_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(config.fusion_resume_iter)) FG.load_state_dict( torch.load(FG_path, map_location=lambda storage, loc: storage)) FG.eval() ############# process data ########## def crop_face(img, bbox, keypoint): flags = list() points = list() # can not detect face in some images if len(bbox) == 0: return None # draw bbox x, y, w, h = [int(v) for v in bbox] crop_img = img[y:y + h, x:x + w] return crop_img def load_json(path): with open(path, 'r') as f: data = json.load(f) print("load %d video annotations totally." % len(data)) vid_anns = dict() for anns in data: name = anns['video'] path = anns['video_path'] vid_anns[name] = {'path': path} for ann in anns['annotations']: idx = ann['index'] keypoints = ann['keypoint'] bbox = ann['bbox'] vid_anns[name][idx] = [bbox, keypoints] return vid_anns def draw_bbox_keypoints(img, bbox, keypoint): flags = list() points = list() # can not detect face in some images if len(bbox) == 0: return None, None points_image = np.zeros_like(img, np.uint8) # draw bbox # x, y, w, h = [int(v) for v in bbox] # cv2.rectangle(img, (x, y), (x+w, y+h), (0, 0, 255), 2) # draw points for i in range(0, len(keypoint), 3): x, y, flag = [int(k) for k in keypoint[i:i + 3]] flags.append(flag) points.append([x, y]) if flag == 0: # keypoint not exist continue elif flag == 1: # keypoint exist but invisible cv2.circle(points_image, (x, y), 3, (0, 0, 255), -1) elif flag == 2: # keypoint exist and visible cv2.circle(points_image, (x, y), 3, (0, 255, 0), -1) else: raise ValueError("flag of keypoint must be 0, 1, or 2.") return crop_face(points_image, bbox, keypoint), crop_face(img, bbox, keypoint) def extract_image(img, bbox, keypoint): flags = list() points = list() # can not detect face in some images if len(bbox) == 0: return None, None, None, None, None mask_image = np.zeros_like(img, np.uint8) mask = np.zeros_like(img, np.uint8) Knockout_image = img.copy() # draw bbox x, y, w, h = [int(v) for v in bbox] print(x, y, w, h) cv2.rectangle(mask_image, (x, y), (x + w, y + h), (255, 255, 255), cv2.FILLED) cv2.rectangle(Knockout_image, (x, y), (x + w, y + h), (0, 0, 0), cv2.FILLED) mask = cv2.rectangle(mask, (x, y), (x + w, y + h), (1, 1, 1), cv2.FILLED) onlyface = img * mask mask_points = mask_image.copy() # draw points for i in range(0, len(keypoint), 3): x, y, flag = [int(k) for k in keypoint[i:i + 3]] flags.append(flag) points.append([x, y]) if flag == 0: # keypoint not exist continue elif flag == 1: # keypoint exist but invisible cv2.circle(mask_points, (x, y), 3, (0, 0, 255), -1) elif flag == 2: # keypoint exist and visible cv2.circle(mask_points, (x, y), 3, (0, 255, 0), -1) else: raise ValueError("flag of keypoint must be 0, 1, or 2.") return mask_image, Knockout_image, onlyface, img, mask_points first_frm = Image.open(first_frm_file) vid_annos = load_json(json_file) vid_name = os.path.basename(first_frm_file) vid_name = "{}.mp4".format(vid_name.split(".")[0]) anno = vid_annos[vid_name] first_bbox, first_keypoint = anno[1] _, first_knockout_image, _, first_img, _ = extract_image( np.array(first_frm), first_bbox, first_keypoint) frm = cv2.imread(first_frm_file) _, _, channels = frm.shape size = (544, 720) fourcc = cv2.VideoWriter_fourcc(*'XVID') sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) out_path = os.path.join( sample_dir, 'out_{}_{}_{}.mp4'.format(config.version, config.test_image, config.resume_iter)) vid_writer = cv2.VideoWriter(out_path, fourcc, 10, size) if not os.path.isdir("test_result"): os.mkdir("test_result") print(len(anno)) frm_crop = None is_first = True for idx in range(len(anno) - 1): print(idx) # if (idx+1) % 5 != 1: # continue bbox, keypoint = anno[idx + 1] ###### 融合 cv2.imwrite("test_result/maskface_{}.jpg".format(idx), cv2.cvtColor(np.asarray(first_frm), cv2.COLOR_RGB2BGR)) mask_image, Knockout_image, onlyface, img, mask_points = extract_image( first_frm, bbox, keypoint) cv2.imwrite("test_result/mask_{}.jpg".format(idx), mask_points) ### TODO: here i have mask_points, maskface, Knockout_image def transform_images(a, b, c): transform = [] transform.append(T.Resize((720, 544))) transform.append(T.ToTensor()) transform.append( T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) a_copy = Image.fromarray(a, 'RGB') b = Image.fromarray(b, 'RGB') c = Image.fromarray(c, 'RGB') a_copy = transform(a_copy) b = transform(b) c = transform(c) a_copy = a_copy.unsqueeze(0) b = b.unsqueeze(0) c = c.unsqueeze(0) return a_copy, b, c first_img_copy, mask_image, mask = transform_images( first_img, mask_image, mask) fake_frm = FG(first_img_copy, mask_image, mask) fake_frm = denorm(fake_frm.data.cpu()) toPIL = T.ToPILImage() fake_frm = toPIL(fake_frm.squeeze()) # fake_frm = cv2.cvtColor(np.asarray(fake_frm), cv2.COLOR_RGB2BGR) cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm) vid_writer.write(fake_frm) vid_writer.release()
def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 loader = data_loader.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face", config) G = ExpressionGenerater() D = RealFakeDiscriminator() #FEN = FeatureExtractNet() id_D = IdDiscriminator() kp_D = KeypointDiscriminator() points_G = LandMarksDetect() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( version) log = Logger(os.path.join(ckpt_dir, 'log.txt')) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) D_path = os.path.join(ckpt_dir, '{}-D.ckpt'.format(resume_iter)) D.load_state_dict( torch.load(D_path, map_location=lambda storage, loc: storage)) IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter)) id_D.load_state_dict( torch.load(IdD_path, map_location=lambda storage, loc: storage)) kp_D_path = os.path.join(ckpt_dir, '{}-kpD.ckpt'.format(resume_iter)) kp_D.load_state_dict( torch.load(kp_D_path, map_location=lambda storage, loc: storage)) points_G_path = os.path.join(ckpt_dir, '{}-pG.ckpt'.format(resume_iter)) points_G.load_state_dict( torch.load(points_G_path, map_location=lambda storage, loc: storage)) else: resume_iter = 0 ##### 训练face2keypoint #### points_G_optimizer = torch.optim.Adam(points_G.parameters(), lr=0.0001, betas=(0.5, 0.9)) kp_D_optimizer = torch.optim.Adam(kp_D.parameters(), lr=0.0001, betas=(0.5, 0.9)) G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) D_optimizer = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.9)) idD_optimizer = torch.optim.Adam(id_D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) id_D.to(device) D.to(device) kp_D.to(device) points_G.to(device) #FEN.to(device) #FEN.eval() # Start training from scratch or resume training. start_iters = resume_iter trigger_rec = 1 data_iter = iter(loader) # Start training. print('Start training...') for i in range(start_iters, 150000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # #faces, origin_points = next(data_iter) #_, target_points = next(data_iter) try: faces, origin_points = next(data_iter) except StopIteration: data_iter = iter(loader) faces, origin_points = next(data_iter) rand_idx = torch.randperm(origin_points.size(0)) target_points = origin_points[rand_idx] target_faces = faces[rand_idx] faces = faces.to(device) target_faces = target_faces.to(device) origin_points = origin_points.to(device) target_points = target_points.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # Real fake Dis real_loss = -torch.mean(D(faces)) # big for real faces_fake = G(faces, target_points) fake_loss = torch.mean(D(faces_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * faces.data + (1 - alpha) * faces_fake.data).requires_grad_(True) out_src = D(x_hat) d_loss_gp = gradient_penalty(out_src, x_hat) lambda_gp = 10 Dis_loss = real_loss + fake_loss + lambda_gp * d_loss_gp D_optimizer.zero_grad() Dis_loss.backward() D_optimizer.step() # ID Dis id_real_loss = -torch.mean(id_D(faces, target_faces)) # big for real faces_fake = G(faces, target_points) id_fake_loss = torch.mean(id_D(faces, faces_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * target_faces.data + (1 - alpha) * faces_fake.data).requires_grad_(True) out_src = id_D(faces, x_hat) id_d_loss_gp = gradient_penalty(out_src, x_hat) id_lambda_gp = 10 id_Dis_loss = id_real_loss + id_fake_loss + id_lambda_gp * id_d_loss_gp idD_optimizer.zero_grad() id_Dis_loss.backward() idD_optimizer.step() # Keypoints Dis kp_real_loss = -torch.mean(kp_D(target_faces, target_points)) # big for real faces_fake = G(faces, target_points) points_fake = points_G(faces_fake) kp_fake_loss = torch.mean(kp_D(target_faces, points_fake)) # small for fake # Compute loss for gradient penalty. alpha = torch.rand(target_faces.size(0), 1, 1, 1).to(device) x_hat = (alpha * target_points.data + (1 - alpha) * points_fake.data).requires_grad_(True) out_src = kp_D(target_faces, x_hat) kp_d_loss_gp = gradient_penalty(out_src, x_hat) kp_lambda_gp = 10 kp_Dis_loss = kp_real_loss + kp_fake_loss + kp_lambda_gp * kp_d_loss_gp kp_D_optimizer.zero_grad() kp_Dis_loss.backward() kp_D_optimizer.step() # if (i + 1) % 5 == 0: # print("iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}".format(i,real_loss.item(), # fake_loss.item(), # lambda_gp * d_loss_gp # )) # =================================================================================== # # 3. Train the keypointsDetecter # # =================================================================================== # points_detect = points_G(faces) detecter_loss_clear = torch.mean( torch.abs(points_detect - origin_points)) detecter_loss = detecter_loss_clear points_G_optimizer.zero_grad() detecter_loss.backward() points_G_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 4 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(faces, target_points) predict_points = points_G(faces_fake) g_keypoints_loss = -torch.mean(kp_D(target_faces, predict_points)) g_fake_loss = -torch.mean(D(faces_fake)) # reconstructs = G(faces_fake, origin_points) # g_cycle_loss = torch.mean(torch.abs(reconstructs - faces)) g_id_loss = -torch.mean(id_D(faces, faces_fake)) l1_loss = torch.mean(torch.abs(faces_fake - target_faces)) #feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(target_faces))) # 轮流训练 # if (i+1) % 50 == 0: # trigger_rec = 1 - trigger_rec # print("trigger_rec : ", trigger_rec) lambda_rec = config.lambda_rec # 2 to 4 to 8 lambda_l1 = config.lambda_l1 lambda_keypoint = config.lambda_keypoint # 100 to 50 lambda_fake = config.lambda_fake lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_keypoint * g_keypoints_loss + lambda_fake*g_fake_loss \ + lambda_id * g_id_loss + lambda_l1 * l1_loss# + lambda_feature*feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0: print( "iter {} - d_real_loss {:.2}, d_fake_loss {:.2}, d_loss_gp {:.2}, id_real_loss {:.2}, " "id_fake_loss {:.2}, id_loss_gp {:.2} , g_keypoints_loss {:.2}, " "g_fake_loss {:.2}, g_id_loss {:.2}, L1_loss {:.2}".format( i, real_loss.item(), fake_loss.item(), lambda_gp * d_loss_gp, id_real_loss.item(), id_fake_loss.item(), id_lambda_gp * id_d_loss_gp, lambda_keypoint * g_keypoints_loss.item(), lambda_fake * g_fake_loss.item(), lambda_id * g_id_loss.item(), lambda_l1 * l1_loss)) sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0: with torch.no_grad(): target_point = target_points[0] fake_face = faces_fake[0] face = faces[0] #reconstruct = reconstructs[0] predict_point = predict_points[0] sample_path_face = os.path.join( sample_dir, '{}-image-face.jpg'.format(i + 1)) save_image(denorm(face.data.cpu()), sample_path_face) # sample_path_rec = os.path.join(sample_dir, '{}-image-reconstruct.jpg'.format(i + 1)) # save_image(denorm(reconstruct.data.cpu()), sample_path_rec) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) sample_path_target = os.path.join( sample_dir, '{}-image-target_point.jpg'.format(i + 1)) save_image(denorm(target_point.data.cpu()), sample_path_target) sample_path_predict_points = os.path.join( sample_dir, '{}-image-predict_point.jpg'.format(i + 1)) save_image(denorm(predict_point.data.cpu()), sample_path_predict_points) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) point_G_path = os.path.join(model_save_dir, '{}-pG.ckpt'.format(i + 1)) torch.save(points_G.state_dict(), point_G_path) kp_D_path = os.path.join(model_save_dir, '{}-kpD.ckpt'.format(i + 1)) torch.save(kp_D.state_dict(), kp_D_path) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) D_path = os.path.join(model_save_dir, '{}-D.ckpt'.format(i + 1)) torch.save(D.state_dict(), D_path) idD_path = os.path.join(model_save_dir, '{}-idD.ckpt'.format(i + 1)) torch.save(id_D.state_dict(), idD_path) print('Saved model checkpoints into {}...'.format(model_save_dir))
def main(config): # For fast training. cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') version = config.version beta1 = 0.5 beta2 = 0.999 loader = data_loader.get_loader( "/media/data2/laixc/AI_DATA/expression_transfer/face12/crop_face", "/media/data2/laixc/AI_DATA/expression_transfer/face12/points_face", config) G = ExpressionGenerater() FEN = FeatureExtractNet() color_D = SNResIdDiscriminator() ####### 载入预训练网络 ###### resume_iter = config.resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( version) if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir) log = Logger(os.path.join(ckpt_dir, 'log.txt')) if os.path.exists(os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) IdD_path = os.path.join(ckpt_dir, '{}-idD.ckpt'.format(resume_iter)) color_D.load_state_dict( torch.load(IdD_path, map_location=lambda storage, loc: storage)) else: resume_iter = 0 ##### 训练face2keypoint #### G_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.9)) idD_optimizer = torch.optim.Adam(color_D.parameters(), lr=0.001, betas=(0.5, 0.9)) G.to(device) color_D.to(device) FEN.to(device) FEN.eval() log.print(config) # Start training from scratch or resume training. start_iters = resume_iter trigger_rec = 1 data_iter = iter(loader) # Start training. print('Start training...') for i in range(start_iters, 150000): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # #faces, origin_points = next(data_iter) #_, target_points = next(data_iter) try: color_faces, faces = next(data_iter) except StopIteration: data_iter = iter(loader) color_faces, faces = next(data_iter) rand_idx = torch.randperm(faces.size(0)) condition_faces = faces[rand_idx] faces = faces.to(device) color_faces = color_faces.to(device) condition_faces = condition_faces.to(device) # =================================================================================== # # 3. Train the discriminator # # =================================================================================== # # ID Dis id_real_loss = torch.mean( softplus(-color_D(faces, condition_faces))) # big for real faces_fake = G(color_faces, condition_faces) id_fake_loss = torch.mean( softplus(color_D(faces_fake, condition_faces))) # small for fake id_Dis_loss = id_real_loss + id_fake_loss idD_optimizer.zero_grad() id_Dis_loss.backward() idD_optimizer.step() # =================================================================================== # # 3. Train the generator # # =================================================================================== # n_critic = 1 if (i + 1) % n_critic == 0: # Original-to-target domain. faces_fake = G(color_faces, condition_faces) g_id_loss = torch.mean( softplus(-color_D(faces_fake, condition_faces))) l1_loss = torch.mean( torch.abs(faces_fake - faces)) + (1 - ssim(faces_fake, faces)) feature_loss = torch.mean(torch.abs(FEN(faces_fake) - FEN(faces))) lambda_l1 = config.lambda_l1 lambda_id = config.lambda_id lambda_feature = config.lambda_feature g_loss = lambda_id * g_id_loss + lambda_l1 * l1_loss + lambda_feature * feature_loss G_optimizer.zero_grad() g_loss.backward() G_optimizer.step() # Print out training information. if (i + 1) % 4 == 0: log.print( "iter {} - id_real_loss {:.2}, " "id_fake_loss {:.2} , g_id_loss {:.2}, L1_loss {:.2}, feature_loss {:.2}" .format(i, id_real_loss.item(), id_fake_loss.item(), lambda_id * g_id_loss.item(), lambda_l1 * l1_loss, lambda_feature * feature_loss.item())) sample_dir = "gan-sample-{}".format(version) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) if (i + 1) % 24 == 0: with torch.no_grad(): fake_face = faces_fake[0] condition_face = condition_faces[0] color_face = color_faces[0] #reconstruct = reconstructs[0] sample_path_face = os.path.join( sample_dir, '{}-image-face.jpg'.format(i + 1)) save_image(denorm(condition_face.data.cpu()), sample_path_face) sample_path_rec = os.path.join( sample_dir, '{}-image-color.jpg'.format(i + 1)) save_image(denorm(color_face.data.cpu()), sample_path_rec) sample_path_fake = os.path.join( sample_dir, '{}-image-fake.jpg'.format(i + 1)) save_image(denorm(fake_face.data.cpu()), sample_path_fake) print('Saved real and fake images into {}...'.format( sample_path_face)) # Save model checkpoints. model_save_dir = "ckpt-{}".format(version) if (i + 1) % 1000 == 0: if not os.path.isdir(model_save_dir): os.mkdir(model_save_dir) G_path = os.path.join(model_save_dir, '{}-G.ckpt'.format(i + 1)) torch.save(G.state_dict(), G_path) idD_path = os.path.join(model_save_dir, '{}-idD.ckpt'.format(i + 1)) torch.save(color_D.state_dict(), idD_path) print('Saved model checkpoints into {}...'.format(model_save_dir))
class VideoGenerator(): def __init__(self, config, json_file): # For fast training. self.config = config cudnn.benchmark = True os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if len(sys.argv) > 1: os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu else: os.environ["CUDA_VISIBLE_DEVICES"] = "5" global device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.version = config.version self.color_version = config.color_version self.sf_version = config.sf_version self.point_threshold = config.point_threshold self.G = ExpressionGenerater() self.sfG = ExpressionGenerater() self.colorG = NoTrackExpressionGenerater() ####### 载入预训练网络 ###### self.load() self.G.to(device) self.colorG.to(device) if config.eval == "1": self.G.eval() self.colorG.eval() self.transform = [] self.transform.append(T.Resize((224, 224))) self.transform.append(T.ToTensor()) self.transform.append( T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) self.transform = T.Compose(self.transform) self.vid_annos = self.load_json(json_file) def load(self): resume_iter = config.resume_iter color_resume_iter = config.color_resume_iter ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( self.version) sf_ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( self.sf_version) color_ckpt_dir = "/media/data2/laixc/Facial_Expression_GAN/ckpt-{}".format( self.color_version) if os.path.exists( os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter))): G_path = os.path.join(ckpt_dir, '{}-G.ckpt'.format(resume_iter)) self.G.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) G_path = os.path.join(sf_ckpt_dir, '{}-G.ckpt'.format(121000)) self.sfG.load_state_dict( torch.load(G_path, map_location=lambda storage, loc: storage)) cG_path = os.path.join(color_ckpt_dir, '{}-G.ckpt'.format(color_resume_iter)) self.colorG.load_state_dict( torch.load(cG_path, map_location=lambda storage, loc: storage)) print("load ckpt") else: print("found no ckpt") return None ############# process data ########## def crop_face(self, img, bbox, keypoint): flags = list() points = list() # can not detect face in some images if len(bbox) == 0: return None # draw bbox x, y, w, h = [int(v) for v in bbox] crop_img = img[y:y + h, x:x + w] return crop_img def load_json(self, path): with open(path, 'r') as f: data = json.load(f) print("load %d video annotations totally." % len(data)) vid_anns = dict() for anns in data: name = anns['video'] path = anns['video_path'] vid_anns[name] = {'path': path} for ann in anns['annotations']: idx = ann['index'] keypoints = ann['keypoint'] bbox = ann['bbox'] vid_anns[name][idx] = [bbox, keypoints] return vid_anns def draw_rotate_keypoints(self, img, bbox, keypoint, first_bbox, first_keypoint): flags = list() points = list() first_flags = list() first_points = list() # can not detect face in some images if len(bbox) == 0: return None, None points_image = np.zeros((224, 224, 3), np.uint8) # draw bbox bx, by, bw, bh = [int(v) for v in bbox] for i in range(0, len(keypoint), 3): x, y, flag = [int(k) for k in keypoint[i:i + 3]] x = int((x - bx) / bw * 224) y = int((y - by) / bh * 224) flags.append(flag) points.append([x, y]) if flag == 0: # keypoint not exist continue elif flag == 1 or flag == 2: # keypoint exist and visible cv2.circle(points_image, (x, y), 2, (0, 255, 0), -1) else: raise ValueError("flag of keypoint must be 0, 1, or 2.") fbx, fby, fbw, fbh = [int(v) for v in first_bbox] for i in range(0, len(first_keypoint), 3): x, y, flag = [int(k) for k in first_keypoint[i:i + 3]] x = int((x - fbx) / fbw * 224) y = int((y - fby) / fbh * 224) first_flags.append(flag) first_points.append([x, y]) # 脸部对齐 # 52是左眼的最左边, 61是右眼的最右边, x是宽度方向,y是高度方向 if not (flags[52] == 0 or flags[61] == 0): left_x, left_y = points[52] right_x, right_y = points[61] if left_x > 224 / 4 or right_x < 224 - 224 / 4: return np.asarray(points_image), 0 deltaH = right_y - left_y deltaW = right_x - left_x if math.sqrt(deltaW**2 + deltaH**2) < 1: return np.asarray(points_image), 0 angle = math.asin(deltaH / math.sqrt(deltaW**2 + deltaH**2)) angle = angle / math.pi * 180 # 弧度转角度 # 计算第一帧的角度 first_angle = 0 if not (first_flags[52] == 0 or first_flags[61] == 0): left_x, left_y = first_points[52] right_x, right_y = first_points[61] deltaH = right_y - left_y deltaW = right_x - left_x if not math.sqrt(deltaW**2 + deltaH**2) < 1: first_angle = math.asin(deltaH / math.sqrt(deltaW**2 + deltaH**2)) first_angle = first_angle / math.pi * 180 # 弧度转角度 #print("angle", angle) #print("first_angle", first_angle) angle = angle - first_angle if abs(angle) < 5: return np.asarray(points_image), 0 points_image = Image.fromarray(np.uint8(points_image)) points_image = points_image.rotate(angle) else: angle = 0 return np.asarray(points_image), angle def draw_bbox_keypoints(self, img, bbox, keypoint): flags = list() points = list() # can not detect face in some images if len(bbox) == 0: return None, None points_image = np.zeros((224, 224, 3), np.uint8) # draw bbox bx, by, bw, bh = [int(v) for v in bbox] for i in range(0, len(keypoint), 3): x, y, flag = [int(k) for k in keypoint[i:i + 3]] x = int((x - bx) / bw * 224) y = int((y - by) / bh * 224) flags.append(flag) points.append([x, y]) if flag == 0: # keypoint not exist continue elif flag == 1 or flag == 2: # keypoint exist and visible cv2.circle(points_image, (x, y), 2, (0, 255, 0), -1) else: raise ValueError("flag of keypoint must be 0, 1, or 2.") return points_image, self.crop_face(img, bbox, keypoint) def extract_image(self, img, bbox): # can not detect face in some images Knockout_image = img.copy() # draw bbox x, y, w, h = [int(v) for v in bbox] cv2.rectangle(Knockout_image, (x, y), (x + w, y + h), (0, 0, 0), cv2.FILLED) return Knockout_image, img def generate(self, first_frm_file, bg_file, alpha_file, body_file, first_frm_id, special_vids, hard2middle): first_frm = Image.open(first_frm_file) bg_body = np.array(first_frm) background = np.asarray(Image.open(bg_file)) # print("shape", background.shape) body = np.asarray(Image.open(body_file)) body_alpha = np.load(alpha_file) ## print("shape", body.shape) vid_name = os.path.basename(first_frm_file) vid_name = "{}.mp4".format(vid_name.split(".")[0]) anno = self.vid_annos[vid_name] first_bbox, first_keypoint = anno[1] # 创建结果文件夹 if not os.path.isdir("test_result"): os.mkdir("test_result") sample_dir = "test_result/gan-sample-{}-{}".format( self.version, self.config.resume_iter) if not os.path.isdir(sample_dir): os.mkdir(sample_dir) # 设置视频格式等 frm = cv2.imread(first_frm_file) heigth, width, channels = frm.shape size = (width, heigth) fourcc = cv2.VideoWriter_fourcc(*'XVID') out_path = os.path.join(sample_dir, '{}.mp4'.format(first_frm_id)) vid_writer = cv2.VideoWriter(out_path, fourcc, 25, size) first_face_crop = None is_first = True fake_frm = None angle = None G = None # 特殊视频处理 if first_frm_id.startswith("12") and int(first_frm_id) in special_vids: print("special {}-->{}".format(first_frm_id, hard2middle[first_frm_id])) vid_name = "{}.mp4".format(hard2middle[first_frm_id]) temp_anno = self.vid_annos[vid_name] level2_first_bbox, level2_first_keypoint = temp_anno[1] level2_first_frm = os.path.join( first_frm_file[:-9], "{}.jpg".format(int(first_frm_id) - 1000)) level2_first_frm = Image.open(level2_first_frm) _, first_face_crop = self.draw_bbox_keypoints( np.asarray(level2_first_frm), level2_first_bbox, level2_first_keypoint) match_cnt = 0 rotate_match_cnt = 0 # ######################### # 主体循环,生成视频 # ######################### for idx in tqdm(range(len(anno) - 1)): #if idx % 4 == 0: # self.load() bbox, keypoint = anno[idx + 1] if len(bbox) >= 2: bbox[0] = max(0, bbox[0]) bbox[1] = max(0, bbox[1]) if is_first: is_first = False _, real_first_face_crop = self.draw_bbox_keypoints( np.asarray(first_frm), bbox, keypoint) if first_face_crop is None: _, first_face_crop = self.draw_bbox_keypoints( np.asarray(first_frm), bbox, keypoint) if len(bbox) == 0: #print("no first bbox") vid_writer.write(frm) is_first = True # 重置first continue if len(first_bbox) == 0: first_bbox = bbox first_keypoint = keypoint #print("first_bbox = bbox") vid_writer.write( cv2.cvtColor(np.asarray(first_frm), cv2.COLOR_RGB2BGR)) continue else: # 获取第idx帧的关键点图 draw_kp, angle = self.draw_rotate_keypoints( np.asarray(first_frm), bbox, keypoint, first_bbox, first_keypoint) draw_kp_noratota, _ = self.draw_bbox_keypoints( np.asarray(first_frm), bbox, keypoint) #print(angle) if draw_kp is None or first_face_crop is None: print("no bbox") if first_face_crop is None: vid_writer.write(frm) continue if fake_frm is None: vid_writer.write(frm) else: vid_writer.write(fake_frm) continue first_face_crop_arr = Image.fromarray(first_face_crop, 'RGB') first_face_tensor = self.transform(first_face_crop_arr) first_face_tensor = first_face_tensor.unsqueeze(0) first_draw_kp, _ = self.draw_bbox_keypoints( np.asarray(first_frm), first_bbox, first_keypoint) # print(np.mean(first_draw_kp - draw_kp_noratota)) #print(angle) if angle == 360: angle = 0 G = self.sfG else: G = self.G # 特殊帧特殊处理,若目标帧与原始帧很接近,那么直接拿原始帧 point_threshold = self.point_threshold if np.mean(first_draw_kp - draw_kp_noratota) < point_threshold: match_cnt += 1 ##print(np.mean(first_draw_kp - draw_kp_noratota)) #print("match ") fake_frm = knn_fusion(bg_body, background, body, body_alpha, first_bbox, np.asarray(first_face_crop), bbox) fake_frm = cv2.cvtColor(np.asarray(fake_frm), cv2.COLOR_RGB2BGR) fake_frm = cv2.resize(fake_frm, (width, heigth)) vid_writer.write(fake_frm) continue # 调用模型生成一帧 def generate_frm(draw_kp, first_face_tensor): img_kp = Image.fromarray(draw_kp, 'RGB') # img.show() key_points = self.transform(img_kp) key_points = key_points.unsqueeze(0) first_face_tensor = first_face_tensor.to(device) key_points = key_points.to(device) face_fake = G(first_face_tensor, key_points) return face_fake def to_PIL(face_fake): frm = denorm(face_fake.data.cpu()) frm = frm[0] toPIL = T.ToPILImage() frm = toPIL(frm) return frm # 特殊帧特殊处理,若目标帧与原始帧很接近,那么直接拿原始帧 if np.mean(first_draw_kp - draw_kp) < point_threshold: rotate_match_cnt += 1 #print(np.mean(first_draw_kp - draw_kp)) frm = Image.fromarray(np.array(first_face_crop)) frm = frm.resize(size=(224, 224)) #print("match rotation") else: face_fake = generate_frm(draw_kp, first_face_tensor) frm = to_PIL(face_fake) # 如果有旋转 if angle != 0: face_fake_norotate = generate_frm(draw_kp_noratota, first_face_tensor) frm = frm.rotate(-angle) norotate_frm = to_PIL(face_fake_norotate) # 旋转的人脸和未旋转人脸的融合 frm = np.array(frm) frm.flags.writeable = True norotate_frm = np.asarray(norotate_frm) mask = np.mean(frm, axis=2) frm[:, :, 0] = np.where(mask < 0.1, norotate_frm[:, :, 0], frm[:, :, 0]) frm[:, :, 1] = np.where(mask < 0.1, norotate_frm[:, :, 1], frm[:, :, 1]) frm[:, :, 2] = np.where(mask < 0.1, norotate_frm[:, :, 2], frm[:, :, 2]) frm = Image.fromarray(frm) #save_image(frm, "test_result/fake_face_{}.jpg".format(idx)) if config.color_transfer == "1": fake_face = cv2.cvtColor(np.asarray(frm), cv2.COLOR_RGB2BGR) first_face = cv2.cvtColor( np.asarray(Image.fromarray(first_face_crop, 'RGB')), cv2.COLOR_RGB2BGR) frm = color_transfer(first_face, fake_face) frm = cv2.cvtColor(np.asarray(frm), cv2.COLOR_BGR2RGB) if config.sharpen == "1": kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1] ]) * config.sharpen_lambda kernel[1, 1] = kernel[1, 1] + 1 frm = cv2.filter2D(frm, -1, kernel) ###### 融合 #fake_frm = fusion(first_img, first_bbox, np.asarray(frm), bbox) fake_frm = knn_fusion(bg_body, background, body, body_alpha, first_bbox, np.asarray(frm), bbox) # fake_frm = cv2.cvtColor(np.asarray(fake_frm), cv2.COLOR_RGB2BGR) #cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm) fake_frm = cv2.resize(fake_frm, (width, heigth)) # cv2.imwrite("test_result/fake_frm_{}.jpg".format(idx), fake_frm) vid_writer.write(fake_frm) vid_writer.release() print("match_cnt {}, rotate_match_cnt {}".format( match_cnt, rotate_match_cnt))