Exemplo n.º 1
0
 def get_quats_sos_and_eos():
     quats_sos_and_eos_file = os.path.join(data_path,
                                           'quats_sos_and_eos.npz')
     keys = list(self.data_loader['train'].keys())
     num_samples = len(self.data_loader['train'])
     try:
         mean_quats_sos = np.load(quats_sos_and_eos_file,
                                  allow_pickle=True)['quats_sos']
         mean_quats_eos = np.load(quats_sos_and_eos_file,
                                  allow_pickle=True)['quats_eos']
     except FileNotFoundError:
         mean_quats_sos = np.zeros((self.V, self.D))
         mean_quats_eos = np.zeros((self.V, self.D))
         for j in range(self.V):
             quats_sos = np.zeros((self.D, num_samples))
             quats_eos = np.zeros((self.D, num_samples))
             for s in range(num_samples):
                 quats_sos[:, s] = self.data_loader['train'][
                     keys[s]]['rotations'][0, j]
                 quats_eos[:, s] = self.data_loader['train'][
                     keys[s]]['rotations'][-1, j]
             _, sos_eig_vectors = np.linalg.eig(
                 np.dot(quats_sos, quats_sos.T))
             mean_quats_sos[j] = sos_eig_vectors[:, 0]
             _, eos_eig_vectors = np.linalg.eig(
                 np.dot(quats_eos, quats_eos.T))
             mean_quats_eos[j] = eos_eig_vectors[:, 0]
         np.savez_compressed(quats_sos_and_eos_file,
                             quats_sos=mean_quats_sos,
                             quats_eos=mean_quats_eos)
     mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0)
     mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0)
     for s in range(num_samples):
         pos_sos = \
             MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0),
                                             torch.from_numpy(self.data_loader['train'][keys[s]]
                                             ['positions'][0:1, 0]).double().unsqueeze(0),
                                             self.joint_parents,
                                             torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                             ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
         affs_sos = MocapDataset.get_mpi_affective_features(pos_sos)
         pos_eos = \
             MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0),
                                             torch.from_numpy(self.data_loader['train'][keys[s]]
                                             ['positions'][-1:, 0]).double().unsqueeze(0),
                                             self.joint_parents,
                                             torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                             ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
         affs_eos = MocapDataset.get_mpi_affective_features(pos_eos)
         self.data_loader['train'][keys[s]]['positions'] = \
             np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0)
         self.data_loader['train'][keys[s]]['affective_features'] = \
             np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos),
                            axis=0)
     return mean_quats_sos, mean_quats_eos
Exemplo n.º 2
0
 def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred):
     # pos_next = torch.zeros_like(pos_last)
     offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \
         unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1)
     pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0],
                                                                            quat_next.shape[1], -1, self.D),
                                                pos_last[:, :, 0],
                                                self.joint_parents,
                                                torch.from_numpy(self.joint_offsets).float().cuda())
     # for joint in range(1, self.V):
     #     pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \
     #                             + pos_next[:, :, self.mocap.joint_parents[joint]]
     root = pos_next[:, :, 0]
     l_shoulder = pos_next[:, :, 18]
     r_shoulder = pos_next[:, :, 25]
     facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
     facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
     return rs_pred * facing + pos_last[:, :, 0, [0, 2]]
Exemplo n.º 3
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        epoch=None,
                        randomized=True,
                        animations_as_videos=False):

        if epoch is None:
            epoch = 'best'
        if load_saved_model:
            self.load_model_at_epoch(epoch=epoch)
        self.model.eval()
        test_loader = self.data_loader['test']

        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, intended_emotion, intended_polarity, acting_task, gender,
                age, handedness, native_tongue, quat_prelude,
                joint_lengths / scales[..., None])
            for s in range(len(quat_pred)):
                quat_pred[s] = qfix(
                    quat_pred[s].view(quat_pred[s].shape[0], self.V,
                                      -1)).view(quat_pred[s].shape[0], -1)

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='test')
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos,
            'rotations': quat
        }
        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='gt')

        if animations_as_videos:
            pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                     pos_pred.shape[1], -1).permute(0, 2, 1).\
                detach().cpu().numpy()
            display_animations(pos_pred_np,
                               self.joint_parents,
                               save=True,
                               dataset_name=self.dataset,
                               subset_name='epoch_' +
                               str(self.best_loss_epoch),
                               overwrite=True)
