Exemplo n.º 1
0
def get_gif_from_list_of_params(generator, flame_params, step, alpha, noise, overlay_landmarks, flame_std, flame_mean,
                                overlay_visualizer, rendered_flame_as_condition, use_posed_constant_input,
                                normal_maps_as_cond, camera_params):
    # cam_t = np.array([0., 0., 2.5])
    # camera_params = camera_dynamic((224, 224), cam_t)
    if overlay_visualizer is None:
        overlay_visualizer = OverLayViz()

    fixed_embeddings = torch.ones(flame_params.shape[0], dtype=torch.long, device='cuda')*13
    # print(generator.module.get_embddings()[fixed_embeddings])
    flame_params_unnorm = flame_params * flame_std + flame_mean

    flame_params_unnorm = torch.from_numpy(flame_params_unnorm).cuda()
    normal_map_img, _, _, _, rend_imgs = \
        overlay_visualizer.get_rendered_mesh(flame_params=(flame_params_unnorm[:, SHAPE_IDS[0]:SHAPE_IDS[1]],
                                                           flame_params_unnorm[:, EXP_IDS[0]:EXP_IDS[1]],
                                                           flame_params_unnorm[:, POSE_IDS[0]:POSE_IDS[1]],
                                                           flame_params_unnorm[:, TRANS_IDS[0]:TRANS_IDS[1]]),
                                             camera_params=camera_params)
    rend_imgs = (rend_imgs/127.0 - 1)

    if use_posed_constant_input:
        pose = flame_params[:, constants.get_idx_list('GLOBAL_ROT')]
    else:
        pose = None

    if rendered_flame_as_condition:
        gen_in = rend_imgs
    else:
        gen_in = flame_params

    if normal_maps_as_cond:
        gen_in = torch.cat((rend_imgs, normal_map_img), dim=1)

    fake_images = generate_from_flame_sequence(generator, gen_in, pose, step, alpha, noise,
                                               input_indices=fixed_embeddings)[-1]

    fake_images = overlay_visualizer.range_normalize_images(fast_image_reshape(fake_images,
                                                                               height_out=rend_imgs.shape[2],
                                                                               width_out=rend_imgs.shape[3]))
    if rendered_flame_as_condition:
        fake_images = torch.cat([fake_images.cpu(), (rend_imgs.cpu() + 1)/2], dim=-1)

    if normal_maps_as_cond:
        fake_images = torch.cat([fake_images.cpu(), (normal_map_img.cpu() + 1) / 2], dim=-1)

    return fake_images
Exemplo n.º 2
0
                settings_for_runs[run_idx]['rendered_flame_as_condition']:
            cam = flm_batch[:, constants.DECA_IDX['cam'][0]:constants.
                            DECA_IDX['cam'][1]:]
            shape = flm_batch[:, constants.INDICES['SHAPE'][0]:constants.
                              INDICES['SHAPE'][1]]
            exp = flm_batch[:, constants.INDICES['EXP'][0]:constants.
                            INDICES['EXP'][1]]
            pose = flm_batch[:, constants.INDICES['POSE'][0]:constants.
                             INDICES['POSE'][1]]
            # import ipdb; ipdb.set_trace()
            light_code = \
                flm_batch[:, constants.DECA_IDX['lit'][0]:constants.DECA_IDX['lit'][1]:].view((batch_size_true, 9, 3))
            texture_code = flm_batch[:, constants.DECA_IDX['tex'][0]:constants.
                                     DECA_IDX['tex'][1]:]
            norma_map_img, _, _, _, rend_flm = \
                overlay_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code),
                                                     camera_params=cam)
            rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1
            norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1
            rend_flm = fast_image_reshape(rend_flm,
                                          height_out=256,
                                          width_out=256,
                                          mode='bilinear')
            norma_map_img = fast_image_reshape(norma_map_img,
                                               height_out=256,
                                               width_out=256,
                                               mode='bilinear')

        else:
            rend_flm = None
            norma_map_img = None
Exemplo n.º 3
0
                texture_code = flm_batch[:,
                                         constants.DECA_IDX['tex'][0]:constants
                                         .DECA_IDX['tex'][1]:]

                params_to_save['cam'].append(cam.cpu().detach().numpy())
                params_to_save['shape'].append(shape.cpu().detach().numpy())
                params_to_save['shape'].append(shape.cpu().detach().numpy())
                params_to_save['exp'].append(exp.cpu().detach().numpy())
                params_to_save['pose'].append(pose.cpu().detach().numpy())
                params_to_save['light_code'].append(
                    light_code.cpu().detach().numpy())
                params_to_save['texture_code'].append(
                    texture_code.cpu().detach().numpy())

                norma_map_img, _, _, _, rend_flm = \
                    overlay_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code),
                                                         camera_params=cam)
                # import ipdb; ipdb.set_trace()

                rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1
                norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1
                rend_flm = fast_image_reshape(rend_flm,
                                              height_out=256,
                                              width_out=256,
                                              mode='bilinear')
                norma_map_img = fast_image_reshape(norma_map_img,
                                                   height_out=256,
                                                   width_out=256,
                                                   mode='bilinear')

                # Render the 2nd time to get backface culling and white texture
                # norma_map_img_to_save, _, _, _, rend_flm_to_save = \
