def get_quats_sos_and_eos(): quats_sos_and_eos_file = os.path.join(data_path, 'quats_sos_and_eos.npz') keys = list(self.data_loader['train'].keys()) num_samples = len(self.data_loader['train']) try: mean_quats_sos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_sos'] mean_quats_eos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_eos'] except FileNotFoundError: mean_quats_sos = np.zeros((self.V, self.D)) mean_quats_eos = np.zeros((self.V, self.D)) for j in range(self.V): quats_sos = np.zeros((self.D, num_samples)) quats_eos = np.zeros((self.D, num_samples)) for s in range(num_samples): quats_sos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][0, j] quats_eos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][-1, j] _, sos_eig_vectors = np.linalg.eig( np.dot(quats_sos, quats_sos.T)) mean_quats_sos[j] = sos_eig_vectors[:, 0] _, eos_eig_vectors = np.linalg.eig( np.dot(quats_eos, quats_eos.T)) mean_quats_eos[j] = eos_eig_vectors[:, 0] np.savez_compressed(quats_sos_and_eos_file, quats_sos=mean_quats_sos, quats_eos=mean_quats_eos) mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0) mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0) for s in range(num_samples): pos_sos = \ MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][0:1, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_sos = MocapDataset.get_mpi_affective_features(pos_sos) pos_eos = \ MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][-1:, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_eos = MocapDataset.get_mpi_affective_features(pos_eos) self.data_loader['train'][keys[s]]['positions'] = \ np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0) self.data_loader['train'][keys[s]]['affective_features'] = \ np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos), axis=0) return mean_quats_sos, mean_quats_eos
def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred): # pos_next = torch.zeros_like(pos_last) offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \ unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1) pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0], quat_next.shape[1], -1, self.D), pos_last[:, :, 0], self.joint_parents, torch.from_numpy(self.joint_offsets).float().cuda()) # for joint in range(1, self.V): # pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \ # + pos_next[:, :, self.mocap.joint_parents[joint]] root = pos_next[:, :, 0] l_shoulder = pos_next[:, :, 18] r_shoulder = pos_next[:, :, 25] facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]] facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9) return rs_pred * facing + pos_last[:, :, 0, [0, 2]]
def generate_motion(self, load_saved_model=True, samples_to_generate=10, epoch=None, randomized=True, animations_as_videos=False): if epoch is None: epoch = 'best' if load_saved_model: self.load_model_at_epoch(epoch=epoch) self.model.eval() test_loader = self.data_loader['test'] joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized) with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) for s in range(len(quat_pred)): quat_pred[s] = qfix( quat_pred[s].view(quat_pred[s].shape[0], self.V, -1)).view(quat_pred[s].shape[0], -1) root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) animation_pred = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos_pred, 'rotations': quat_pred } MocapDataset.save_as_bvh(animation_pred, dataset_name=self.dataset + '_glove', subset_name='test') animation = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos, 'rotations': quat } MocapDataset.save_as_bvh(animation, dataset_name=self.dataset + '_glove', subset_name='gt') if animations_as_videos: pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy() display_animations(pos_pred_np, self.joint_parents, save=True, dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), overwrite=True)
def per_eval(self): self.model.eval() test_loader = self.data_loader['test'] eval_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue in self.yield_batch(self.args.batch_size, test_loader): with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg * torch.mean( # (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight * \ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg * \ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss N += quat.shape[0] eval_loss /= N self.epoch_info['mean_loss'] = eval_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[ 'epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info()
def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx,\ text, text_valid_idx, intended_emotion, intended_polarity,\ acting_task, gender, age, handedness,\ native_tongue in self.yield_batch(self.args.batch_size, train_loader): quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() self.optimizer.zero_grad() with torch.autograd.detect_anomaly(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V # row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg *\ # torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight *\ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg *\ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss train_loss.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # animation_pred = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos_pred, # 'rotations': quat_pred # } # MocapDataset.save_as_bvh(animation_pred, # dataset_name=self.dataset, # subset_name='test') # animation = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos, # 'rotations': quat # } # MocapDataset.save_as_bvh(animation, # dataset_name=self.dataset, # subset_name='gt') # Compute statistics batch_loss += train_loss.item() N += quat.shape[0] # statistics self.iter_info['loss'] = train_loss.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf()
def generate_motion(self, load_saved_model=True, samples_to_generate=10, randomized=True, epoch=None, animations_as_videos=False): if epoch is None: epoch = 'best' if load_saved_model: self.load_model_at_epoch(epoch=epoch) self.model.eval() test_loader = self.data_loader['test'] start_time = time.time() joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, perceived_emotion, perceived_polarity, \ acting_task, gender, age, handedness, \ native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized) with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_pred = torch.zeros_like(quat) quat_pred[:, 0] = torch.cat( quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:, 0].shape) quat_pred, quat_pred_pre_norm = self.model( text, perceived_emotion, perceived_polarity, acting_task, gender, age, handedness, native_tongue, quat[:, :-1], joint_lengths / scales[..., None]) # text_latent = self.model(text, intended_emotion, intended_polarity, # acting_task, gender, age, handedness, native_tongue, only_encoder=True) # for t in range(1, self.T): # quat_pred_curr, _ = self.model(text_latent, quat=quat_pred[:, 0:t], # offset_lengths=joint_lengths / scales[..., None], # only_decoder=True) # quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone() # for s in range(len(quat_pred)): # quat_pred[s] = qfix(quat_pred[s].view(quat_pred[s].shape[0], # self.V, -1)).view(quat_pred[s].shape[0], -1) quat_pred = torch.cat((quat[:, 1:2], quat_pred), dim=1) quat_pred = qfix( quat_pred.view(quat_pred.shape[0], quat_pred.shape[1], self.V, -1)).view(quat_pred.shape[0], quat_pred.shape[1], -1) quat_pred = quat_pred[:, 1:] root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) quat_np = quat.detach().cpu().numpy() quat_pred_np = quat_pred.detach().cpu().numpy() animation_pred = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos_pred, 'rotations': quat_pred, 'valid_idx': quat_valid_idx } MocapDataset.save_as_bvh(animation_pred, dataset_name=self.dataset, subset_name='test_epoch_{}'.format(epoch), include_default_pose=False) end_time = time.time() print('Time taken: {} secs.'.format(end_time - start_time)) shifted_pos = pos - pos[:, :, 0:1] animation = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': shifted_pos, 'rotations': quat, 'valid_idx': quat_valid_idx } MocapDataset.save_as_bvh(animation, dataset_name=self.dataset, subset_name='gt', include_default_pose=False) if animations_as_videos: pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy() display_animations(pos_pred_np, self.joint_parents, save=True, dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), overwrite=True)
def forward_pass(self, joint_offsets, pos, affs, quat, quat_sos, quat_valid_idx, text, text_valid_idx, perceived_emotion, perceived_polarity, acting_task, gender, age, handedness, native_tongue): self.optimizer.zero_grad() with torch.autograd.detect_anomaly(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) # text_latent = self.model(text, perceived_emotion, perceived_polarity, # acting_task, gender, age, handedness, native_tongue, # only_encoder=True) # quat_pred = torch.zeros_like(quat) # quat_pred_pre_norm = torch.zeros_like(quat) # quat_in = quat_sos[:, :self.T_steps] # quat_valid_idx_max = torch.max(torch.sum(quat_valid_idx, dim=-1)) # for t in range(0, self.T, self.T_steps): # if t > quat_valid_idx_max: # break # quat_pred[:, t:min(self.T, t + self.T_steps)],\ # quat_pred_pre_norm[:, t:min(self.T, t + self.T_steps)] =\ # self.model(text_latent, quat=quat_in, offset_lengths= joint_lengths / scales[..., None], # only_decoder=True) # if torch.rand(1) > self.tf: # if t + self.T_steps * 2 >= self.T: # quat_in = quat_pred[:, -(self.T - t - self.T_steps):].clone() # else: # quat_in = quat_pred[:, t:t + self.T_steps].clone() # else: # if t + self.T_steps * 2 >= self.T: # quat_in = quat[:, -(self.T - t - self.T_steps):].clone() # else: # quat_in = quat[:, t:min(self.T, t + self.T_steps)].clone() quat_pred, quat_pred_pre_norm = self.model( text, perceived_emotion, perceived_polarity, acting_task, gender, age, handedness, native_tongue, quat_sos[:, :-1], joint_lengths / scales[..., None]) quat_fixed = qfix(quat.contiguous().view( quat.shape[0], quat.shape[1], -1, self.D)).contiguous().view(quat.shape[0], quat.shape[1], -1) quat_pred = qfix(quat_pred.contiguous().view( quat_pred.shape[0], quat_pred.shape[1], -1, self.D)).contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) # quat_loss, quat_derv_loss = losses.quat_angle_loss(quat_pred, quat_fixed, # quat_valid_idx[:, 1:], # self.V, self.D, # self.lower_body_start, # self.args.upper_body_weight) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat_fixed[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] # recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos) # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, :, 7:15]) recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, 1:, 7:15]) # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_loss[:, :, self.lower_body_start:].sum(-1) recons_loss = self.args.recons_reg *\ torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight *\ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg *\ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) affs_loss = self.args.affs_reg * torch.mean( (affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # affs_loss = self.affs_loss_func(affs_pred, affs) total_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss return total_loss
def generate_motion(self, load_saved_model=True, samples_to_generate=10, max_steps=300, randomized=True): if load_saved_model: self.load_best_model() self.model.eval() test_loader = self.data_loader['test'] joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized) with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_pred = torch.zeros_like(quat) quat_pred_pre_norm = torch.zeros_like(quat) quat_pred[:, 0] = torch.cat( quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:, 0].shape) text_latent = self.model(text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, only_encoder=True) for t in range(1, self.T): quat_pred_curr, quat_pred_pre_norm_curr = \ self.model(text_latent, quat=quat_pred[:, 0:t], offset_lengths=joint_lengths / scales[..., None], only_decoder=True) quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone() quat_pred_pre_norm_curr[:, t:t + 1] = quat_pred_pre_norm_curr[:, -1:].clone( ) for s in range(len(quat_pred)): quat_pred[s] = qfix( quat_pred[s].view(quat_pred[s].shape[0], self.V, -1)).view(quat_pred[s].shape[0], -1) root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) animation_pred = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos_pred, 'rotations': quat_pred } MocapDataset.save_as_bvh(animation_pred, dataset_name=self.dataset + '_new', subset_name='test') animation = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos, 'rotations': quat } MocapDataset.save_as_bvh(animation, dataset_name=self.dataset + '_new', subset_name='gt') pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy()