Exemple #1
0
    def __init__(self, max_images_in_batch):
        texture_data = np.load(cnst.flame_texture_space_dat_file,
                               allow_pickle=True,
                               encoding='latin1').item()
        self.flm_tex_dec = FlameTextureSpace(texture_data,
                                             data_un_normalizer=None).cuda()
        self.flame_visualizer = OverLayViz()
        self.face_region_only_mask = np.array(
            Image.open(cnst.face_region_mask_file))[:, :, 0:1].transpose(
                (2, 0, 1))
        self.face_region_only_mask = (
            torch.from_numpy(self.face_region_only_mask.astype('float32')) /
            255.0)[None, ...]
        print('read face region mask')
        flength = 5000
        cam_t = np.array([0., 0., 0])
        self.camera_params = camera_ringnetpp((512, 512),
                                              trans=cam_t,
                                              focal=flength)
        self.max_num = max_images_in_batch - 1
        # self.max_num = 5
        self.pairs = []
        for i in range(self.max_num):
            for j in range(i + 1, self.max_num):
                self.pairs.append((i, j))

        self.pairs = np.array(self.pairs)
Exemple #2
0
 def __init__(self, gen_i, gen_j, sampling_flame_labels, dataset, input_indices, overlay_mesh=False):
     self.gen_i = gen_i
     self.gen_j = gen_j
     self.sampling_flame_labels = sampling_flame_labels
     self.overlay_mesh = overlay_mesh
     self.overlay_visualizer = OverLayViz()
     self.dataset = dataset
     self.input_indices = input_indices
     self.cam_t = np.array([0., 0., 2.5])
Exemple #3
0
def get_gif_from_list_of_params(generator, flame_params, step, alpha, noise, overlay_landmarks, flame_std, flame_mean,
                                overlay_visualizer, rendered_flame_as_condition, use_posed_constant_input,
                                normal_maps_as_cond, camera_params):
    # cam_t = np.array([0., 0., 2.5])
    # camera_params = camera_dynamic((224, 224), cam_t)
    if overlay_visualizer is None:
        overlay_visualizer = OverLayViz()

    fixed_embeddings = torch.ones(flame_params.shape[0], dtype=torch.long, device='cuda')*13
    # print(generator.module.get_embddings()[fixed_embeddings])
    flame_params_unnorm = flame_params * flame_std + flame_mean

    flame_params_unnorm = torch.from_numpy(flame_params_unnorm).cuda()
    normal_map_img, _, _, _, rend_imgs = \
        overlay_visualizer.get_rendered_mesh(flame_params=(flame_params_unnorm[:, SHAPE_IDS[0]:SHAPE_IDS[1]],
                                                           flame_params_unnorm[:, EXP_IDS[0]:EXP_IDS[1]],
                                                           flame_params_unnorm[:, POSE_IDS[0]:POSE_IDS[1]],
                                                           flame_params_unnorm[:, TRANS_IDS[0]:TRANS_IDS[1]]),
                                             camera_params=camera_params)
    rend_imgs = (rend_imgs/127.0 - 1)

    if use_posed_constant_input:
        pose = flame_params[:, constants.get_idx_list('GLOBAL_ROT')]
    else:
        pose = None

    if rendered_flame_as_condition:
        gen_in = rend_imgs
    else:
        gen_in = flame_params

    if normal_maps_as_cond:
        gen_in = torch.cat((rend_imgs, normal_map_img), dim=1)

    fake_images = generate_from_flame_sequence(generator, gen_in, pose, step, alpha, noise,
                                               input_indices=fixed_embeddings)[-1]

    fake_images = overlay_visualizer.range_normalize_images(fast_image_reshape(fake_images,
                                                                               height_out=rend_imgs.shape[2],
                                                                               width_out=rend_imgs.shape[3]))
    if rendered_flame_as_condition:
        fake_images = torch.cat([fake_images.cpu(), (rend_imgs.cpu() + 1)/2], dim=-1)

    if normal_maps_as_cond:
        fake_images = torch.cat([fake_images.cpu(), (normal_map_img.cpu() + 1) / 2], dim=-1)

    return fake_images