Exemplo n.º 4
0
                expcode.append(torch.tensor(params['exp'])[None, ...].cuda())
                posecode.append(torch.tensor(params['pose'])[None, ...].cuda())
                cam.append(torch.tensor(params['cam'])[None, ...].cuda())
                texcode.append(torch.tensor(params['tex'])[None, ...].cuda())
                lightcode.append(torch.tensor(params['lit'])[None, ...].cuda())

            shapecode = torch.cat(shapecode, dim=0)
            expcode = torch.cat(expcode, dim=0)
            posecode = torch.cat(posecode, dim=0)
            cam = torch.cat(cam, dim=0)
            texcode = torch.cat(texcode, dim=0)
            lightcode = torch.cat(lightcode, dim=0)

            # render
            normal_images, _, _, _, textured_images = \
                overlay_viz.get_rendered_mesh((shapecode, expcode, posecode, lightcode, texcode), cam)

            count = 0
            for item_id in range(batch_size):
                i = batch_id * batch_size + item_id
                if not str(i).zfill(5) + '.pkl' in param_files:
                    continue
                textured_image = (textured_images[count].detach().cpu().numpy()*255).astype('uint8').transpose((1, 2, 0))
                textured_image = Image.fromarray(textured_image)

                normal_image = (normal_images[count].detach().cpu().numpy()*255).astype('uint8').transpose((1, 2, 0))
                normal_image = Image.fromarray(normal_image)

                # # Just for inspection
                # fig, (ax1, ax2,) = plt.subplots(1, 2)
                # ax1.imshow(textured_image)
