def norm_output_fls_rot(fl_data_i, anchor_t_shape=None):

    # fl_data_i = savgol_filter(fl_data_i, 21, 3, axis=0)

    t_shape_idx = (27, 28, 29, 30, 33, 36, 39, 42, 45)
    if (anchor_t_shape is None):
        anchor_t_shape = np.loadtxt(
            r'src/dataset/utils/ANCHOR_T_SHAPE_{}.txt'.format(
                len(t_shape_idx)))
        s = np.abs(anchor_t_shape[5, 0] - anchor_t_shape[8, 0])
        anchor_t_shape = anchor_t_shape / s * 1.0
        c2 = np.mean(anchor_t_shape[[4, 5, 8], :], axis=0)
        anchor_t_shape -= c2

    else:
        anchor_t_shape = anchor_t_shape.reshape((68, 3))
        anchor_t_shape = anchor_t_shape[t_shape_idx, :]

    fl_data_i = fl_data_i.reshape((-1, 68, 3)).copy()

    # get rot_mat
    rot_quats = []
    rot_trans = []
    for i in range(fl_data_i.shape[0]):
        line = fl_data_i[i]
        frame_t_shape = line[t_shape_idx, :]
        T, distance, itr = icp(frame_t_shape, anchor_t_shape)
        rot_mat = T[:3, :3]
        trans_mat = T[:3, 3:4]

        # norm to anchor
        fl_data_i[i] = np.dot(rot_mat, line.T).T + trans_mat.T

        # inverse (anchor -> reat_t)
        # tmp = np.dot(rot_mat.T, (anchor_t_shape - trans_mat.T).T).T

        r = R.from_matrix(rot_mat)
        rot_quats.append(r.as_quat())
        # rot_eulers.append(r.as_euler('xyz'))
        rot_trans.append(T[:3, :])

    rot_quats = np.array(rot_quats)
    rot_trans = np.array(rot_trans)

    return rot_trans, rot_quats, fl_data_i