Exemple #4
0
class VisualizationSaver():
    def __init__(self, gen_i, gen_j, sampling_flame_labels, dataset, input_indices, overlay_mesh=False):
        self.gen_i = gen_i
        self.gen_j = gen_j
        self.sampling_flame_labels = sampling_flame_labels
        self.overlay_mesh = overlay_mesh
        self.overlay_visualizer = OverLayViz()
        self.dataset = dataset
        self.input_indices = input_indices
        self.cam_t = np.array([0., 0., 2.5])

    def set_flame_params(self, pose, sampling_flame_labels, input_indices):
        self.pose = pose
        self.sampling_flame_labels = sampling_flame_labels
        self.input_indices = input_indices

    def save_samples(self, i, model, step, alpha, resolution, fid, run_id):
        images = []
        # camera_params = camera_dynamic((resolution, resolution), self.cam_t)
        flength = 5000
        cam_t = np.array([0., 0., 0])
        camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength)

        with torch.no_grad():
            for img_idx in range(self.gen_i):
                flame_param_this_batch = self.sampling_flame_labels[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                if self.pose is not None:
                    pose_this_batch = self.pose[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                else:
                    pose_this_batch = None
                idx_this_batch = self.input_indices[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                img_tensor = model(flame_param_this_batch.clone(), pose_this_batch, step=step, alpha=alpha,
                                   input_indices=idx_this_batch)[-1]

                img_tensor = self.overlay_visualizer.range_normalize_images(
                    dataset_loaders.fast_image_reshape(img_tensor, height_out=256, width_out=256,
                                                       non_diff_allowed=True))

                images.append(img_tensor.data.cpu())

        torchvision.utils.save_image(
            torch.cat(images, 0),
            f'{cnst.output_root}sample/{str(run_id)}/{str(i + 1).zfill(6)}_res{resolution}x{resolution}_fid_{fid:.2f}.png',
            nrow=self.gen_i,
            normalize=True,
            range=(0, 1))
Exemple #5
0
# run_ids_1 = [7, 24, 8, 3]
# run_ids_1 = [7, 8, 3]

settings_for_runs = \
    {24: {'name': 'vector_cond', 'model_idx': '216000_1', 'normal_maps_as_cond': False,
          'rendered_flame_as_condition': False, 'apply_sqrt2_fac_in_eq_lin': False},
     29: {'name': 'full_model', 'model_idx': '026000_1', 'normal_maps_as_cond': True,
          'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': True},
     7: {'name': 'flm_rndr_tex_interp', 'model_idx': '488000_1', 'normal_maps_as_cond': False,
         'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': False},
     3: {'name': 'norm_mp_tex_interp', 'model_idx': '040000_1', 'normal_maps_as_cond': True,
         'rendered_flame_as_condition': False, 'apply_sqrt2_fac_in_eq_lin': False},
     8: {'name': 'norm_map_rend_flm_no_tex_interp', 'model_idx': '460000_1', 'normal_maps_as_cond': True,
         'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': False},}

overlay_visualizer = OverLayViz()
# overlay_visualizer.setup_renderer(mesh_file=None)

flm_params = np.zeros((num_smpl_to_eval_on, code_size)).astype('float32')
fl_param_dict = np.load(cnst.all_flame_params_file, allow_pickle=True).item()
for i, key in enumerate(fl_param_dict):
    flame_param = fl_param_dict[key]
    flame_param = np.hstack(
        (flame_param['shape'], flame_param['exp'], flame_param['pose'],
         flame_param['cam'], flame_param['tex'], flame_param['lit'].flatten()))
    # tz = camera_params['f'][0] / (camera_params['c'][0] * flame_param[:, 156:157])
    # flame_param[:, 156:159] = np.concatenate((flame_param[:, 157:], tz), axis=1)

    # import ipdb; ipdb.set_trace()
    flm_params[i, :] = flame_param.astype('float32')
    if i == num_smpl_to_eval_on - 1:
# run_ids_1 = [7, 8, 3]
# run_ids_1 = [7]

settings_for_runs = \
    {24: {'name': 'vector_cond', 'model_idx': '216000_1', 'normal_maps_as_cond': False,
          'rendered_flame_as_condition': False, 'apply_sqrt2_fac_in_eq_lin': False},
     29: {'name': 'full_model', 'model_idx': '026000_1', 'normal_maps_as_cond': True,
          'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': True},
     7: {'name': 'flm_rndr_tex_interp', 'model_idx': '488000_1', 'normal_maps_as_cond': False,
         'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': False},
     3: {'name': 'norm_mp_tex_interp', 'model_idx': '040000_1', 'normal_maps_as_cond': True,
         'rendered_flame_as_condition': False, 'apply_sqrt2_fac_in_eq_lin': False},
     8: {'name': 'norm_map_rend_flm_no_tex_interp', 'model_idx': '460000_1', 'normal_maps_as_cond': True,
         'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': False},}

overlay_visualizer = OverLayViz()
# overlay_visualizer.setup_renderer(mesh_file=None)

flm_params = np.zeros((num_smpl_to_eval_on, code_size)).astype('float32')
fl_param_dict = np.load(cnst.all_flame_params_file, allow_pickle=True).item()
np.random.seed(2)
for i, key in enumerate(fl_param_dict):
    flame_param = fl_param_dict[key]
    shape_params = np.concatenate((np.random.normal(0, 1, [
        3,
    ]), np.zeros(97))).astype('float32')
    exp_params = np.concatenate((np.random.normal(0, 1, [
        3,
    ]), np.zeros(47))).astype('float32')
    # +- pi/4 for bad samples +- pi/8 for good samples
    # pose = np.array([0, np.random.uniform(-np.pi/4, np.pi/4, 1), 0,
Exemple #7
0
import sys
sys.path.append('../')
import constants as cnst
import tqdm
import numpy as np
import torch
import lmdb
from io import BytesIO
from PIL import Image
import time
from my_utils.visualize_flame_overlay import OverLayViz
# import matplotlib.pyplot as plt


overlay_viz = OverLayViz()
num_validation_images = -1
num_files = 70_000
batch_size = 32
resolution = cnst.flame_config['image_size']
flame_param_dict = np.load(cnst.all_flame_params_file, allow_pickle=True).item()
param_files = flame_param_dict.keys()

with lmdb.open(cnst.rendered_flame_root, map_size=1024 ** 4, readahead=False) as env:
    with env.begin(write=True) as transaction:
        total = 0
        for batch_id in tqdm.tqdm(range(num_files//batch_size + 1)):
            shapecode = []
            expcode = []
            posecode = []
            cam = []
            texcode = []
Exemple #8
0
class InterpolatedTextureLoss:
    def __init__(self, max_images_in_batch):
        texture_data = np.load(cnst.flame_texture_space_dat_file,
                               allow_pickle=True,
                               encoding='latin1').item()
        self.flm_tex_dec = FlameTextureSpace(texture_data,
                                             data_un_normalizer=None).cuda()
        self.flame_visualizer = OverLayViz()
        self.face_region_only_mask = np.array(
            Image.open(cnst.face_region_mask_file))[:, :, 0:1].transpose(
                (2, 0, 1))
        self.face_region_only_mask = (
            torch.from_numpy(self.face_region_only_mask.astype('float32')) /
            255.0)[None, ...]
        print('read face region mask')
        flength = 5000
        cam_t = np.array([0., 0., 0])
        self.camera_params = camera_ringnetpp((512, 512),
                                              trans=cam_t,
                                              focal=flength)
        self.max_num = max_images_in_batch - 1
        # self.max_num = 5
        self.pairs = []
        for i in range(self.max_num):
            for j in range(i + 1, self.max_num):
                self.pairs.append((i, j))

        self.pairs = np.array(self.pairs)

    def pairwise_texture_loss(self, tx1, tx2):
        if self.face_region_only_mask.device != tx1.device:
            self.face_region_only_mask = self.face_region_only_mask.to(
                tx1.device)

        if self.face_region_only_mask.shape[-1] != tx1.shape[-1]:
            face_region_only_mask = fast_image_reshape(
                self.face_region_only_mask, tx1.shape[1], tx1.shape[2])
        else:
            face_region_only_mask = self.face_region_only_mask

        # import ipdb; ipdb.set_trace()
        return torch.mean(
            torch.sigmoid(torch.pow(tx1 - tx2, 2)) * face_region_only_mask[0])
        # return torch.mean(torch.pow(tx1 - tx2, 2) * face_region_only_mask[0])

    def tex_sp_intrp_loss(self, flame_batch, generator, step, alpha, max_ids,
                          normal_maps_as_cond, use_posed_constant_input,
                          rendered_flame_as_condition):

        textures, tx_masks, _ = self.get_image_and_textures(
            alpha, flame_batch, generator, max_ids, normal_maps_as_cond,
            rendered_flame_as_condition, step, use_posed_constant_input)

        # import ipdb; ipdb.set_trace()

        random_pairs_idxs = np.random.choice(len(self.pairs),
                                             self.max_num,
                                             replace=False)
        random_pairs = self.pairs[random_pairs_idxs]
        loss = 0
        for cur_pair in random_pairs:
            tx_mask_common = tx_masks[cur_pair[1]] * tx_masks[cur_pair[0]]
            loss += self.pairwise_texture_loss(
                tx1=textures[cur_pair[0]] * tx_mask_common,
                tx2=textures[cur_pair[1]] * tx_mask_common)

        return 16 * loss / len(random_pairs)

    def get_image_and_textures(self, alpha, flame_batch, generator, max_ids,
                               normal_maps_as_cond,
                               rendered_flame_as_condition, step,
                               use_posed_constant_input):
        batch_size = flame_batch.shape[0]
        flame_batch = flame_batch[:self.max_num, :]  # Just to limit run time
        # import ipdb; ipdb.set_trace()
        if rendered_flame_as_condition or normal_maps_as_cond:
            shape = flame_batch[:, constants.INDICES['SHAPE'][0]:constants.
                                INDICES['SHAPE'][1]]
            exp = flame_batch[:, constants.INDICES['EXP'][0]:constants.
                              INDICES['EXP'][1]]
            pose = flame_batch[:, constants.INDICES['POSE'][0]:constants.
                               INDICES['POSE'][1]]
            if flame_batch.shape[-1] == 159:  # non DECA params
                flame_params = (
                    shape, exp, pose,
                    flame_batch[:, constants.INDICES['TRANS'][0]:constants.
                                INDICES['TRANS'][1]])  # translation
                norma_map_img, _, _, _, rend_flm = self.flame_visualizer.get_rendered_mesh(
                    flame_params=flame_params,
                    camera_params=self.camera_params)
                rend_flm = rend_flm / 127 - 1.0
                norma_map_img = norma_map_img * 2 - 1
            elif flame_batch.shape[-1] == 236:  # DECA
                cam = flame_batch[:, constants.DECA_IDX['cam'][0]:constants.
                                  DECA_IDX['cam'][1]:]

                # Same lightcode for the whole batch
                light_code = flame_batch[
                    0:1,
                    constants.DECA_IDX['lit'][0]:constants.DECA_IDX['lit'][1]:]
                light_code = light_code.repeat(batch_size, 1)
                light_code = light_code.view((batch_size, 9, 3))

                #same texture code for the whole batch
                texture_code = flame_batch[
                    0:1,
                    constants.DECA_IDX['tex'][0]:constants.DECA_IDX['tex'][1]:]
                texture_code = texture_code.repeat(batch_size, 1)
                # import ipdb; ipdb.set_trace()
                norma_map_img, _, _, _, rend_flm = \
                    self.flame_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code),
                                                            camera_params=cam)
                # rend_flm = rend_flm * 2 - 1
                rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1
                norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1
                rend_flm = fast_image_reshape(rend_flm,
                                              height_out=256,
                                              width_out=256,
                                              mode='bilinear')
                norma_map_img = fast_image_reshape(norma_map_img,
                                                   height_out=256,
                                                   width_out=256,
                                                   mode='bilinear')
            else:
                raise ValueError('Flame prameter format not understood')
        # import ipdb; ipdb.set_trace()
        if use_posed_constant_input:
            pose = flame_batch[:, constants.get_idx_list('GLOBAL_ROT')]
        else:
            pose = None
        fixed_identities = torch.ones(
            flame_batch.shape[0], dtype=torch.long,
            device='cuda') * np.random.randint(0, max_ids)
        if rendered_flame_as_condition and normal_maps_as_cond:
            gen_in = torch.cat((rend_flm, norma_map_img), dim=1)
        elif rendered_flame_as_condition:
            gen_in = rend_flm
        elif normal_maps_as_cond:
            gen_in = norma_map_img
        else:
            gen_in = flame_batch
        generated_image = generator(gen_in,
                                    pose=pose,
                                    step=step,
                                    alpha=alpha,
                                    input_indices=fixed_identities)[-1]
        textures, tx_masks = self.flm_tex_dec(generated_image, flame_batch)
        return textures, tx_masks, generated_image
Exemple #9
0
normalization_file_path = None
flame_param_est = None

# import ipdb; ipdb.set_trace()
list_bad_images = np.load('/is/cluster/work/pghosh/gif1.0/DECA_inferred/b_box_stats.npz')['bad_images']
dataset = FFHQ(real_img_root=data_root, rendered_flame_root=rendered_flame_root, params_dir=params_dir,
               generic_transform=generic_transform, pose_cam_from_yao=False,
               rendered_flame_as_condition=True, resolution=256,
               normalization_file_path=normalization_file_path, debug=True, random_crop=False, get_normal_images=True,
               flame_version='DECA', list_bad_images=list_bad_images)

show_img_at_res = 256


overlay_visualizer = OverLayViz(full_neck=False, add_random_noise_to_background=False, inside_mouth_faces=True,
                                background_img=None, texture_pattern_name='MEAN_TEXTURE_WITH_CHKR_BOARD',
                                flame_version='DECA', image_size=256)
deca_flame = overlay_visualizer.deca.flame

example_out_dir = '/is/cluster/work/pghosh/gif1.0/eye_cntr_images'

normalized_eye_cntr_L = []
normalized_eye_cntr_R = []
for indx in range(50):
    # i = int(random.randint(0, 60_000))
    i = indx

    fig, ax1 = plt.subplots(1, 1)
    img, flm_rndr, flm_lbl, index = dataset.__getitem__(i, bypass_valid_indexing=False)
    img = img[None, ...]
Exemple #10
0
               pose_cam_from_yao=False,
               rendered_flame_as_condition=True,
               resolution=256,
               normalization_file_path=normalization_file_path,
               debug=True,
               random_crop=False,
               get_normal_images=True,
               flame_version='DECA',
               list_bad_images=list_bad_images)

show_img_at_res = 256

overlay_visualizer = OverLayViz(
    full_neck=False,
    add_random_noise_to_background=False,
    inside_mouth_faces=True,
    background_img=None,
    texture_pattern_name='MEAN_TEXTURE_WITH_CHKR_BOARD',
    flame_version='DECA',
    image_size=256)

example_out_dir = '/is/cluster/work/pghosh/gif1.0/DECA_inferred/saved_rerendered_side_by_side'

for indx in range(5):
    # i = int(random.randint(0, 60_000))
    i = indx

    fig, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8) = plt.subplots(1, 8)
    img, flm_rndr, flm_lbl, index = dataset.__getitem__(
        i, bypass_valid_indexing=False)
    img = img[None, ...]
    flm_rndr = flm_rndr[0][None, ...]
generator = torch.nn.DataParallel(generator)
generator.load_state_dict(ckpt['generator_running'])
# generator.load_state_dict(ckpt['generator'])
# generator.eval()

log_dir = os.path.join(cnst.output_root, 'gif_smpls/FFHQ')

if random_background:
    torch.manual_seed(2)
    back_ground_noise = (torch.randn(
        (3, 224, 224), dtype=torch.float32) * 255).clamp(min=0,
                                                         max=255).cuda()
else:
    back_ground_noise = None
# # Don't provide add rnadom noise to background here. Cause then every frame will have different noise and that's bad
overlay_visualizer = OverLayViz()
# overlay_visualizer.setup_renderer(mesh_file=None)

if flame_version == 'FLAME_2020_revisited':
    flength = 5000
    cam_t = np.array([0., 0., 0])
    camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength)