Exemplo n.º 5
0
class InterpolatedTextureLoss:
    def __init__(self, max_images_in_batch):
        texture_data = np.load(cnst.flame_texture_space_dat_file,
                               allow_pickle=True,
                               encoding='latin1').item()
        self.flm_tex_dec = FlameTextureSpace(texture_data,
                                             data_un_normalizer=None).cuda()
        self.flame_visualizer = OverLayViz()
        self.face_region_only_mask = np.array(
            Image.open(cnst.face_region_mask_file))[:, :, 0:1].transpose(
                (2, 0, 1))
        self.face_region_only_mask = (
            torch.from_numpy(self.face_region_only_mask.astype('float32')) /
            255.0)[None, ...]
        print('read face region mask')
        flength = 5000
        cam_t = np.array([0., 0., 0])
        self.camera_params = camera_ringnetpp((512, 512),
                                              trans=cam_t,
                                              focal=flength)
        self.max_num = max_images_in_batch - 1
        # self.max_num = 5
        self.pairs = []
        for i in range(self.max_num):
            for j in range(i + 1, self.max_num):
                self.pairs.append((i, j))

        self.pairs = np.array(self.pairs)

    def pairwise_texture_loss(self, tx1, tx2):
        if self.face_region_only_mask.device != tx1.device:
            self.face_region_only_mask = self.face_region_only_mask.to(
                tx1.device)

        if self.face_region_only_mask.shape[-1] != tx1.shape[-1]:
            face_region_only_mask = fast_image_reshape(
                self.face_region_only_mask, tx1.shape[1], tx1.shape[2])
        else:
            face_region_only_mask = self.face_region_only_mask

        # import ipdb; ipdb.set_trace()
        return torch.mean(
            torch.sigmoid(torch.pow(tx1 - tx2, 2)) * face_region_only_mask[0])
        # return torch.mean(torch.pow(tx1 - tx2, 2) * face_region_only_mask[0])

    def tex_sp_intrp_loss(self, flame_batch, generator, step, alpha, max_ids,
                          normal_maps_as_cond, use_posed_constant_input,
                          rendered_flame_as_condition):

        textures, tx_masks, _ = self.get_image_and_textures(
            alpha, flame_batch, generator, max_ids, normal_maps_as_cond,
            rendered_flame_as_condition, step, use_posed_constant_input)

        # import ipdb; ipdb.set_trace()

        random_pairs_idxs = np.random.choice(len(self.pairs),
                                             self.max_num,
                                             replace=False)
        random_pairs = self.pairs[random_pairs_idxs]
        loss = 0
        for cur_pair in random_pairs:
            tx_mask_common = tx_masks[cur_pair[1]] * tx_masks[cur_pair[0]]
            loss += self.pairwise_texture_loss(
                tx1=textures[cur_pair[0]] * tx_mask_common,
                tx2=textures[cur_pair[1]] * tx_mask_common)

        return 16 * loss / len(random_pairs)

    def get_image_and_textures(self, alpha, flame_batch, generator, max_ids,
                               normal_maps_as_cond,
                               rendered_flame_as_condition, step,
                               use_posed_constant_input):
        batch_size = flame_batch.shape[0]
        flame_batch = flame_batch[:self.max_num, :]  # Just to limit run time
        # import ipdb; ipdb.set_trace()
        if rendered_flame_as_condition or normal_maps_as_cond:
            shape = flame_batch[:, constants.INDICES['SHAPE'][0]:constants.
                                INDICES['SHAPE'][1]]
            exp = flame_batch[:, constants.INDICES['EXP'][0]:constants.
                              INDICES['EXP'][1]]
            pose = flame_batch[:, constants.INDICES['POSE'][0]:constants.
                               INDICES['POSE'][1]]
            if flame_batch.shape[-1] == 159:  # non DECA params
                flame_params = (
                    shape, exp, pose,
                    flame_batch[:, constants.INDICES['TRANS'][0]:constants.
                                INDICES['TRANS'][1]])  # translation
                norma_map_img, _, _, _, rend_flm = self.flame_visualizer.get_rendered_mesh(
                    flame_params=flame_params,
                    camera_params=self.camera_params)
                rend_flm = rend_flm / 127 - 1.0
                norma_map_img = norma_map_img * 2 - 1
            elif flame_batch.shape[-1] == 236:  # DECA
                cam = flame_batch[:, constants.DECA_IDX['cam'][0]:constants.
                                  DECA_IDX['cam'][1]:]

                # Same lightcode for the whole batch
                light_code = flame_batch[
                    0:1,
                    constants.DECA_IDX['lit'][0]:constants.DECA_IDX['lit'][1]:]
                light_code = light_code.repeat(batch_size, 1)
                light_code = light_code.view((batch_size, 9, 3))

                #same texture code for the whole batch
                texture_code = flame_batch[
                    0:1,
                    constants.DECA_IDX['tex'][0]:constants.DECA_IDX['tex'][1]:]
                texture_code = texture_code.repeat(batch_size, 1)
                # import ipdb; ipdb.set_trace()
                norma_map_img, _, _, _, rend_flm = \
                    self.flame_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code),
                                                            camera_params=cam)
                # rend_flm = rend_flm * 2 - 1
                rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1
                norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1
                rend_flm = fast_image_reshape(rend_flm,
                                              height_out=256,
                                              width_out=256,
                                              mode='bilinear')
                norma_map_img = fast_image_reshape(norma_map_img,
                                                   height_out=256,
                                                   width_out=256,
                                                   mode='bilinear')
            else:
                raise ValueError('Flame prameter format not understood')
        # import ipdb; ipdb.set_trace()
        if use_posed_constant_input:
            pose = flame_batch[:, constants.get_idx_list('GLOBAL_ROT')]
        else:
            pose = None
        fixed_identities = torch.ones(
            flame_batch.shape[0], dtype=torch.long,
            device='cuda') * np.random.randint(0, max_ids)
        if rendered_flame_as_condition and normal_maps_as_cond:
            gen_in = torch.cat((rend_flm, norma_map_img), dim=1)
        elif rendered_flame_as_condition:
            gen_in = rend_flm
        elif normal_maps_as_cond:
            gen_in = norma_map_img
        else:
            gen_in = flame_batch
        generated_image = generator(gen_in,
                                    pose=pose,
                                    step=step,
                                    alpha=alpha,
                                    input_indices=fixed_identities)[-1]
        textures, tx_masks = self.flm_tex_dec(generated_image, flame_batch)
        return textures, tx_masks, generated_image