예제 #2
0
    def __train_pass__(self, epoch, log_loss, is_training=True):
        st_epoch = time.time()

        # Step 1: init setup
        if(is_training):
            self.C.train()
            data = self.train_data
            dataloader = self.train_dataloader
            status = 'TRAIN'
        else:
            self.C.eval()
            data = self.eval_data
            dataloader = self.eval_dataloader
            status = 'EVAL'

        random_clip_index = np.random.permutation(len(dataloader))[0:self.opt_parser.random_clip_num]
        print('random visualize clip index', random_clip_index)

        # Step 2: train for each batch
        for i, batch in enumerate(dataloader):

            global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4]
            inputs_fl, inputs_au = batch
            inputs_fl_ori, inputs_au_ori = inputs_fl.to(device), inputs_au.to(device)

            std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []
            seg_bs = 512

            ''' pick a most closed lip frame from entire clip data '''
            close_fl_list = inputs_fl_ori[::10, 0, :]
            idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy())
            input_face_id = close_fl_list[idx:idx + 1, :]

            ''' register face '''
            if (self.opt_parser.use_reg_as_std):
                landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3)
                frame_t_shape = landmarks[self.t_shape_idx, :]
                T, distance, itr = icp(frame_t_shape, self.anchor_t_shape)
                landmarks = np.hstack((landmarks, np.ones((68, 1))))
                registered_landmarks = np.dot(T, landmarks.T).T
                input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False,
                                             dtype=torch.float).to(device)

            for in_batch in range(self.opt_parser.in_batch_nepoch):

                std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []

                if (is_training):
                    rand_start = np.random.randint(0, inputs_fl_ori.shape[0] // 5, 1).reshape(-1)
                    inputs_fl = inputs_fl_ori[rand_start[0]:]
                    inputs_au = inputs_au_ori[rand_start[0]:]
                else:
                    inputs_fl = inputs_fl_ori
                    inputs_au = inputs_au_ori

                for j in range(0, inputs_fl.shape[0], seg_bs):

                    # Step 3.1: load segments
                    inputs_fl_segments = inputs_fl[j: j + seg_bs]
                    inputs_au_segments = inputs_au[j: j + seg_bs]
                    fl_std = inputs_fl_segments[:, 0, :].data.cpu().numpy()

                    if(inputs_fl_segments.shape[0] < 10):
                        continue

                    fl_dis_pred_pos, input_face_id, loss = \
                        self.__train_content__(inputs_fl_segments, inputs_au_segments, input_face_id, is_training)

                    fl_dis_pred_pos = (fl_dis_pred_pos + input_face_id).data.cpu().numpy()
                    ''' solve inverse lip '''
                    fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos)

                    fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))]
                    std_fls_list += [fl_std.reshape((-1, 204))]

                    for key in log_loss.keys():
                        if (key not in locals().keys()):
                            continue
                        if (type(locals()[key]) == float):
                            log_loss[key].add(locals()[key])
                        else:
                            log_loss[key].add(locals()[key].data.cpu().numpy())


                if (epoch % self.opt_parser.jpg_freq == 0 and (i in random_clip_index or in_batch % self.opt_parser.jpg_freq == 1)):
                    def save_fls_av(fake_fls_list, postfix='', ifsmooth=True):
                        fake_fls_np = np.concatenate(fake_fls_list)
                        filename = 'fake_fls_{}_{}_{}.txt'.format(epoch, video_name, postfix)
                        np.savetxt(
                            os.path.join(self.opt_parser.dump_dir, '../nn_result', self.opt_parser.name, filename),
                            fake_fls_np, fmt='%.6f')
                        audio_filename = '{:05d}_{}_audio.wav'.format(global_id, video_name)
                        from util.vis import Vis_old
                        Vis_old(run_name=self.opt_parser.name, pred_fl_filename=filename, audio_filename=audio_filename,
                                fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, in_batch, postfix),
                                postfix=postfix, root_dir=self.opt_parser.root_dir, ifsmooth=ifsmooth)

                    if (self.opt_parser.show_animation and not is_training):
                        print('show animation ....')
                        save_fls_av(fls_pred_pos_list, 'pred_{}'.format(i), ifsmooth=True)
                        save_fls_av(std_fls_list, 'std_{}'.format(i), ifsmooth=False)
                        from util.vis import Vis_comp
                        Vis_comp(run_name=self.opt_parser.name,
                                 pred1='fake_fls_{}_{}_{}.txt'.format(epoch, video_name, 'pred_{}'.format(i)),
                                 pred2='fake_fls_{}_{}_{}.txt'.format(epoch, video_name, 'std_{}'.format(i)),
                                 audio_filename='{:05d}_{}_audio.wav'.format(global_id, video_name),
                                fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, in_batch, 'comp_{}'.format(i)),
                                postfix='comp_{}'.format(i), root_dir=self.opt_parser.root_dir, ifsmooth=False)

                    self.__save_model__(save_type='last_inbatch', epoch=epoch)

                if (self.opt_parser.verbose <= 1):
                    print('{} Epoch: #{} batch #{}/{} inbatch #{}/{}'.format(
                        status, epoch, i, len(dataloader),
                    in_batch, self.opt_parser.in_batch_nepoch), end=': ')
                    for key in log_loss.keys():
                        print(key, '{:.5f}'.format(log_loss[key].per('batch')), end=', ')
                    print('')

        if (self.opt_parser.verbose <= 2):
            print('==========================================================')
            print('{} Epoch: #{}'.format(status, epoch), end=':')
            for key in log_loss.keys():
                print(key, '{:.4f}'.format(log_loss[key].per('epoch')), end=', ')
            print(
                'Epoch time usage: {:.2f} sec\n==========================================================\n'.format(
                    time.time() - st_epoch))
        self.__save_model__(save_type='last_epoch', epoch=epoch)
        if (epoch % self.opt_parser.ckpt_epoch_freq == 0):
            self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch)