elif flame_version == 'DECA':
    pass
else:
    cam_t = np.array([0., 0., 2.5])
    camera_params = camera_dynamic((resolution, resolution), cam_t)

fixed_identity_embeddings = torch.ones(1, dtype=torch.long, device='cuda')

expt_name = 'teaser_figure'
Exemple #12
0
from my_utils.visualize_flame_overlay import OverLayViz
from my_utils.ringnet_overlay.util import tensor_vis_landmarks
from PIL import Image
import torch
import matplotlib.pyplot as plt

resolution = 512
flength = 5000
cam_t = np.array([0., 0., 0])
camera_params = camera_ringnetpp((resolution, resolution),
                                 trans=cam_t,
                                 focal=flength)

overlay_visualizer = OverLayViz(full_neck=False,
                                add_random_noise_to_background=False,
                                background_img=None,
                                texture_pattern_name='CHKR_BRD_FLT_TEETH',
                                flame_version='FLAME_2020_revisited',
                                image_size=resolution)

dest_dir = '/is/cluster/work/pghosh/gif1.0/fitting_viz/flame_2020/'
# dest_dir = '/is/cluster/work/pghosh/gif1.0/fitting_viz/flame_old/'

for i in tqdm.tqdm(range(0, 300)):
    img_file = '/is/cluster/scratch/partha/face_gan_data/FFHQ/images1024x1024/' + str(
        i).zfill(5) + '.png'
    img_original = Image.open(img_file).resize((resolution, resolution))
    images = torch.from_numpy(
        np.array(img_original).transpose(
            (2, 0, 1)).astype('float32'))[None, ...].cuda()

    flame_param_file = f'/is/cluster/pghosh/face_gan_data/FFHQ/flmae_photometric_opt/' \
