def norm_output_fls_rot(fl_data_i, anchor_t_shape=None): # fl_data_i = savgol_filter(fl_data_i, 21, 3, axis=0) t_shape_idx = (27, 28, 29, 30, 33, 36, 39, 42, 45) if (anchor_t_shape is None): anchor_t_shape = np.loadtxt( r'src/dataset/utils/ANCHOR_T_SHAPE_{}.txt'.format( len(t_shape_idx))) s = np.abs(anchor_t_shape[5, 0] - anchor_t_shape[8, 0]) anchor_t_shape = anchor_t_shape / s * 1.0 c2 = np.mean(anchor_t_shape[[4, 5, 8], :], axis=0) anchor_t_shape -= c2 else: anchor_t_shape = anchor_t_shape.reshape((68, 3)) anchor_t_shape = anchor_t_shape[t_shape_idx, :] fl_data_i = fl_data_i.reshape((-1, 68, 3)).copy() # get rot_mat rot_quats = [] rot_trans = [] for i in range(fl_data_i.shape[0]): line = fl_data_i[i] frame_t_shape = line[t_shape_idx, :] T, distance, itr = icp(frame_t_shape, anchor_t_shape) rot_mat = T[:3, :3] trans_mat = T[:3, 3:4] # norm to anchor fl_data_i[i] = np.dot(rot_mat, line.T).T + trans_mat.T # inverse (anchor -> reat_t) # tmp = np.dot(rot_mat.T, (anchor_t_shape - trans_mat.T).T).T r = R.from_matrix(rot_mat) rot_quats.append(r.as_quat()) # rot_eulers.append(r.as_euler('xyz')) rot_trans.append(T[:3, :]) rot_quats = np.array(rot_quats) rot_trans = np.array(rot_trans) return rot_trans, rot_quats, fl_data_i
def __train_pass__(self, epoch, log_loss, is_training=True): st_epoch = time.time() # Step 1: init setup if(is_training): self.C.train() data = self.train_data dataloader = self.train_dataloader status = 'TRAIN' else: self.C.eval() data = self.eval_data dataloader = self.eval_dataloader status = 'EVAL' random_clip_index = np.random.permutation(len(dataloader))[0:self.opt_parser.random_clip_num] print('random visualize clip index', random_clip_index) # Step 2: train for each batch for i, batch in enumerate(dataloader): global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4] inputs_fl, inputs_au = batch inputs_fl_ori, inputs_au_ori = inputs_fl.to(device), inputs_au.to(device) std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] seg_bs = 512 ''' pick a most closed lip frame from entire clip data ''' close_fl_list = inputs_fl_ori[::10, 0, :] idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy()) input_face_id = close_fl_list[idx:idx + 1, :] ''' register face ''' if (self.opt_parser.use_reg_as_std): landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3) frame_t_shape = landmarks[self.t_shape_idx, :] T, distance, itr = icp(frame_t_shape, self.anchor_t_shape) landmarks = np.hstack((landmarks, np.ones((68, 1)))) registered_landmarks = np.dot(T, landmarks.T).T input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False, dtype=torch.float).to(device) for in_batch in range(self.opt_parser.in_batch_nepoch): std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] if (is_training): rand_start = np.random.randint(0, inputs_fl_ori.shape[0] // 5, 1).reshape(-1) inputs_fl = inputs_fl_ori[rand_start[0]:] inputs_au = inputs_au_ori[rand_start[0]:] else: inputs_fl = inputs_fl_ori inputs_au = inputs_au_ori for j in range(0, inputs_fl.shape[0], seg_bs): # Step 3.1: load segments inputs_fl_segments = inputs_fl[j: j + seg_bs] inputs_au_segments = inputs_au[j: j + seg_bs] fl_std = inputs_fl_segments[:, 0, :].data.cpu().numpy() if(inputs_fl_segments.shape[0] < 10): continue fl_dis_pred_pos, input_face_id, loss = \ self.__train_content__(inputs_fl_segments, inputs_au_segments, input_face_id, is_training) fl_dis_pred_pos = (fl_dis_pred_pos + input_face_id).data.cpu().numpy() ''' solve inverse lip ''' fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos) fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))] std_fls_list += [fl_std.reshape((-1, 204))] for key in log_loss.keys(): if (key not in locals().keys()): continue if (type(locals()[key]) == float): log_loss[key].add(locals()[key]) else: log_loss[key].add(locals()[key].data.cpu().numpy()) if (epoch % self.opt_parser.jpg_freq == 0 and (i in random_clip_index or in_batch % self.opt_parser.jpg_freq == 1)): def save_fls_av(fake_fls_list, postfix='', ifsmooth=True): fake_fls_np = np.concatenate(fake_fls_list) filename = 'fake_fls_{}_{}_{}.txt'.format(epoch, video_name, postfix) np.savetxt( os.path.join(self.opt_parser.dump_dir, '../nn_result', self.opt_parser.name, filename), fake_fls_np, fmt='%.6f') audio_filename = '{:05d}_{}_audio.wav'.format(global_id, video_name) from util.vis import Vis_old Vis_old(run_name=self.opt_parser.name, pred_fl_filename=filename, audio_filename=audio_filename, fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, in_batch, postfix), postfix=postfix, root_dir=self.opt_parser.root_dir, ifsmooth=ifsmooth) if (self.opt_parser.show_animation and not is_training): print('show animation ....') save_fls_av(fls_pred_pos_list, 'pred_{}'.format(i), ifsmooth=True) save_fls_av(std_fls_list, 'std_{}'.format(i), ifsmooth=False) from util.vis import Vis_comp Vis_comp(run_name=self.opt_parser.name, pred1='fake_fls_{}_{}_{}.txt'.format(epoch, video_name, 'pred_{}'.format(i)), pred2='fake_fls_{}_{}_{}.txt'.format(epoch, video_name, 'std_{}'.format(i)), audio_filename='{:05d}_{}_audio.wav'.format(global_id, video_name), fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, in_batch, 'comp_{}'.format(i)), postfix='comp_{}'.format(i), root_dir=self.opt_parser.root_dir, ifsmooth=False) self.__save_model__(save_type='last_inbatch', epoch=epoch) if (self.opt_parser.verbose <= 1): print('{} Epoch: #{} batch #{}/{} inbatch #{}/{}'.format( status, epoch, i, len(dataloader), in_batch, self.opt_parser.in_batch_nepoch), end=': ') for key in log_loss.keys(): print(key, '{:.5f}'.format(log_loss[key].per('batch')), end=', ') print('') if (self.opt_parser.verbose <= 2): print('==========================================================') print('{} Epoch: #{}'.format(status, epoch), end=':') for key in log_loss.keys(): print(key, '{:.4f}'.format(log_loss[key].per('epoch')), end=', ') print( 'Epoch time usage: {:.2f} sec\n==========================================================\n'.format( time.time() - st_epoch)) self.__save_model__(save_type='last_epoch', epoch=epoch) if (epoch % self.opt_parser.ckpt_epoch_freq == 0): self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch)
def test_end2end(self, jpg_shape): self.G.eval() self.C.eval() data = self.eval_data dataloader = self.eval_dataloader for i, batch in enumerate(dataloader): global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4] inputs_fl, inputs_au, inputs_emb, inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = batch for key in ['irx71tYyI-Q', 'J-NPsvtQ8lE', 'Z7WRt--g-h4', 'E0zgrhQ0QDw', 'bXpavyiCu10', 'W6uRNCJmdtI', 'sxCbrYjBsGA', 'wAAMEC1OsRc', '_ldiVrXgZKc', '48uYS3bHIA8', 'E_kmpT-EfOg']: emb_val = self.test_embs[key] inputs_emb = np.tile(emb_val, (inputs_emb.shape[0], 1)) inputs_emb = torch.tensor(inputs_emb, dtype=torch.float, requires_grad=False) # this_emb = key # inputs_emb = torch.zeros(size=(inputs_au.shape[0], len(self.test_embs_dic.keys()))) # inputs_emb[:, self.test_embs_dic[this_emb]] = 1. inputs_fl, inputs_au, inputs_emb = inputs_fl.to(device), inputs_au.to(device), inputs_emb.to(device) inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = inputs_reg_fl.to(device), inputs_rot_tran.to(device), inputs_rot_quat.to(device) std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] seg_bs = self.opt_parser.segment_batch_size # input_face_id = self.std_face_id input_face_id = torch.tensor(jpg_shape.reshape(1, 204), requires_grad=False, dtype=torch.float).to(device) ''' register face ''' if (True): landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3) frame_t_shape = landmarks[self.t_shape_idx, :] T, distance, itr = icp(frame_t_shape, self.anchor_t_shape) landmarks = np.hstack((landmarks, np.ones((68, 1)))) registered_landmarks = np.dot(T, landmarks.T).T input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False, dtype=torch.float).to(device) for j in range(0, inputs_fl.shape[0], seg_bs): # Step 3.1: load segments inputs_fl_segments = inputs_fl[j: j + seg_bs] inputs_au_segments = inputs_au[j: j + seg_bs] inputs_emb_segments = inputs_emb[j: j + seg_bs] inputs_reg_fl_segments = inputs_reg_fl[j: j + seg_bs] inputs_rot_tran_segments = inputs_rot_tran[j: j + seg_bs] inputs_rot_quat_segments = inputs_rot_quat[j: j + seg_bs] if(inputs_fl_segments.shape[0] < 10): continue fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_reg_fls, loss_laplacian, loss_pos) = \ self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments, input_face_id, inputs_reg_fl_segments, inputs_rot_tran_segments, inputs_rot_quat_segments, is_training=False, use_residual=True) fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy() pos_pred = pos_pred.data.cpu().numpy() fl_std = inputs_reg_fl_segments[:, 0, :].data.cpu().numpy() pos_std = inputs_rot_tran_segments[:, 0, :].data.cpu().numpy() ''' solve inverse lip ''' fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos) fl_dis_pred_pos = fl_dis_pred_pos.reshape((-1, 68, 3)) fl_std = fl_std.reshape((-1, 68, 3)) if(self.opt_parser.pos_dim == 12): pos_pred = pos_pred.reshape((-1, 3, 4)) for k in range(fl_dis_pred_pos.shape[0]): fl_dis_pred_pos[k] = np.dot(pos_pred[k, :3, :3].T + np.eye(3), (fl_dis_pred_pos[k] - pos_pred[k, :, 3].T).T).T pos_std = pos_std.reshape((-1, 3, 4)) for k in range(fl_std.shape[0]): fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3), (fl_std[k] - pos_std[k, :, 3].T).T).T else: smooth_length = int(min(pos_pred.shape[0] - 1, 27) // 2 * 2 + 1) pos_pred = savgol_filter(pos_pred, smooth_length, 3, axis=0) quat = pos_pred[:, :4] trans = pos_pred[:, 4:] for k in range(fl_dis_pred_pos.shape[0]): fl_dis_pred_pos[k] = np.dot(R.from_quat(quat[k]).as_matrix().T, (fl_dis_pred_pos[k] - trans[k:k+1]).T).T pos_std = pos_std.reshape((-1, 3, 4)) for k in range(fl_std.shape[0]): fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3), (fl_std[k] - pos_std[k, :, 3].T).T).T fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))] std_fls_list += [fl_std.reshape((-1, 204))] fake_fls_np = np.concatenate(fls_pred_pos_list) filename = 'pred_fls_{}_{}.txt'.format(video_name.split('/')[-1], key) np.savetxt(os.path.join('examples', filename), fake_fls_np, fmt='%.6f')
def __train_pass__(self, epoch, log_loss, is_training=True): st_epoch = time.time() # Step 1: init setup if (is_training): self.G.train() data = self.train_data dataloader = self.train_dataloader status = 'TRAIN' else: self.G.eval() data = self.eval_data dataloader = self.eval_dataloader status = 'EVAL' # random_clip_index = np.random.randint(0, len(dataloader)-1, 4) # random_clip_index = np.random.randint(0, 64, 4) random_clip_index = list(range(len(dataloader))) # print('random_clip_index', random_clip_index) # Step 2: train for each batch for i, batch in enumerate(dataloader): # if(i>=512): # break st = time.time() global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4] # Step 2.1: load batch data from dataloader (in segments) inputs_fl, inputs_au = batch if (is_training): rand_start = np.random.randint(0, inputs_fl.shape[0] // 5, 1).reshape(-1) inputs_fl = inputs_fl[rand_start[0]:] inputs_au = inputs_au[rand_start[0]:] inputs_fl, inputs_au = inputs_fl.to(device), inputs_au.to(device) std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] seg_bs = self.opt_parser.segment_batch_size close_fl_list = inputs_fl[::10, 0, :] idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy()) input_face_id = close_fl_list[idx:idx + 1, :] ''' register face ''' if (self.opt_parser.use_reg_as_std): landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3) frame_t_shape = landmarks[self.t_shape_idx, :] T, distance, itr = icp(frame_t_shape, self.anchor_t_shape) landmarks = np.hstack((landmarks, np.ones((68, 1)))) registered_landmarks = np.dot(T, landmarks.T).T input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape(1, 204), requires_grad=False, dtype=torch.float).to(device) for j in range(0, inputs_fl.shape[0], seg_bs): # Step 3.1: load segments inputs_fl_segments = inputs_fl[j: j + seg_bs] inputs_au_segments = inputs_au[j: j + seg_bs] if(inputs_fl_segments.shape[0] < 10): continue if(self.opt_parser.test_emb): input_face_id = self.std_face_id fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_g, loss_laplacian) = \ self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, input_face_id, is_training=is_training) fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy() fl_std = inputs_fl_segments[:, 0, :].data.cpu().numpy() ''' solve inverse lip ''' if(not is_training): fl_dis_pred_pos = self.__solve_inverse_lip2__(fl_dis_pred_pos) fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))] std_fls_list += [fl_std.reshape((-1, 204))] for key in log_loss.keys(): if (key not in locals().keys()): continue if (type(locals()[key]) == float): log_loss[key].add(locals()[key]) else: log_loss[key].add(locals()[key].data.cpu().numpy()) if (epoch % 5 == 0): # and i in [0, 200, 400, 600, 800, 1000]): def save_fls_av(fake_fls_list, postfix='', ifsmooth=True): fake_fls_np = np.concatenate(fake_fls_list) filename = 'fake_fls_{}_{}_{}.txt'.format(epoch, video_name, postfix) np.savetxt( os.path.join(self.opt_parser.dump_dir, '../nn_result', self.opt_parser.name, filename), fake_fls_np, fmt='%.6f') # audio_filename = '{:05d}_{}_audio.wav'.format(global_id, video_name) # from util.vis import Vis_old # Vis_old(run_name=self.opt_parser.name, pred_fl_filename=filename, audio_filename=audio_filename, # fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, i, postfix), # postfix=postfix, root_dir=self.opt_parser.root_dir, ifsmooth=ifsmooth) if (True): if (self.opt_parser.show_animation): print('show animation ....') save_fls_av(fls_pred_pos_list, 'pred', ifsmooth=True) save_fls_av(std_fls_list, 'std', ifsmooth=False) if (self.opt_parser.verbose <= 1): print('{} Epoch: #{} batch #{}/{}'.format(status, epoch, i, len(dataloader)), end=': ') for key in log_loss.keys(): print(key, '{:.5f}'.format(log_loss[key].per('batch')), end=', ') print('') self.__tensorboard_write__(status, log_loss, 'batch') if (self.opt_parser.verbose <= 2): print('==========================================================') print('{} Epoch: #{}'.format(status, epoch), end=':') for key in log_loss.keys(): print(key, '{:.4f}'.format(log_loss[key].per('epoch')), end=', ') print('Epoch time usage: {:.2f} sec\n==========================================================\n'.format(time.time() - st_epoch)) self.__save_model__(save_type='last_epoch', epoch=epoch) if(epoch % 5 == 0): self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch) self.__tensorboard_write__(status, log_loss, 'epoch')
num_points=512) net = ICP().cuda() src = torch.tensor(pointcloud1).cuda() target = torch.tensor(pointcloud2).cuda() rotation_ab = torch.tensor(R_ab).cuda() translation_ab = torch.tensor(t_ab).cuda() batch_size = src.size(0) src, src_corr, rotation_ab_pred, translation_ab_pred, rotation_ba_pred, translation_ba_pred = net( src, target) identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1) loss = F.mse_loss( torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) + F.mse_loss(translation_ab_pred, translation_ab) print(loss) src_np = src.detach().cpu().numpy().squeeze() target_np = target.detach().cpu().numpy().squeeze() T, distances, iterations = icp(src_np.T, target_np.T, tolerance=0.000001) print(iterations) T1 = np.eye(4) T1[:3, :3] = R_ab.squeeze() T1[:3, 3] = t_ab T = torch.tensor(T, dtype=torch.float32).cuda() T1 = torch.tensor(T1, dtype=torch.float32).cuda() identity = torch.eye(3).cuda() loss = F.mse_loss(torch.matmul(T[:3, :3].transpose(1, 0), T1[:3, :3]), identity) + F.mse_loss(T[:3, 3], T1[:3, 3]) print(loss)
std_face_jaw[:, 8, :])**2, axis=1)) fls_jaw = fls.reshape((-1, 68, 3))[:, 0:17, 0:2] jaw_l = np.sqrt( np.sum((0.5 * (fls_jaw[:, 0, :] + fls_jaw[:, 16, :]) - fls_jaw[:, 8, :])**2, axis=1, keepdims=True)) scaled_face_jaw = np.tile(std_face_jaw, (fls_jaw.shape[0], 1, 1)) scaled_face_jaw[:, :, 1] *= (jaw_l / jaw_l_reference - 1.) * 0.4 + 1. for i in range(scaled_face_jaw.shape[0]): src = scaled_face_jaw[i] trg = fls_jaw[i] T, distance, itr = icp(src, trg) rot_mat = T[:2, :2] scaled_face_jaw[i] = (np.dot(rot_mat, src.T) + T[:2, 2:3]).T # fls[:, 17:, 0:2] *= 0.99 trg = fls.reshape((-1, 68, 3)) src = std_face.reshape((-1, 68, 3)) dis = - 1.0 * (0.5 * (scaled_face_jaw[:, 0, :] + scaled_face_jaw[:, 16, :]) - 0.5 * ( trg[:, 36, 0:2] + trg[:, 45, 0:2])) + \ 1.0 * (0.5 * (std_face_jaw[:, 0, :] + std_face_jaw[:, 16, :]) - 0.5 * (src[:, 36, 0:2] + src[:, 45, 0:2])) scaled_face_jaw = scaled_face_jaw + np.expand_dims(dis, axis=1) fls[:, 0:17, 0:2] = scaled_face_jaw fls[:, 0:17, 1] *= 1.02 if (DEMO_CH in [ 'paint', 'mulaney', 'cartoonM', 'beer', 'color', 'JohnMulaney', 'vangogh', 'jm', 'roy', 'lineface'
def __train_pass__(self, au_emb=None, centerize_face=False, no_y_rotation=False, vis_fls=False): # Step 1: init setup self.G.eval() self.C.eval() data = self.eval_data dataloader = self.eval_dataloader # Step 2: train for each batch for i, batch in enumerate(dataloader): global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4] # Step 2.1: load batch data from dataloader (in segments) inputs_fl, inputs_au, inputs_emb = batch keys = self.opt_parser.reuse_train_emb_list if (len(keys) == 0): keys = ['audio_embed'] for key in keys: # ['45hn7-LXDX8']: #['sxCbrYjBsGA']:# # load saved emb if (au_emb is None): emb_val = self.test_embs[key] else: emb_val = au_emb[i] inputs_emb = np.tile(emb_val, (inputs_emb.shape[0], 1)) inputs_emb = torch.tensor(inputs_emb, dtype=torch.float, requires_grad=False) inputs_fl, inputs_au, inputs_emb = inputs_fl.to( device), inputs_au.to(device), inputs_emb.to(device) std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] seg_bs = 512 for j in range(0, inputs_fl.shape[0], seg_bs): # Step 3.1: load segments inputs_fl_segments = inputs_fl[j:j + seg_bs] inputs_au_segments = inputs_au[j:j + seg_bs] inputs_emb_segments = inputs_emb[j:j + seg_bs] if (inputs_fl_segments.shape[0] < 10): continue input_face_id = self.std_face_id fl_dis_pred_pos, input_face_id = \ self.__train_face_and_pos__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments, input_face_id) fl_dis_pred_pos = (fl_dis_pred_pos + input_face_id).data.cpu().numpy() ''' solve inverse lip ''' fl_dis_pred_pos = self.__solve_inverse_lip2__( fl_dis_pred_pos) fls_pred_pos_list += [fl_dis_pred_pos] fake_fls_np = np.concatenate(fls_pred_pos_list) # revise nose top point fake_fls_np[:, 27 * 3:28 * 3] = fake_fls_np[:, 28 * 3:29 * 3] * 2 - fake_fls_np[:, 29 * 3:30 * 3] # fake_fls_np[:, 48*3+1::3] += 0.1 # smooth from scipy.signal import savgol_filter fake_fls_np = savgol_filter(fake_fls_np, 5, 3, axis=0) if (centerize_face): std_m = np.mean( self.std_face_id.detach().cpu().numpy().reshape( (1, 68, 3)), axis=1, keepdims=True) fake_fls_np = fake_fls_np.reshape((-1, 68, 3)) fake_fls_np = fake_fls_np - np.mean( fake_fls_np, axis=1, keepdims=True) + std_m fake_fls_np = fake_fls_np.reshape((-1, 68 * 3)) if (no_y_rotation): std = self.std_face_id.detach().cpu().numpy().reshape( 68, 3) std_t_shape = std[self.t_shape_idx, :] fake_fls_np = fake_fls_np.reshape( (fake_fls_np.shape[0], 68, 3)) frame_t_shape = fake_fls_np[:, self.t_shape_idx, :] from util.icp import icp from scipy.spatial.transform import Rotation as R for i in range(frame_t_shape.shape[0]): T, distance, itr = icp(frame_t_shape[i], std_t_shape) landmarks = np.hstack((frame_t_shape[i], np.ones( (9, 1)))) rot_mat = T[:3, :3] r = R.from_dcm(rot_mat).as_euler('xyz') r = [0., r[1], r[2]] r = R.from_euler('xyz', r).as_dcm() # print(frame_t_shape[i, 0], r) landmarks = np.hstack( (fake_fls_np[i] - T[:3, 3:4].T, np.ones((68, 1)))) T2 = np.hstack((r, T[:3, 3:4])) fake_fls_np[i] = np.dot(T2, landmarks.T).T # print(frame_t_shape[i, 0]) fake_fls_np = fake_fls_np.reshape((-1, 68 * 3)) filename = 'pred_fls_{}_{}.txt'.format( video_name.split('/')[-1], key) np.savetxt(os.path.join(self.opt_parser.output_folder, filename), fake_fls_np, fmt='%.6f') # ''' Visualize result in landmarks ''' if (vis_fls): from util.vis import Vis Vis(fls=fake_fls_np, filename=video_name.split('/')[-1], fps=62.5, audio_filenam='examples/' + video_name.split('/')[-1] + '.wav')
def __train_pass__(self, epoch, log_loss, is_training=True): st_epoch = time.time() # Step 1: init setup if (is_training): self.G.train() self.C.train() data = self.train_data dataloader = self.train_dataloader status = 'TRAIN' else: self.G.eval() self.C.eval() data = self.eval_data dataloader = self.eval_dataloader status = 'EVAL' # random_clip_index = np.random.randint(0, len(dataloader)-1, 4) # random_clip_index = np.random.randint(0, 64, 4) random_clip_index = list(range(len(dataloader))) print('random_clip_index', random_clip_index) # Step 2: train for each batch for i, batch in enumerate(dataloader): # if(i>=64): # break st = time.time() global_id, video_name = data[i][0][1][0], data[i][0][1][1][:-4] # Step 2.1: load batch data from dataloader (in segments) inputs_fl, inputs_au, inputs_emb, inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = batch # inputs_emb = torch.zeros(size=(inputs_au.shape[0], len(self.test_embs_dic.keys()))) # this_emb = video_name.split('_x_')[1] # inputs_emb[:, self.test_embs_dic[this_emb]] = 1. if (is_training): rand_start = np.random.randint(0, inputs_fl.shape[0] // 5, 1).reshape(-1) inputs_fl = inputs_fl[rand_start[0]:] inputs_au = inputs_au[rand_start[0]:] inputs_emb = inputs_emb[rand_start[0]:] inputs_reg_fl = inputs_reg_fl[rand_start[0]:] inputs_rot_tran = inputs_rot_tran[rand_start[0]:] inputs_rot_quat = inputs_rot_quat[rand_start[0]:] inputs_fl, inputs_au, inputs_emb = inputs_fl.to( device), inputs_au.to(device), inputs_emb.to(device) inputs_reg_fl, inputs_rot_tran, inputs_rot_quat = inputs_reg_fl.to( device), inputs_rot_tran.to(device), inputs_rot_quat.to(device) std_fls_list, fls_pred_face_id_list, fls_pred_pos_list = [], [], [] seg_bs = self.opt_parser.segment_batch_size close_fl_list = inputs_fl[::10, 0, :] idx = self.__close_face_lip__(close_fl_list.detach().cpu().numpy()) input_face_id = close_fl_list[idx:idx + 1, :] ''' register face ''' if (self.opt_parser.use_reg_as_std): landmarks = input_face_id.detach().cpu().numpy().reshape(68, 3) frame_t_shape = landmarks[self.t_shape_idx, :] T, distance, itr = icp(frame_t_shape, self.anchor_t_shape) landmarks = np.hstack((landmarks, np.ones((68, 1)))) registered_landmarks = np.dot(T, landmarks.T).T input_face_id = torch.tensor(registered_landmarks[:, 0:3].reshape( 1, 204), requires_grad=False, dtype=torch.float).to(device) for j in range(0, inputs_fl.shape[0], seg_bs): # Step 3.1: load segments inputs_fl_segments = inputs_fl[j:j + seg_bs] inputs_au_segments = inputs_au[j:j + seg_bs] inputs_emb_segments = inputs_emb[j:j + seg_bs] inputs_reg_fl_segments = inputs_reg_fl[j:j + seg_bs] inputs_rot_tran_segments = inputs_rot_tran[j:j + seg_bs] inputs_rot_quat_segments = inputs_rot_quat[j:j + seg_bs] if (inputs_fl_segments.shape[0] < 10): continue if (self.opt_parser.test_emb): input_face_id = self.std_face_id fl_dis_pred_pos, pos_pred, input_face_id, (loss, loss_reg_fls, loss_laplacian, loss_pos) = \ self.__train_speaker_aware__(inputs_fl_segments, inputs_au_segments, inputs_emb_segments, input_face_id, inputs_reg_fl_segments, inputs_rot_tran_segments, inputs_rot_quat_segments, is_training=is_training, use_residual=self.opt_parser.use_residual) fl_dis_pred_pos = fl_dis_pred_pos.data.cpu().numpy() pos_pred = pos_pred.data.cpu().numpy() fl_std = inputs_reg_fl_segments[:, 0, :].data.cpu().numpy() pos_std = inputs_rot_tran_segments[:, 0, :].data.cpu().numpy() ''' solve inverse lip ''' if (not is_training): fl_dis_pred_pos = self.__solve_inverse_lip2__( fl_dis_pred_pos) fl_dis_pred_pos = fl_dis_pred_pos.reshape((-1, 68, 3)) fl_std = fl_std.reshape((-1, 68, 3)) if (self.opt_parser.pos_dim == 12): pos_pred = pos_pred.reshape((-1, 3, 4)) for k in range(fl_dis_pred_pos.shape[0]): fl_dis_pred_pos[k] = np.dot( pos_pred[k, :3, :3].T + np.eye(3), (fl_dis_pred_pos[k] - pos_pred[k, :, 3].T).T).T pos_std = pos_std.reshape((-1, 3, 4)) for k in range(fl_std.shape[0]): fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3), (fl_std[k] - pos_std[k, :, 3].T).T).T else: if (not is_training): smooth_length = int( min(pos_pred.shape[0] - 1, 27) // 2 * 2 + 1) pos_pred = savgol_filter(pos_pred, smooth_length, 3, axis=0) quat = pos_pred[:, :4] trans = pos_pred[:, 4:] for k in range(fl_dis_pred_pos.shape[0]): fl_dis_pred_pos[k] = np.dot( R.from_quat(quat[k]).as_matrix().T, (fl_dis_pred_pos[k] - trans[k:k + 1]).T).T pos_std = pos_std.reshape((-1, 3, 4)) for k in range(fl_std.shape[0]): fl_std[k] = np.dot(pos_std[k, :3, :3].T + np.eye(3), (fl_std[k] - pos_std[k, :, 3].T).T).T fls_pred_pos_list += [fl_dis_pred_pos.reshape((-1, 204))] std_fls_list += [fl_std.reshape((-1, 204))] for key in log_loss.keys(): if (key not in locals().keys()): continue if (type(locals()[key]) == float): log_loss[key].add(locals()[key]) else: log_loss[key].add(locals()[key].data.cpu().numpy()) if (epoch % self.opt_parser.jpg_freq == 0 and i in random_clip_index): def save_fls_av(fake_fls_list, postfix='', ifsmooth=True): fake_fls_np = np.concatenate(fake_fls_list) filename = 'fake_fls_{}_{}_{}.txt'.format( epoch, video_name, postfix) np.savetxt(os.path.join(self.opt_parser.dump_dir, '../nn_result', self.opt_parser.name, filename), fake_fls_np, fmt='%.6f') audio_filename = '{:05d}_{}_audio.wav'.format( global_id, video_name) from util.vis import Vis_old Vis_old(run_name=self.opt_parser.name, pred_fl_filename=filename, audio_filename=audio_filename, fps=62.5, av_name='e{:04d}_{}_{}'.format(epoch, i, postfix), postfix=postfix, root_dir=self.opt_parser.root_dir, ifsmooth=ifsmooth) if (True): if (self.opt_parser.show_animation): print('show animation ....') save_fls_av(fls_pred_pos_list, 'pred', ifsmooth=True) save_fls_av(std_fls_list, 'std', ifsmooth=False) if (self.opt_parser.verbose <= 1): print('{} Epoch: #{} batch #{}/{}'.format( status, epoch, i, len(dataloader)), end=': ') for key in log_loss.keys(): print(key, '{:.5f}'.format(log_loss[key].per('batch')), end=', ') print('') self.__tensorboard_write__(status, log_loss, 'batch') if (self.opt_parser.verbose <= 2): print('==========================================================') print('{} Epoch: #{}'.format(status, epoch), end=':') for key in log_loss.keys(): print(key, '{:.4f}'.format(log_loss[key].per('epoch')), end=', ') print( 'Epoch time usage: {:.2f} sec\n==========================================================\n' .format(time.time() - st_epoch)) self.__save_model__(save_type='last_epoch', epoch=epoch) if (epoch % self.opt_parser.ckpt_epoch_freq == 0): self.__save_model__(save_type='e_{}'.format(epoch), epoch=epoch) self.__tensorboard_write__(status, log_loss, 'epoch')