예제 #3
0
    def test_end2end(self, jpg_shape):

        self.G.eval()
        self.C.eval()
        data = self.eval_data
        dataloader = self.eval_dataloader

        for i, batch in enumerate(dataloader):

            global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4]

            inputs_fl, inputs_au, inputs_emb, inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = batch

            for key in ['irx71tYyI-Q', 'J-NPsvtQ8lE', 'Z7WRt--g-h4', 'E0zgrhQ0QDw', 'bXpavyiCu10', 'W6uRNCJmdtI', 'sxCbrYjBsGA', 'wAAMEC1OsRc', '_ldiVrXgZKc', '48uYS3bHIA8', 'E_kmpT-EfOg']:
                emb_val = self.test_embs[key]
                inputs_emb = np.tile(emb_val, (inputs_emb.shape[0], 1))
                inputs_emb = torch.tensor(inputs_emb, dtype=torch.float, requires_grad=False)

                # this_emb = key
                # inputs_emb = torch.zeros(size=(inputs_au.shape[0], len(self.test_embs_dic.keys())))
                # inputs_emb[:, self.test_embs_dic[this_emb]] = 1.

                inputs_fl, inputs_au, inputs_emb = inputs_fl.to(device), inputs_au.to(device), inputs_emb.to(device)
                inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = inputs_reg_fl.to(device), inputs_rot_tran.to(device), inputs_rot_quat.to(device)

                std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []
                seg_bs = self.opt_parser.segment_batch_size

                # input_face_id = self.std_face_id
                input_face_id = torch.tensor(jpg_shape.reshape(1, 204), requires_grad=False, dtype=torch.float).to(device)

                ''' register face '''
                if (True):
                    landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3)
                    frame_t_shape = landmarks[self.t_shape_idx, :]
                    T, distance, itr = icp(frame_t_shape, self.anchor_t_shape)
                    landmarks = np.hstack((landmarks, np.ones((68, 1))))
                    registered_landmarks = np.dot(T, landmarks.T).T
                    input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False,
                                                 dtype=torch.float).to(device)

                for j in range(0, inputs_fl.shape[0], seg_bs):
                    # Step 3.1: load segments
                    inputs_fl_segments = inputs_fl[j: j + seg_bs]
                    inputs_au_segments = inputs_au[j: j + seg_bs]
                    inputs_emb_segments = inputs_emb[j: j + seg_bs]
                    inputs_reg_fl_segments = inputs_reg_fl[j: j + seg_bs]
                    inputs_rot_tran_segments = inputs_rot_tran[j: j + seg_bs]
                    inputs_rot_quat_segments = inputs_rot_quat[j: j + seg_bs]

                    if(inputs_fl_segments.shape[0] < 10):
                        continue

                    fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_reg_fls, loss_laplacian, loss_pos) = \
                        self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments,
                                                       input_face_id,  inputs_reg_fl_segments, inputs_rot_tran_segments,
                                                     inputs_rot_quat_segments,
                                                     is_training=False, use_residual=True)

                    fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy()
                    pos_pred = pos_pred.data.cpu().numpy()
                    fl_std = inputs_reg_fl_segments[:, 0, :].data.cpu().numpy()
                    pos_std = inputs_rot_tran_segments[:, 0, :].data.cpu().numpy()

                    ''' solve inverse lip '''
                    fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos)

                    fl_dis_pred_pos = fl_dis_pred_pos.reshape((-1, 68, 3))
                    fl_std = fl_std.reshape((-1, 68, 3))
                    if(self.opt_parser.pos_dim == 12):
                        pos_pred = pos_pred.reshape((-1, 3, 4))
                        for k in range(fl_dis_pred_pos.shape[0]):
                            fl_dis_pred_pos[k] = np.dot(pos_pred[k, :3, :3].T + np.eye(3),
                                                        (fl_dis_pred_pos[k] - pos_pred[k, :, 3].T).T).T
                        pos_std = pos_std.reshape((-1, 3, 4))
                        for k in range(fl_std.shape[0]):
                            fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3),
                                                        (fl_std[k] - pos_std[k, :, 3].T).T).T
                    else:
                        smooth_length = int(min(pos_pred.shape[0] - 1, 27) // 2 * 2 + 1)
                        pos_pred = savgol_filter(pos_pred, smooth_length, 3, axis=0)
                        quat = pos_pred[:, :4]
                        trans = pos_pred[:, 4:]
                        for k in range(fl_dis_pred_pos.shape[0]):
                            fl_dis_pred_pos[k] = np.dot(R.from_quat(quat[k]).as_matrix().T,
                                                        (fl_dis_pred_pos[k] - trans[k:k+1]).T).T
                        pos_std = pos_std.reshape((-1, 3, 4))
                        for k in range(fl_std.shape[0]):
                            fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3),
                                               (fl_std[k] - pos_std[k, :, 3].T).T).T

                    fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))]
                    std_fls_list += [fl_std.reshape((-1, 204))]

                fake_fls_np = np.concatenate(fls_pred_pos_list)
                filename = 'pred_fls_{}_{}.txt'.format(video_name.split('/')[-1], key)
                np.savetxt(os.path.join('examples', filename), fake_fls_np, fmt='%.6f')