Exemple #13
0
# embeddings = generator.get_embddings()
generator = torch.nn.DataParallel(generator)
generator.load_state_dict(ckpt['generator_running'])
# generator.load_state_dict(ckpt['generator'])
# generator.eval()

log_dir = os.path.join(cnst.output_root, 'gif_smpls/FFHQ')

if random_background:
    torch.manual_seed(2)
    back_ground_noise = (torch.randn((3, 224, 224), dtype=torch.float32)*255).clamp(min=0, max=255).cuda()
else:
    back_ground_noise = None

# # Don't provide add rnadom noise to background here. Cause then every frame will have different noise and that's bad
overlay_visualizer = OverLayViz()
# overlay_visualizer.setup_renderer(mesh_file=None)

if flame_version == 'FLAME_2020_revisited':
    flength = 5000
    cam_t = np.array([0., 0., 0])
    camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength)
elif flame_version == 'DECA':
    pass
else:
    cam_t = np.array([0., 0., 2.5])
    camera_params = camera_dynamic((resolution, resolution), cam_t)

fixed_identity_embeddings = torch.ones(1, dtype=torch.long, device='cuda')

expt_name = 'teaser_figure'
Exemple #14
0
    pose = np.hstack((seqs['frame_pose_params'][:, 0:3] * 0,
                      seqs['frame_pose_params'][:, 6:9]))