Exemplo n.º 4
0
    def per_eval(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        eval_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
                native_tongue in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_prelude = self.quats_eos.view(1, -1).cuda() \
                    .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
                quat_prelude[:, -1] = quat[:, 0].clone()
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])
                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                row_sums = quat_valid_idx.sum(1,
                                              keepdim=True) * self.D * self.V
                row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg * torch.mean(
                #     (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight * \
                #                    (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg * \
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss
                # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                N += quat.shape[0]

        eval_loss /= N
        self.epoch_info['mean_loss'] = eval_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()
Exemplo n.º 5
0
    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx,\
            text, text_valid_idx, intended_emotion, intended_polarity,\
                acting_task, gender, age, handedness,\
                native_tongue in self.yield_batch(self.args.batch_size, train_loader):
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()

            self.optimizer.zero_grad()
            with torch.autograd.detect_anomaly():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])

                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
                # row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg *\
                #               torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight *\
                #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg *\
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
                # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                train_loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
                self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred,
            #     'rotations': quat_pred
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')
            # animation = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos,
            #     'rotations': quat
            # }
            # MocapDataset.save_as_bvh(animation,
            #                          dataset_name=self.dataset,
            #                          subset_name='gt')

            # Compute statistics
            batch_loss += train_loss.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = train_loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()
Exemplo n.º 6
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        randomized=True,
                        epoch=None,
                        animations_as_videos=False):

        if epoch is None:
            epoch = 'best'
        if load_saved_model:
            self.load_model_at_epoch(epoch=epoch)
        self.model.eval()
        test_loader = self.data_loader['test']

        start_time = time.time()
        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, perceived_emotion, perceived_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_pred = torch.zeros_like(quat)
            quat_pred[:, 0] = torch.cat(
                quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:,
                                                                      0].shape)

            quat_pred, quat_pred_pre_norm = self.model(
                text, perceived_emotion, perceived_polarity, acting_task,
                gender, age, handedness, native_tongue, quat[:, :-1],
                joint_lengths / scales[..., None])
            # text_latent = self.model(text, intended_emotion, intended_polarity,
            #                          acting_task, gender, age, handedness, native_tongue, only_encoder=True)
            # for t in range(1, self.T):
            #     quat_pred_curr, _ = self.model(text_latent, quat=quat_pred[:, 0:t],
            #                                    offset_lengths=joint_lengths / scales[..., None],
            #                                    only_decoder=True)
            #     quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone()

            # for s in range(len(quat_pred)):
            #     quat_pred[s] = qfix(quat_pred[s].view(quat_pred[s].shape[0],
            #                                           self.V, -1)).view(quat_pred[s].shape[0], -1)
            quat_pred = torch.cat((quat[:, 1:2], quat_pred), dim=1)
            quat_pred = qfix(
                quat_pred.view(quat_pred.shape[0], quat_pred.shape[1], self.V,
                               -1)).view(quat_pred.shape[0],
                                         quat_pred.shape[1], -1)
            quat_pred = quat_pred[:, 1:]

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        quat_np = quat.detach().cpu().numpy()
        quat_pred_np = quat_pred.detach().cpu().numpy()

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred,
            'valid_idx': quat_valid_idx
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset,
                                 subset_name='test_epoch_{}'.format(epoch),
                                 include_default_pose=False)
        end_time = time.time()
        print('Time taken: {} secs.'.format(end_time - start_time))
        shifted_pos = pos - pos[:, :, 0:1]
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': shifted_pos,
            'rotations': quat,
            'valid_idx': quat_valid_idx
        }

        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset,
                                 subset_name='gt',
                                 include_default_pose=False)

        if animations_as_videos:
            pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                     pos_pred.shape[1], -1).permute(0, 2, 1).\
                detach().cpu().numpy()
            display_animations(pos_pred_np,
                               self.joint_parents,
                               save=True,
                               dataset_name=self.dataset,
                               subset_name='epoch_' +
                               str(self.best_loss_epoch),
                               overwrite=True)
