Esempio n. 1
0
def main(path_to_e_hat_video='e_hat_video.tar',
         path_to_e_hat_images='e_hat_images.tar',
         path_to_chkpt='model_weights.tar',
         path_to_video='examples/fine_tuning/test_video.mp4',
         path_to_images='examples/fine_tuning/reeps'):
    """Hyperparameters and config"""
    device = torch.device("cuda:0")
    cpu = torch.device("cpu")

    T = 32
    """Loading Embedder input"""
    frame_mark_video = select_frames(path_to_video, T)
    frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50)
    frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_video = frame_mark_video.transpose(2,
                                                  4).to(device)  #T,2,3,256,256
    f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

    frame_mark_images = select_images_frames(path_to_images)
    frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50)
    frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_images = frame_mark_images.transpose(2, 4).to(
        device)  #T,2,3,256,256
    f_lm_images = frame_mark_images.unsqueeze(0)  #1,T,2,3,256,256

    E = Embedder(256).to(device)
    E.eval()
    """Loading from past checkpoint"""
    checkpoint = torch.load(path_to_chkpt, map_location=cpu)
    E.load_state_dict(checkpoint['E_state_dict'])
    """Inference"""
    with torch.no_grad():
        #forward
        # Calculate average encoding vector for video
        f_lm = f_lm_video
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_video = e_vectors.mean(dim=1)

        f_lm = f_lm_images
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_images = e_vectors.mean(dim=1)

    print('Saving e_hat...')
    torch.save({'e_hat': e_hat_video}, path_to_e_hat_video)
    torch.save({'e_hat': e_hat_images}, path_to_e_hat_images)
    print('...Done saving')
def main():
    args = parse_args()
    """Hyperparameters and config"""
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    cpu = torch.device("cpu")
    path_to_e_hat_video = args.output
    path_to_video = args.video
    T = 32
    face_aligner = face_alignment.FaceAlignment(
        face_alignment.LandmarksType._2D,
        flip_input=False,
        device='cuda' if use_cuda else 'cpu')
    """Loading Embedder input"""
    frame_mark_video = select_frames(path_to_video, T)
    frame_mark_video = generate_cropped_landmarks(frame_mark_video,
                                                  pad=50,
                                                  fa=face_aligner)
    frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
        dtype=torch.float)  #T,2,256,256,3
    frame_mark_video = frame_mark_video.transpose(
        2, 4).to(device) / 255  #T,2,3,256,256
    f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

    E = Embedder(256).to(device)
    E.eval()
    """Loading from past checkpoint"""
    checkpoint = torch.load(args.model, map_location=cpu)
    E.load_state_dict(checkpoint['E_state_dict'])
    """Inference"""
    with torch.no_grad():
        #forward
        # Calculate average encoding vector for video
        f_lm = f_lm_video
        f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                                 f_lm.shape[-2],
                                 f_lm.shape[-1])  #BxT,2,3,224,224
        print('Run inference...')
        e_vectors = E(f_lm_compact[:, 0, :, :, :],
                      f_lm_compact[:, 1, :, :, :])  #BxT,512,1
        e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
        e_hat_video = e_vectors.mean(dim=1)

    print('Saving e_hat...')
    torch.save({'e_hat': e_hat_video}, path_to_e_hat_video)
    print('...Done saving')
frame_mark_video = select_frames(path_to_video, T)
frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50)
frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(
    dtype=torch.float)  #T,2,256,256,3
frame_mark_video = frame_mark_video.transpose(2, 4).to(device)  #T,2,3,256,256
f_lm_video = frame_mark_video.unsqueeze(0)  #1,T,2,3,256,256

frame_mark_images = select_images_frames(path_to_images)
frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50)
frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(
    dtype=torch.float)  #T,2,256,256,3
frame_mark_images = frame_mark_images.transpose(2,
                                                4).to(device)  #T,2,3,256,256
f_lm_images = frame_mark_images.unsqueeze(0)  #1,T,2,3,256,256

E = Embedder(256).to(device)
E.eval()
"""Loading from past checkpoint"""
checkpoint = torch.load(path_to_chkpt, map_location=cpu)
E.load_state_dict(checkpoint['E_state_dict'])
"""Inference"""
with torch.no_grad():
    #forward
    # Calculate average encoding vector for video
    f_lm = f_lm_video
    f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3],
                             f_lm.shape[-2], f_lm.shape[-1])  #BxT,2,3,224,224
    e_vectors = E(f_lm_compact[:, 0, :, :, :],
                  f_lm_compact[:, 1, :, :, :])  #BxT,512,1
    e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1)  #B,T,512,1
    e_hat_video = e_vectors.mean(dim=1)
Esempio n. 4
0
)
"""Create dataset and net"""
dataset = MyNewDataset(path_to_images='clean_dataset', device=device)
dataLoader = DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize SummaryWriter for tensorboard
RUN_NAME = datetime.now().strftime(format='%b%d_%H-%M-%S')
writer = SummaryWriter(log_dir=f'runs/{RUN_NAME}')

# Initialize PartialInceptionNetwork for calculating FID score
inception = PartialInceptionNetwork().eval()
# Initialize Cosine similarity for calculating CSIM score
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

G = Generator(224).to(device)
E = Embedder(224).to(device)
D = Discriminator(dataset.__len__()).to(device)

G.train()
E.train()
D.train()

optimizerG = optim.Adam(params=list(E.parameters()) + list(G.parameters()),
                        lr=5e-5)
optimizerD = optim.Adam(params=D.parameters(), lr=2e-4)
"""Criterion"""
criterionG = LossG(
    VGGFace_body_path=VGGFace_body_path,
    VGGFace_weight_path=VGGFace_weight_path,
    device=device,
)