def main(path_to_e_hat_video='e_hat_video.tar', path_to_e_hat_images='e_hat_images.tar', path_to_chkpt='model_weights.tar', path_to_video='examples/fine_tuning/test_video.mp4', path_to_images='examples/fine_tuning/reeps'): """Hyperparameters and config""" device = torch.device("cuda:0") cpu = torch.device("cpu") T = 32 """Loading Embedder input""" frame_mark_video = select_frames(path_to_video, T) frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50) frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_video = frame_mark_video.transpose(2, 4).to(device) #T,2,3,256,256 f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256 frame_mark_images = select_images_frames(path_to_images) frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50) frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_images = frame_mark_images.transpose(2, 4).to( device) #T,2,3,256,256 f_lm_images = frame_mark_images.unsqueeze(0) #1,T,2,3,256,256 E = Embedder(256).to(device) E.eval() """Loading from past checkpoint""" checkpoint = torch.load(path_to_chkpt, map_location=cpu) E.load_state_dict(checkpoint['E_state_dict']) """Inference""" with torch.no_grad(): #forward # Calculate average encoding vector for video f_lm = f_lm_video f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224 e_vectors = E(f_lm_compact[:, 0, :, :, :], f_lm_compact[:, 1, :, :, :]) #BxT,512,1 e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1 e_hat_video = e_vectors.mean(dim=1) f_lm = f_lm_images f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224 e_vectors = E(f_lm_compact[:, 0, :, :, :], f_lm_compact[:, 1, :, :, :]) #BxT,512,1 e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1 e_hat_images = e_vectors.mean(dim=1) print('Saving e_hat...') torch.save({'e_hat': e_hat_video}, path_to_e_hat_video) torch.save({'e_hat': e_hat_images}, path_to_e_hat_images) print('...Done saving')
def embed(path_to_video, lmarks, path_to_chkpt=None, ckpt=None, T=1, E=None, ref_ids=None): print("embedding...") """Loading Embedder input""" if path_to_video.split('.')[-1] in ['png', 'jpg']: ref_img = cv2.imread(path_to_video) frame_mark_video = [cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)] lmarks_list = [lmarks.copy()] elif path_to_video.split('.')[-1] in ['mp4']: frame_mark_video, lmarks_list = select_frames(path_to_video, T, lmarks=lmarks.copy(), specific_frames=ref_ids) inference_img = frame_mark_video reference_lmarks = lmarks_list frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50) frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_video = frame_mark_video.transpose(2, 4).to(device) #T,2,3,256,256 f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256 if E is None: E = Embedder(256).to(device) E.eval() """Loading from past checkpoint""" if path_to_chkpt is not None: checkpoint = torch.load(path_to_chkpt, map_location=cpu) else: assert ckpt is not None checkpoint = copy.deepcopy(ckpt) E.load_state_dict(checkpoint['E_state_dict']) else: E = E E.eval() """Inference""" with torch.no_grad(): # Calculate average encoding vector for video f_lm = f_lm_video f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224 e_vectors = E(f_lm_compact[:, 0, :, :, :], f_lm_compact[:, 1, :, :, :]) #BxT,512,1 e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1 e_hat_video = e_vectors.mean(dim=1) return e_hat_video, inference_img, reference_lmarks
def main(): args = parse_args() """Hyperparameters and config""" use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') cpu = torch.device("cpu") path_to_e_hat_video = args.output path_to_video = args.video T = 32 face_aligner = face_alignment.FaceAlignment( face_alignment.LandmarksType._2D, flip_input=False, device='cuda' if use_cuda else 'cpu') """Loading Embedder input""" frame_mark_video = select_frames(path_to_video, T) frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50, fa=face_aligner) frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_video = frame_mark_video.transpose( 2, 4).to(device) / 255 #T,2,3,256,256 f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256 E = Embedder(256).to(device) E.eval() """Loading from past checkpoint""" checkpoint = torch.load(args.model, map_location=cpu) E.load_state_dict(checkpoint['E_state_dict']) """Inference""" with torch.no_grad(): #forward # Calculate average encoding vector for video f_lm = f_lm_video f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224 print('Run inference...') e_vectors = E(f_lm_compact[:, 0, :, :, :], f_lm_compact[:, 1, :, :, :]) #BxT,512,1 e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1 e_hat_video = e_vectors.mean(dim=1) print('Saving e_hat...') torch.save({'e_hat': e_hat_video}, path_to_e_hat_video) print('...Done saving')
def __getitem__(self, idx): path = self.path_to_video frame_has_face = False while not frame_has_face: try: frame_mark = select_frames(path, 1) frame_mark = generate_cropped_landmarks(frame_mark, pad=50, fa=self.fa) frame_has_face = True except: print('No face detected, retrying') frame_mark = torch.from_numpy(np.array(frame_mark)).type( dtype=torch.float ) # 1,2,256,256,3 frame_mark = frame_mark.transpose(2, 4).to(self.device) # 1,2,3,256,256 x = frame_mark[0, 0].squeeze() g_y = frame_mark[0, 1].squeeze() return x, g_y
def __getitem__(self, idx): vid_idx = idx if idx < 0: idx = self.__len__() + idx path = list(Path(self.path_to_mp4).glob('**/*.mp4'))[idx] frame_mark = select_frames(path, self.K) frame_mark = generate_landmarks(frame_mark, fa=self.fa) frame_mark = torch.from_numpy(np.array(frame_mark)).type( dtype=torch.float ) # K,2,224,224,3 frame_mark = frame_mark.transpose(2, 4).to(self.device) # K,2,3,224,224 """ comented out because it can be generated in training I will fave frame_mark and load it directly in the training g_idx = torch.randint(low=0, high=self.K, size=(1, 1)) x = frame_mark[g_idx, 0].squeeze() g_y = frame_mark[g_idx, 1].squeeze() return frame_mark, x, g_y, vid_idx """ save_video(Path(self.new_path), frame_mark, vid_idx)
from dataset.video_extraction_conversion import select_frames, select_images_frames, generate_cropped_landmarks from network.blocks import * from network.model import Embedder import numpy as np """Hyperparameters and config""" device = torch.device("cuda:0") cpu = torch.device("cpu") path_to_e_hat_video = 'e_hat_video.tar' path_to_e_hat_images = 'e_hat_images.tar' path_to_chkpt = 'model_weights.tar' path_to_video = 'examples/fine_tuning/test_video.mp4' path_to_images = 'examples/fine_tuning/test_images' T = 32 """Loading Embedder input""" frame_mark_video = select_frames(path_to_video, T) frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50) frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_video = frame_mark_video.transpose(2, 4).to(device) #T,2,3,256,256 f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256 frame_mark_images = select_images_frames(path_to_images) frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50) frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type( dtype=torch.float) #T,2,256,256,3 frame_mark_images = frame_mark_images.transpose(2, 4).to(device) #T,2,3,256,256 f_lm_images = frame_mark_images.unsqueeze(0) #1,T,2,3,256,256 E = Embedder(256).to(device)