예제 #4
0
    def __train_pass__(self, epoch, log_loss, is_training=True):
        st_epoch = time.time()

        # Step 1: init setup
        if (is_training):
            self.G.train()
            data = self.train_data
            dataloader = self.train_dataloader
            status = 'TRAIN'
        else:
            self.G.eval()
            data = self.eval_data
            dataloader = self.eval_dataloader
            status = 'EVAL'

        # random_clip_index = np.random.randint(0, len(dataloader)-1, 4)
        # random_clip_index = np.random.randint(0, 64, 4)
        random_clip_index = list(range(len(dataloader)))
        # print('random_clip_index', random_clip_index)
        # Step 2: train for each batch
        for i, batch in enumerate(dataloader):

            # if(i>=512):
            #     break

            st = time.time()
            global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4]

            # Step 2.1: load batch data from dataloader (in segments)
            inputs_fl, inputs_au = batch

            if (is_training):
                rand_start = np.random.randint(0, inputs_fl.shape[0] // 5, 1).reshape(-1)
                inputs_fl = inputs_fl[rand_start[0]:]
                inputs_au = inputs_au[rand_start[0]:]

            inputs_fl, inputs_au = inputs_fl.to(device), inputs_au.to(device)
            std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []
            seg_bs = self.opt_parser.segment_batch_size

            close_fl_list = inputs_fl[::10, 0, :]
            idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy())
            input_face_id = close_fl_list[idx:idx + 1, :]

            ''' register face '''
            if (self.opt_parser.use_reg_as_std):
                landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3)
                frame_t_shape = landmarks[self.t_shape_idx, :]
                T, distance, itr = icp(frame_t_shape, self.anchor_t_shape)
                landmarks = np.hstack((landmarks, np.ones((68, 1))))
                registered_landmarks = np.dot(T, landmarks.T).T
                input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False,
                                             dtype=torch.float).to(device)

            for j in range(0, inputs_fl.shape[0], seg_bs):
                # Step 3.1: load segments
                inputs_fl_segments = inputs_fl[j: j + seg_bs]
                inputs_au_segments = inputs_au[j: j + seg_bs]


                if(inputs_fl_segments.shape[0] < 10):
                    continue

                if(self.opt_parser.test_emb):
                    input_face_id = self.std_face_id

                fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_g, loss_laplacian) = \
                    self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, input_face_id,
                                                 is_training=is_training)

                fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy()
                fl_std = inputs_fl_segments[:, 0, :].data.cpu().numpy()
                ''' solve inverse lip '''
                if(not is_training):
                    fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos)

                fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))]
                std_fls_list += [fl_std.reshape((-1, 204))]

                for key in log_loss.keys():
                    if (key not in locals().keys()):
                        continue
                    if (type(locals()[key]) == float):
                        log_loss[key].add(locals()[key])
                    else:
                        log_loss[key].add(locals()[key].data.cpu().numpy())

            if (epoch % 5 == 0): # and i in [0, 200, 400, 600, 800, 1000]):
                def save_fls_av(fake_fls_list, postfix='', ifsmooth=True):
                    fake_fls_np = np.concatenate(fake_fls_list)
                    filename = 'fake_fls_{}_{}_{}.txt'.format(epoch, video_name, postfix)
                    np.savetxt(
                        os.path.join(self.opt_parser.dump_dir, '../nn_result', self.opt_parser.name, filename),
                        fake_fls_np, fmt='%.6f')
                    # audio_filename = '{:05d}_{}_audio.wav'.format(global_id, video_name)
                    # from util.vis import Vis_old
                    # Vis_old(run_name=self.opt_parser.name, pred_fl_filename=filename, audio_filename=audio_filename,
                    #     fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, i, postfix),
                    #     postfix=postfix, root_dir=self.opt_parser.root_dir, ifsmooth=ifsmooth)

                if (True):
                    if (self.opt_parser.show_animation):
                        print('show animation ....')
                        save_fls_av(fls_pred_pos_list, 'pred', ifsmooth=True)
                        save_fls_av(std_fls_list, 'std', ifsmooth=False)

            if (self.opt_parser.verbose <= 1):
                print('{} Epoch: #{} batch #{}/{}'.format(status, epoch, i, len(dataloader)), end=': ')
                for key in log_loss.keys():
                    print(key, '{:.5f}'.format(log_loss[key].per('batch')), end=', ')
                print('')
            self.__tensorboard_write__(status, log_loss, 'batch')

        if (self.opt_parser.verbose <= 2):
            print('==========================================================')
            print('{} Epoch: #{}'.format(status, epoch), end=':')
            for key in log_loss.keys():
                print(key, '{:.4f}'.format(log_loss[key].per('epoch')), end=', ')
            print('Epoch time usage: {:.2f} sec\n==========================================================\n'.format(time.time() - st_epoch))
        self.__save_model__(save_type='last_epoch', epoch=epoch)
        if(epoch % 5 == 0):
            self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch)
        self.__tensorboard_write__(status, log_loss, 'epoch')
