Example #1
0
def main(path_to_e_hat_video='e_hat_video.tar',
         path_to_e_hat_images='e_hat_images.tar',
         path_to_chkpt='model_weights.tar',
         path_to_video='examples/fine_tuning/test_video.mp4',
         path_to_images='examples/fine_tuning/reeps'):
    """Hyperparameters and config"""
    device = torch.device("cuda:0")
    cpu = torch.device("cpu")

    T = 32
    """Loading Embedder input"""
    frame_mark_video = select_frames(path_to_video, T)
    frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50)
    frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_video = frame_mark_video.transpose(2,
                                                  4).to(device)  #T,2,3,256,256
    f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

    frame_mark_images = select_images_frames(path_to_images)
    frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50)
    frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_images = frame_mark_images.transpose(2, 4).to(
        device)  #T,2,3,256,256
    f_lm_images = frame_mark_images.unsqueeze(0)  #1,T,2,3,256,256

    E = Embedder(256).to(device)
    E.eval()
    """Loading from past checkpoint"""
    checkpoint = torch.load(path_to_chkpt, map_location=cpu)
    E.load_state_dict(checkpoint['E_state_dict'])
    """Inference"""
    with torch.no_grad():
        #forward
        # Calculate average encoding vector for video
        f_lm = f_lm_video
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_video = e_vectors.mean(dim=1)

        f_lm = f_lm_images
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_images = e_vectors.mean(dim=1)

    print('Saving e_hat...')
    torch.save({'e_hat': e_hat_video}, path_to_e_hat_video)
    torch.save({'e_hat': e_hat_images}, path_to_e_hat_images)
    print('...Done saving')
Example #2
0
def embed(path_to_video,
          lmarks,
          path_to_chkpt=None,
          ckpt=None,
          T=1,
          E=None,
          ref_ids=None):
    print("embedding...")
    """Loading Embedder input"""
    if path_to_video.split('.')[-1] in ['png', 'jpg']:
        ref_img = cv2.imread(path_to_video)
        frame_mark_video = [cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)]
        lmarks_list = [lmarks.copy()]
    elif path_to_video.split('.')[-1] in ['mp4']:
        frame_mark_video, lmarks_list = select_frames(path_to_video,
                                                      T,
                                                      lmarks=lmarks.copy(),
                                                      specific_frames=ref_ids)
    inference_img = frame_mark_video
    reference_lmarks = lmarks_list

    frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50)
    frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_video = frame_mark_video.transpose(2,
                                                  4).to(device)  #T,2,3,256,256
    f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

    if E is None:
        E = Embedder(256).to(device)
        E.eval()
        """Loading from past checkpoint"""
        if path_to_chkpt is not None:
            checkpoint = torch.load(path_to_chkpt, map_location=cpu)
        else:
            assert ckpt is not None
            checkpoint = copy.deepcopy(ckpt)
        E.load_state_dict(checkpoint['E_state_dict'])
    else:
        E = E
        E.eval()
    """Inference"""
    with torch.no_grad():
        # Calculate average encoding vector for video
        f_lm = f_lm_video
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_video = e_vectors.mean(dim=1)

    return e_hat_video, inference_img, reference_lmarks
def main():
    args = parse_args()
    """Hyperparameters and config"""
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    cpu = torch.device("cpu")
    path_to_e_hat_video = args.output
    path_to_video = args.video
    T = 32
    face_aligner = face_alignment.FaceAlignment(
        face_alignment.LandmarksType._2D,
        flip_input=False,
        device='cuda' if use_cuda else 'cpu')
    """Loading Embedder input"""
    frame_mark_video = select_frames(path_to_video, T)
    frame_mark_video = generate_cropped_landmarks(frame_mark_video,
                                                  pad=50,
                                                  fa=face_aligner)
    frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_video = frame_mark_video.transpose(
        2, 4).to(device) / 255  #T,2,3,256,256
    f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

    E = Embedder(256).to(device)
    E.eval()
    """Loading from past checkpoint"""
    checkpoint = torch.load(args.model, map_location=cpu)
    E.load_state_dict(checkpoint['E_state_dict'])
    """Inference"""
    with torch.no_grad():
        #forward
        # Calculate average encoding vector for video
        f_lm = f_lm_video
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        print('Run inference...')
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_video = e_vectors.mean(dim=1)

    print('Saving e_hat...')
    torch.save({'e_hat': e_hat_video}, path_to_e_hat_video)
    print('...Done saving')
    def __getitem__(self, idx):
        path = self.path_to_video
        frame_has_face = False
        while not frame_has_face:
            try:
                frame_mark = select_frames(path, 1)
                frame_mark = generate_cropped_landmarks(frame_mark, pad=50, fa=self.fa)
                frame_has_face = True
            except:
                print('No face detected, retrying')
        frame_mark = torch.from_numpy(np.array(frame_mark)).type(
            dtype=torch.float
        )  # 1,2,256,256,3
        frame_mark = frame_mark.transpose(2, 4).to(self.device)  # 1,2,3,256,256

        x = frame_mark[0, 0].squeeze()
        g_y = frame_mark[0, 1].squeeze()
        return x, g_y
    def __getitem__(self, idx):
        vid_idx = idx
        if idx < 0:
            idx = self.__len__() + idx

        path = list(Path(self.path_to_mp4).glob('**/*.mp4'))[idx]
        frame_mark = select_frames(path, self.K)
        frame_mark = generate_landmarks(frame_mark, fa=self.fa)
        frame_mark = torch.from_numpy(np.array(frame_mark)).type(
            dtype=torch.float
        )  # K,2,224,224,3
        frame_mark = frame_mark.transpose(2, 4).to(self.device)  # K,2,3,224,224
        """
        comented out because it can be generated in training
        I will fave frame_mark and load it directly in the training

        g_idx = torch.randint(low=0, high=self.K, size=(1, 1))
        x = frame_mark[g_idx, 0].squeeze()
        g_y = frame_mark[g_idx, 1].squeeze()

        return frame_mark, x, g_y, vid_idx
        """
        save_video(Path(self.new_path), frame_mark, vid_idx)
from dataset.video_extraction_conversion import select_frames, select_images_frames, generate_cropped_landmarks
from network.blocks import *
from network.model import Embedder

import numpy as np
"""Hyperparameters and config"""
device = torch.device("cuda:0")
cpu = torch.device("cpu")
path_to_e_hat_video = 'e_hat_video.tar'
path_to_e_hat_images = 'e_hat_images.tar'
path_to_chkpt = 'model_weights.tar'
path_to_video = 'examples/fine_tuning/test_video.mp4'
path_to_images = 'examples/fine_tuning/test_images'
T = 32
"""Loading Embedder input"""
frame_mark_video = select_frames(path_to_video, T)
frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50)
frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
    dtype=torch.float)  #T,2,256,256,3
frame_mark_video = frame_mark_video.transpose(2, 4).to(device)  #T,2,3,256,256
f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

frame_mark_images = select_images_frames(path_to_images)
frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50)
frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(
    dtype=torch.float)  #T,2,256,256,3
frame_mark_images = frame_mark_images.transpose(2,
                                                4).to(device)  #T,2,3,256,256
f_lm_images = frame_mark_images.unsqueeze(0)  #1,T,2,3,256,256

E = Embedder(256).to(device)