Exemplo n.º 7
0
    def forward_pass(self, joint_offsets, pos, affs, quat, quat_sos,
                     quat_valid_idx, text, text_valid_idx, perceived_emotion,
                     perceived_polarity, acting_task, gender, age, handedness,
                     native_tongue):
        self.optimizer.zero_grad()
        with torch.autograd.detect_anomaly():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            # text_latent = self.model(text, perceived_emotion, perceived_polarity,
            #                          acting_task, gender, age, handedness, native_tongue,
            #                          only_encoder=True)
            # quat_pred = torch.zeros_like(quat)
            # quat_pred_pre_norm = torch.zeros_like(quat)
            # quat_in = quat_sos[:, :self.T_steps]
            # quat_valid_idx_max = torch.max(torch.sum(quat_valid_idx, dim=-1))
            # for t in range(0, self.T, self.T_steps):
            #     if t > quat_valid_idx_max:
            #         break
            #     quat_pred[:, t:min(self.T, t + self.T_steps)],\
            #         quat_pred_pre_norm[:, t:min(self.T, t + self.T_steps)] =\
            #         self.model(text_latent, quat=quat_in, offset_lengths= joint_lengths / scales[..., None],
            #                    only_decoder=True)
            #     if torch.rand(1) > self.tf:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat_pred[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat_pred[:, t:t + self.T_steps].clone()
            #     else:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat[:, t:min(self.T, t + self.T_steps)].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, perceived_emotion, perceived_polarity, acting_task,
                gender, age, handedness, native_tongue, quat_sos[:, :-1],
                joint_lengths / scales[..., None])
            quat_fixed = qfix(quat.contiguous().view(
                quat.shape[0], quat.shape[1], -1,
                self.D)).contiguous().view(quat.shape[0], quat.shape[1], -1)
            quat_pred = qfix(quat_pred.contiguous().view(
                quat_pred.shape[0], quat_pred.shape[1], -1,
                self.D)).contiguous().view(quat_pred.shape[0],
                                           quat_pred.shape[1], -1)

            quat_pred_pre_norm = quat_pred_pre_norm.view(
                quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1,
                self.D)
            quat_norm_loss = self.args.quat_norm_reg *\
                torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

            # quat_loss, quat_derv_loss = losses.quat_angle_loss(quat_pred, quat_fixed,
            #                                                    quat_valid_idx[:, 1:],
            #                                                    self.V, self.D,
            #                                                    self.lower_body_start,
            #                                                    self.args.upper_body_weight)
            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred, quat_fixed[:, 1:], quat_valid_idx[:, 1:], self.V,
                self.D, self.lower_body_start, self.args.upper_body_weight)
            quat_loss *= self.args.quat_reg

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))
            affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

            row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
            row_sums[row_sums == 0.] = 1.

            shifted_pos = pos - pos[:, :, 0:1]
            shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

            # recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos)
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, :, 7:15])
            recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                shifted_pos[:, 1:])
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, 1:, 7:15])
            # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
            # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #               recons_loss[:, :, self.lower_body_start:].sum(-1)
            recons_loss = self.args.recons_reg *\
                torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            #
            # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
            #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
            # recons_derv_loss = self.args.upper_body_weight *\
            #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
            # recons_derv_loss = 2. * self.args.recons_reg *\
            #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
            #
            # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
            affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])
            affs_loss = self.args.affs_reg * torch.mean(
                (affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            # affs_loss = self.affs_loss_func(affs_pred, affs)

            total_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
            # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss

        return total_loss
Exemplo n.º 8
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        max_steps=300,
                        randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_pred = torch.zeros_like(quat)
            quat_pred_pre_norm = torch.zeros_like(quat)
            quat_pred[:, 0] = torch.cat(
                quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:,
                                                                      0].shape)
            text_latent = self.model(text,
                                     intended_emotion,
                                     intended_polarity,
                                     acting_task,
                                     gender,
                                     age,
                                     handedness,
                                     native_tongue,
                                     only_encoder=True)
            for t in range(1, self.T):
                quat_pred_curr, quat_pred_pre_norm_curr = \
                    self.model(text_latent, quat=quat_pred[:, 0:t],
                               offset_lengths=joint_lengths / scales[..., None],
                               only_decoder=True)
                quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone()
                quat_pred_pre_norm_curr[:, t:t +
                                        1] = quat_pred_pre_norm_curr[:,
                                                                     -1:].clone(
                                                                     )

            for s in range(len(quat_pred)):
                quat_pred[s] = qfix(
                    quat_pred[s].view(quat_pred[s].shape[0], self.V,
                                      -1)).view(quat_pred[s].shape[0], -1)

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset + '_new',
                                 subset_name='test')
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos,
            'rotations': quat
        }
        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset + '_new',
                                 subset_name='gt')
        pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                 pos_pred.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()