예제 #5
0
                                                      num_points=512)
    net = ICP().cuda()
    src = torch.tensor(pointcloud1).cuda()
    target = torch.tensor(pointcloud2).cuda()
    rotation_ab = torch.tensor(R_ab).cuda()
    translation_ab = torch.tensor(t_ab).cuda()
    batch_size = src.size(0)

    src, src_corr, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net(
        src, target)
    identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
    loss = F.mse_loss(
        torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab),
        identity) + F.mse_loss(translation_ab_pred, translation_ab)
    print(loss)

    src_np = src.detach().cpu().numpy().squeeze()
    target_np = target.detach().cpu().numpy().squeeze()
    T, distances, iterations = icp(src_np.T, target_np.T, tolerance=0.000001)
    print(iterations)
    T1 = np.eye(4)
    T1[:3, :3] = R_ab.squeeze()
    T1[:3, 3] = t_ab
    T = torch.tensor(T, dtype=torch.float32).cuda()
    T1 = torch.tensor(T1, dtype=torch.float32).cuda()

    identity = torch.eye(3).cuda()
    loss = F.mse_loss(torch.matmul(T[:3, :3].transpose(1, 0), T1[:3, :3]),
                      identity) + F.mse_loss(T[:3, 3], T1[:3, 3])
    print(loss)