Exemplo n.º 6
0
    shape = flame_interp_batch[:, constants.INDICES['SHAPE'][0]:constants.
                               INDICES['SHAPE'][1]]
    exp = flame_interp_batch[:, constants.INDICES['EXP'][0]:constants.
                             INDICES['EXP'][1]]
    pose = flame_interp_batch[:, constants.INDICES['POSE'][0]:constants.
                              INDICES['POSE'][1]]
    texture_code_batch = linear_interpolate(texture_code[interp_type_idx],
                                            texture_code[interp_type_idx + 1],
                                            num_frames)
    light_code_batch = linear_interpolate(light_code[interp_type_idx],
                                          light_code[interp_type_idx + 1],
                                          num_frames)
    light_code_batch = light_code_batch.view(num_frames, 9, 3)

    norma_map_img, _, _, _, rend_flm = \
        overlay_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code_batch, texture_code_batch),
                                             camera_params=cam)
    rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1
    norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1
    rend_flm = fast_image_reshape(rend_flm,
                                  height_out=256,
                                  width_out=256,
                                  mode='bilinear')
    norma_map_img = fast_image_reshape(norma_map_img,
                                       height_out=256,
                                       width_out=256,
                                       mode='bilinear')

    if normal_maps_as_cond and rendered_flame_as_condition:
        # norma_map_img = norma_map_img * 2 - 1
        gen_in = torch.cat((rend_flm, norma_map_img), dim=1)
    elif normal_maps_as_cond:
Exemplo n.º 7
0
    try:
        flame_param = np.load(flame_param_file, allow_pickle=True)
    except FileNotFoundError as e:
        continue

    flame_param = np.hstack((flame_param['shape'], flame_param['exp'],
                             flame_param['pose'], flame_param['cam']))

    tz = flength / (0.5 * resolution * flame_param[:, 156:157])
    flame_param[:, 156:159] = np.concatenate((flame_param[:, 157:], tz),
                                             axis=1)

    flame_param = torch.from_numpy(flame_param).cuda()

    norma_map_img, pos_mask, alpha_images, key_points2d, rend_imgs = \
        overlay_visualizer.get_rendered_mesh(
            flame_params=(flame_param[:, 0:100],  # shape
                          flame_param[:, 100:150],  # exp
                          flame_param[:, 150:156],  # Pose
                          flame_param[:, 156:159]),  # translation
            camera_params=camera_params)

    overlayed_imgs_original = tensor_vis_landmarks(images / 255,
                                                   key_points2d[:, 17:])
    plt.imshow(overlayed_imgs_original[0].cpu().detach().numpy().transpose(
        (1, 2, 0)))
    print('Shown')
    plt.show()
    # img_to_save = (255*overlayed_imgs_original[0]).cpu().detach().numpy().transpose((1, 2, 0)).astype('uint8')
    # img_to_save = Image.fromarray(img_to_save)
    # img_to_save.save(dest_dir + str(i) + '.png')
Exemplo n.º 8
0
    rend_flm = []

    for batch_idx in range(0, seq_len, 32):
        shape_batch = shape[batch_idx:batch_idx+32]
        exp_batch = exp[batch_idx:batch_idx+32]
        pose_batch = pose[batch_idx:batch_idx+32]
        cam_batch = cam[batch_idx:batch_idx+32]
        true_batch_size = cam_batch.shape[0]
        light_code = light_texture_id_code_source['light_code'][id].astype('float32')[None, ...].\
            repeat(true_batch_size, axis=0)
        texture_code = light_texture_id_code_source['texture_code'][id].astype('float32')[None, ...].\
            repeat(true_batch_size, axis=0)
        # import ipdb; ipdb.set_trace()
        norma_map_img_batch, _, _, _, rend_flm_batch = \
            overlay_visualizer.get_rendered_mesh(flame_params=(shape_batch, exp_batch, pose_batch,
                                                               torch.from_numpy(light_code).cuda(),
                                                               torch.from_numpy(texture_code).cuda()),
                                                 camera_params=cam_batch)
        rend_flm_batch = torch.clamp(rend_flm_batch, 0, 1) * 2 - 1
        norma_map_img_batch = torch.clamp(norma_map_img_batch, 0, 1) * 2 - 1
        rend_flm_batch = fast_image_reshape(rend_flm_batch, height_out=256, width_out=256, mode='bilinear')
        norma_map_img_batch = fast_image_reshape(norma_map_img_batch, height_out=256, width_out=256, mode='bilinear')

        # with back face culling and white texture
        norma_map_img_batch_to_save, _, _, _, rend_flm_batch_to_save = \
            overlay_visualizer.get_rendered_mesh(flame_params=(shape_batch, exp_batch, pose_batch,
                                                               torch.from_numpy(light_code).cuda(),
                                                               torch.from_numpy(texture_code).cuda()),
                                                 camera_params=cam_batch, cull_backfaces=True,
                                                 constant_albedo=0.6)
        rend_flm_batch_to_save = torch.clamp(rend_flm_batch_to_save, 0, 1) * 2 - 1
        norma_map_img_batch_to_save = torch.clamp(norma_map_img_batch_to_save, 0, 1) * 2 - 1