else:
    pose = np.hstack(
        (seqs['frame_pose_params'][:, 0:3], seqs['frame_pose_params'][:, 6:9]))

num_frames = seqs['frame_exp_params'].shape[0]
translation = np.zeros((num_frames, 3))
flame_shape = np.repeat(
    seqs['seq_shape_params'][np.newaxis, :].astype('float32'), (num_frames, ),
    axis=0)
flm_batch = np.hstack((flame_shape, seqs['frame_exp_params'], pose,
                       translation)).astype('float32')[::8]
flm_batch = torch.from_numpy(flm_batch).cuda()

overlay_visualizer = OverLayViz()

config_obj = util.dict2obj(cnst.flame_config)
flame_decoder = FLAME.FLAME(config_obj).cuda().eval()
flm_batch = position_to_given_location(flame_decoder, flm_batch)

# Render FLAME
batch_size_true = flm_batch.shape[0]
cam = flm_batch[:, constants.DECA_IDX['cam'][0]:constants.DECA_IDX['cam'][1]:]
shape = flm_batch[:,
                  constants.INDICES['SHAPE'][0]:constants.INDICES['SHAPE'][1]]
exp = flm_batch[:, constants.INDICES['EXP'][0]:constants.INDICES['EXP'][1]]
pose = flm_batch[:, constants.INDICES['POSE'][0]:constants.INDICES['POSE'][1]]
# import ipdb; ipdb.set_trace()