예제 #6
0
                    std_face_jaw[:, 8, :])**2,
                   axis=1))
        fls_jaw = fls.reshape((-1, 68, 3))[:, 0:17, 0:2]
        jaw_l = np.sqrt(
            np.sum((0.5 * (fls_jaw[:, 0, :] + fls_jaw[:, 16, :]) -
                    fls_jaw[:, 8, :])**2,
                   axis=1,
                   keepdims=True))

        scaled_face_jaw = np.tile(std_face_jaw, (fls_jaw.shape[0], 1, 1))
        scaled_face_jaw[:, :, 1] *= (jaw_l / jaw_l_reference - 1.) * 0.4 + 1.

        for i in range(scaled_face_jaw.shape[0]):
            src = scaled_face_jaw[i]
            trg = fls_jaw[i]
            T, distance, itr = icp(src, trg)
            rot_mat = T[:2, :2]
            scaled_face_jaw[i] = (np.dot(rot_mat, src.T) + T[:2, 2:3]).T
        # fls[:, 17:, 0:2] *= 0.99
        trg = fls.reshape((-1, 68, 3))
        src = std_face.reshape((-1, 68, 3))
        dis = - 1.0 * (0.5 * (scaled_face_jaw[:, 0, :] + scaled_face_jaw[:, 16, :]) - 0.5 * (
                    trg[:, 36, 0:2] + trg[:, 45, 0:2])) + \
              1.0 * (0.5 * (std_face_jaw[:, 0, :] + std_face_jaw[:, 16, :]) - 0.5 * (src[:, 36, 0:2] + src[:, 45, 0:2]))
        scaled_face_jaw = scaled_face_jaw + np.expand_dims(dis, axis=1)
        fls[:, 0:17, 0:2] = scaled_face_jaw
        fls[:, 0:17, 1] *= 1.02

    if (DEMO_CH in [
            'paint', 'mulaney', 'cartoonM', 'beer', 'color', 'JohnMulaney',
            'vangogh', 'jm', 'roy', 'lineface'
    def __train_pass__(self,
                       au_emb=None,
                       centerize_face=False,
                       no_y_rotation=False,
                       vis_fls=False):

        # Step 1: init setup
        self.G.eval()
        self.C.eval()
        data = self.eval_data
        dataloader = self.eval_dataloader

        # Step 2: train for each batch
        for i, batch in enumerate(dataloader):

            global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4]

            # Step 2.1: load batch data from dataloader (in segments)
            inputs_fl, inputs_au, inputs_emb = batch

            keys = self.opt_parser.reuse_train_emb_list
            if (len(keys) == 0):
                keys = ['audio_embed']
            for key in keys:  # ['45hn7-LXDX8']: #['sxCbrYjBsGA']:#
                # load saved emb
                if (au_emb is None):
                    emb_val = self.test_embs[key]
                else:
                    emb_val = au_emb[i]

                inputs_emb = np.tile(emb_val, (inputs_emb.shape[0], 1))
                inputs_emb = torch.tensor(inputs_emb,
                                          dtype=torch.float,
                                          requires_grad=False)
                inputs_fl, inputs_au, inputs_emb = inputs_fl.to(
                    device), inputs_au.to(device), inputs_emb.to(device)

                std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []
                seg_bs = 512

                for j in range(0, inputs_fl.shape[0], seg_bs):

                    # Step 3.1: load segments
                    inputs_fl_segments = inputs_fl[j:j + seg_bs]
                    inputs_au_segments = inputs_au[j:j + seg_bs]
                    inputs_emb_segments = inputs_emb[j:j + seg_bs]

                    if (inputs_fl_segments.shape[0] < 10):
                        continue

                    input_face_id = self.std_face_id

                    fl_dis_pred_pos, input_face_id = \
                        self.__train_face_and_pos__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments,
                                                           input_face_id)

                    fl_dis_pred_pos = (fl_dis_pred_pos +
                                       input_face_id).data.cpu().numpy()
                    ''' solve inverse lip '''
                    fl_dis_pred_pos = self.__solve_inverse_lip2__(
                        fl_dis_pred_pos)
                    fls_pred_pos_list += [fl_dis_pred_pos]

                fake_fls_np = np.concatenate(fls_pred_pos_list)

                # revise nose top point
                fake_fls_np[:, 27 * 3:28 *
                            3] = fake_fls_np[:, 28 * 3:29 *
                                             3] * 2 - fake_fls_np[:, 29 *
                                                                  3:30 * 3]

                # fake_fls_np[:, 48*3+1::3] += 0.1

                # smooth
                from scipy.signal import savgol_filter
                fake_fls_np = savgol_filter(fake_fls_np, 5, 3, axis=0)

                if (centerize_face):
                    std_m = np.mean(
                        self.std_face_id.detach().cpu().numpy().reshape(
                            (1, 68, 3)),
                        axis=1,
                        keepdims=True)
                    fake_fls_np = fake_fls_np.reshape((-1, 68, 3))
                    fake_fls_np = fake_fls_np - np.mean(
                        fake_fls_np, axis=1, keepdims=True) + std_m
                    fake_fls_np = fake_fls_np.reshape((-1, 68 * 3))

                if (no_y_rotation):
                    std = self.std_face_id.detach().cpu().numpy().reshape(
                        68, 3)
                    std_t_shape = std[self.t_shape_idx, :]
                    fake_fls_np = fake_fls_np.reshape(
                        (fake_fls_np.shape[0], 68, 3))
                    frame_t_shape = fake_fls_np[:, self.t_shape_idx, :]
                    from util.icp import icp
                    from scipy.spatial.transform import Rotation as R
                    for i in range(frame_t_shape.shape[0]):
                        T, distance, itr = icp(frame_t_shape[i], std_t_shape)
                        landmarks = np.hstack((frame_t_shape[i], np.ones(
                            (9, 1))))
                        rot_mat = T[:3, :3]
                        r = R.from_dcm(rot_mat).as_euler('xyz')
                        r = [0., r[1], r[2]]
                        r = R.from_euler('xyz', r).as_dcm()
                        # print(frame_t_shape[i, 0], r)
                        landmarks = np.hstack(
                            (fake_fls_np[i] - T[:3, 3:4].T, np.ones((68, 1))))
                        T2 = np.hstack((r, T[:3, 3:4]))
                        fake_fls_np[i] = np.dot(T2, landmarks.T).T
                        # print(frame_t_shape[i, 0])
                    fake_fls_np = fake_fls_np.reshape((-1, 68 * 3))

                filename = 'pred_fls_{}_{}.txt'.format(
                    video_name.split('/')[-1], key)
                np.savetxt(os.path.join(self.opt_parser.output_folder,
                                        filename),
                           fake_fls_np,
                           fmt='%.6f')

                # ''' Visualize result in landmarks '''
                if (vis_fls):
                    from util.vis import Vis
                    Vis(fls=fake_fls_np,
                        filename=video_name.split('/')[-1],
                        fps=62.5,
                        audio_filenam='examples/' + video_name.split('/')[-1] +
                        '.wav')
예제 #8
0
    def __train_pass__(self, epoch, log_loss, is_training=True):
        st_epoch = time.time()

        # Step 1: init setup
        if (is_training):
            self.G.train()
            self.C.train()
            data = self.train_data
            dataloader = self.train_dataloader
            status = 'TRAIN'
        else:
            self.G.eval()
            self.C.eval()
            data = self.eval_data
            dataloader = self.eval_dataloader
            status = 'EVAL'

        # random_clip_index = np.random.randint(0, len(dataloader)-1, 4)
        # random_clip_index = np.random.randint(0, 64, 4)
        random_clip_index = list(range(len(dataloader)))
        print('random_clip_index', random_clip_index)
        # Step 2: train for each batch
        for i, batch in enumerate(dataloader):

            # if(i>=64):
            #     break

            st = time.time()
            global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4]

            # Step 2.1: load batch data from dataloader (in segments)
            inputs_fl, inputs_au, inputs_emb, inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = batch
            # inputs_emb = torch.zeros(size=(inputs_au.shape[0], len(self.test_embs_dic.keys())))
            # this_emb = video_name.split('_x_')[1]
            # inputs_emb[:, self.test_embs_dic[this_emb]] = 1.

            if (is_training):
                rand_start = np.random.randint(0, inputs_fl.shape[0] // 5,
                                               1).reshape(-1)
                inputs_fl = inputs_fl[rand_start[0]:]
                inputs_au = inputs_au[rand_start[0]:]
                inputs_emb = inputs_emb[rand_start[0]:]
                inputs_reg_fl = inputs_reg_fl[rand_start[0]:]
                inputs_rot_tran = inputs_rot_tran[rand_start[0]:]
                inputs_rot_quat = inputs_rot_quat[rand_start[0]:]

            inputs_fl, inputs_au, inputs_emb = inputs_fl.to(
                device), inputs_au.to(device), inputs_emb.to(device)
            inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = inputs_reg_fl.to(
                device), inputs_rot_tran.to(device), inputs_rot_quat.to(device)

            std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], []
            seg_bs = self.opt_parser.segment_batch_size

            close_fl_list = inputs_fl[::10, 0, :]
            idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy())
            input_face_id = close_fl_list[idx:idx + 1, :]
            ''' register face '''
            if (self.opt_parser.use_reg_as_std):
                landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3)
                frame_t_shape = landmarks[self.t_shape_idx, :]
                T, distance, itr = icp(frame_t_shape, self.anchor_t_shape)
                landmarks = np.hstack((landmarks, np.ones((68, 1))))
                registered_landmarks = np.dot(T, landmarks.T).T
                input_face_id = torch.tensor(registered_landmarks[:,
                                                                  0:3].reshape(
                                                                      1, 204),
                                             requires_grad=False,
                                             dtype=torch.float).to(device)

            for j in range(0, inputs_fl.shape[0], seg_bs):
                # Step 3.1: load segments
                inputs_fl_segments = inputs_fl[j:j + seg_bs]
                inputs_au_segments = inputs_au[j:j + seg_bs]
                inputs_emb_segments = inputs_emb[j:j + seg_bs]
                inputs_reg_fl_segments = inputs_reg_fl[j:j + seg_bs]
                inputs_rot_tran_segments = inputs_rot_tran[j:j + seg_bs]
                inputs_rot_quat_segments = inputs_rot_quat[j:j + seg_bs]

                if (inputs_fl_segments.shape[0] < 10):
                    continue

                if (self.opt_parser.test_emb):
                    input_face_id = self.std_face_id

                fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_reg_fls, loss_laplacian, loss_pos) = \
                    self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments,
                                                   input_face_id,  inputs_reg_fl_segments, inputs_rot_tran_segments,
                                                 inputs_rot_quat_segments,
                                                 is_training=is_training,
                                                   use_residual=self.opt_parser.use_residual)

                fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy()
                pos_pred = pos_pred.data.cpu().numpy()
                fl_std = inputs_reg_fl_segments[:, 0, :].data.cpu().numpy()
                pos_std = inputs_rot_tran_segments[:, 0, :].data.cpu().numpy()
                ''' solve inverse lip '''
                if (not is_training):
                    fl_dis_pred_pos = self.__solve_inverse_lip2__(
                        fl_dis_pred_pos)

                fl_dis_pred_pos = fl_dis_pred_pos.reshape((-1, 68, 3))
                fl_std = fl_std.reshape((-1, 68, 3))
                if (self.opt_parser.pos_dim == 12):
                    pos_pred = pos_pred.reshape((-1, 3, 4))
                    for k in range(fl_dis_pred_pos.shape[0]):
                        fl_dis_pred_pos[k] = np.dot(
                            pos_pred[k, :3, :3].T + np.eye(3),
                            (fl_dis_pred_pos[k] - pos_pred[k, :, 3].T).T).T
                    pos_std = pos_std.reshape((-1, 3, 4))
                    for k in range(fl_std.shape[0]):
                        fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3),
                                           (fl_std[k] -
                                            pos_std[k, :, 3].T).T).T
                else:
                    if (not is_training):
                        smooth_length = int(
                            min(pos_pred.shape[0] - 1, 27) // 2 * 2 + 1)
                        pos_pred = savgol_filter(pos_pred,
                                                 smooth_length,
                                                 3,
                                                 axis=0)
                    quat = pos_pred[:, :4]
                    trans = pos_pred[:, 4:]
                    for k in range(fl_dis_pred_pos.shape[0]):
                        fl_dis_pred_pos[k] = np.dot(
                            R.from_quat(quat[k]).as_matrix().T,
                            (fl_dis_pred_pos[k] - trans[k:k + 1]).T).T
                    pos_std = pos_std.reshape((-1, 3, 4))
                    for k in range(fl_std.shape[0]):
                        fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3),
                                           (fl_std[k] -
                                            pos_std[k, :, 3].T).T).T

                fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))]
                std_fls_list += [fl_std.reshape((-1, 204))]

                for key in log_loss.keys():
                    if (key not in locals().keys()):
                        continue
                    if (type(locals()[key]) == float):
                        log_loss[key].add(locals()[key])
                    else:
                        log_loss[key].add(locals()[key].data.cpu().numpy())

            if (epoch % self.opt_parser.jpg_freq == 0
                    and i in random_clip_index):

                def save_fls_av(fake_fls_list, postfix='', ifsmooth=True):
                    fake_fls_np = np.concatenate(fake_fls_list)
                    filename = 'fake_fls_{}_{}_{}.txt'.format(
                        epoch, video_name, postfix)
                    np.savetxt(os.path.join(self.opt_parser.dump_dir,
                                            '../nn_result',
                                            self.opt_parser.name, filename),
                               fake_fls_np,
                               fmt='%.6f')
                    audio_filename = '{:05d}_{}_audio.wav'.format(
                        global_id, video_name)
                    from util.vis import Vis_old
                    Vis_old(run_name=self.opt_parser.name,
                            pred_fl_filename=filename,
                            audio_filename=audio_filename,
                            fps=62.5,
                            av_name='e{:04d}_{}_{}'.format(epoch, i, postfix),
                            postfix=postfix,
                            root_dir=self.opt_parser.root_dir,
                            ifsmooth=ifsmooth)

                if (True):
                    if (self.opt_parser.show_animation):
                        print('show animation ....')
                        save_fls_av(fls_pred_pos_list, 'pred', ifsmooth=True)
                        save_fls_av(std_fls_list, 'std', ifsmooth=False)

            if (self.opt_parser.verbose <= 1):
                print('{} Epoch: #{} batch #{}/{}'.format(
                    status, epoch, i, len(dataloader)),
                      end=': ')
                for key in log_loss.keys():
                    print(key,
                          '{:.5f}'.format(log_loss[key].per('batch')),
                          end=', ')
                print('')
            self.__tensorboard_write__(status, log_loss, 'batch')

        if (self.opt_parser.verbose <= 2):
            print('==========================================================')
            print('{} Epoch: #{}'.format(status, epoch), end=':')
            for key in log_loss.keys():
                print(key,
                      '{:.4f}'.format(log_loss[key].per('epoch')),
                      end=', ')
            print(
                'Epoch time usage: {:.2f} sec\n==========================================================\n'
                .format(time.time() - st_epoch))
        self.__save_model__(save_type='last_epoch', epoch=epoch)
        if (epoch % self.opt_parser.ckpt_epoch_freq == 0):
            self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch)
        self.__tensorboard_write__(status, log_loss, 'epoch')