fl_param_dict = np.load(cnst.all_flame_params_file, allow_pickle=True).item()
seqs = np.load(cnst.voca_flame_seq_file)

if ignore_global_rotation:
    pose = np.hstack((seqs['frame_pose_params'][:, 0:3]*0, seqs['frame_pose_params'][:, 6:9]))
else:
    pose = np.hstack((seqs['frame_pose_params'][:, 0:3], seqs['frame_pose_params'][:, 6:9]))

num_frames = seqs['frame_exp_params'].shape[0]
translation = np.zeros((num_frames, 3))
shape_seq = seqs['seq_shape_params']
shape_seq[3:] *= 0
flame_shape = np.repeat(shape_seq[np.newaxis, :].astype('float32'), (num_frames,), axis=0)
flm_batch = np.hstack((flame_shape, seqs['frame_exp_params'], pose, translation)).astype('float32')
flm_batch = torch.from_numpy(flm_batch).cuda()

overlay_visualizer = OverLayViz()

config_obj = util.dict2obj(cnst.flame_config)
flame_decoder = FLAME.FLAME(config_obj).cuda().eval()
flm_batch = position_to_given_location(flame_decoder, flm_batch)


# Render FLAME
seq_len = flm_batch.shape[0]
cam = flm_batch[:, constants.DECA_IDX['cam'][0]:constants.DECA_IDX['cam'][1]:]
shape = flm_batch[:, constants.INDICES['SHAPE'][0]:constants.INDICES['SHAPE'][1]]
exp = flm_batch[:, constants.INDICES['EXP'][0]:constants.INDICES['EXP'][1]]
pose = flm_batch[:, constants.INDICES['POSE'][0]:constants.INDICES['POSE'][1]]
# import ipdb; ipdb.set_trace()

light_texture_id_code_source = np.load('../teaser/params.npy', allow_pickle=True).item()