예제 #1
0
 def get_quats_sos_and_eos():
     quats_sos_and_eos_file = os.path.join(data_path,
                                           'quats_sos_and_eos.npz')
     keys = list(self.data_loader['train'].keys())
     num_samples = len(self.data_loader['train'])
     try:
         mean_quats_sos = np.load(quats_sos_and_eos_file,
                                  allow_pickle=True)['quats_sos']
         mean_quats_eos = np.load(quats_sos_and_eos_file,
                                  allow_pickle=True)['quats_eos']
     except FileNotFoundError:
         mean_quats_sos = np.zeros((self.V, self.D))
         mean_quats_eos = np.zeros((self.V, self.D))
         for j in range(self.V):
             quats_sos = np.zeros((self.D, num_samples))
             quats_eos = np.zeros((self.D, num_samples))
             for s in range(num_samples):
                 quats_sos[:, s] = self.data_loader['train'][
                     keys[s]]['rotations'][0, j]
                 quats_eos[:, s] = self.data_loader['train'][
                     keys[s]]['rotations'][-1, j]
             _, sos_eig_vectors = np.linalg.eig(
                 np.dot(quats_sos, quats_sos.T))
             mean_quats_sos[j] = sos_eig_vectors[:, 0]
             _, eos_eig_vectors = np.linalg.eig(
                 np.dot(quats_eos, quats_eos.T))
             mean_quats_eos[j] = eos_eig_vectors[:, 0]
         np.savez_compressed(quats_sos_and_eos_file,
                             quats_sos=mean_quats_sos,
                             quats_eos=mean_quats_eos)
     mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0)
     mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0)
     for s in range(num_samples):
         pos_sos = \
             MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0),
                                             torch.from_numpy(self.data_loader['train'][keys[s]]
                                             ['positions'][0:1, 0]).double().unsqueeze(0),
                                             self.joint_parents,
                                             torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                             ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
         affs_sos = MocapDataset.get_mpi_affective_features(pos_sos)
         pos_eos = \
             MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0),
                                             torch.from_numpy(self.data_loader['train'][keys[s]]
                                             ['positions'][-1:, 0]).double().unsqueeze(0),
                                             self.joint_parents,
                                             torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                             ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
         affs_eos = MocapDataset.get_mpi_affective_features(pos_eos)
         self.data_loader['train'][keys[s]]['positions'] = \
             np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0)
         self.data_loader['train'][keys[s]]['affective_features'] = \
             np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos),
                            axis=0)
     return mean_quats_sos, mean_quats_eos
예제 #2
0
 def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred):
     # pos_next = torch.zeros_like(pos_last)
     offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \
         unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1)
     pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0],
                                                                            quat_next.shape[1], -1, self.D),
                                                pos_last[:, :, 0],
                                                self.joint_parents,
                                                torch.from_numpy(self.joint_offsets).float().cuda())
     # for joint in range(1, self.V):
     #     pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \
     #                             + pos_next[:, :, self.mocap.joint_parents[joint]]
     root = pos_next[:, :, 0]
     l_shoulder = pos_next[:, :, 18]
     r_shoulder = pos_next[:, :, 25]
     facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
     facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
     return rs_pred * facing + pos_last[:, :, 0, [0, 2]]
예제 #3
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        epoch=None,
                        randomized=True,
                        animations_as_videos=False):

        if epoch is None:
            epoch = 'best'
        if load_saved_model:
            self.load_model_at_epoch(epoch=epoch)
        self.model.eval()
        test_loader = self.data_loader['test']

        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, intended_emotion, intended_polarity, acting_task, gender,
                age, handedness, native_tongue, quat_prelude,
                joint_lengths / scales[..., None])
            for s in range(len(quat_pred)):
                quat_pred[s] = qfix(
                    quat_pred[s].view(quat_pred[s].shape[0], self.V,
                                      -1)).view(quat_pred[s].shape[0], -1)

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='test')
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos,
            'rotations': quat
        }
        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='gt')

        if animations_as_videos:
            pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                     pos_pred.shape[1], -1).permute(0, 2, 1).\
                detach().cpu().numpy()
            display_animations(pos_pred_np,
                               self.joint_parents,
                               save=True,
                               dataset_name=self.dataset,
                               subset_name='epoch_' +
                               str(self.best_loss_epoch),
                               overwrite=True)
예제 #4
0
    def per_eval(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        eval_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
                native_tongue in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_prelude = self.quats_eos.view(1, -1).cuda() \
                    .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
                quat_prelude[:, -1] = quat[:, 0].clone()
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])
                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                row_sums = quat_valid_idx.sum(1,
                                              keepdim=True) * self.D * self.V
                row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg * torch.mean(
                #     (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight * \
                #                    (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg * \
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss
                # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                N += quat.shape[0]

        eval_loss /= N
        self.epoch_info['mean_loss'] = eval_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()
예제 #5
0
    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx,\
            text, text_valid_idx, intended_emotion, intended_polarity,\
                acting_task, gender, age, handedness,\
                native_tongue in self.yield_batch(self.args.batch_size, train_loader):
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()

            self.optimizer.zero_grad()
            with torch.autograd.detect_anomaly():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])

                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
                # row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg *\
                #               torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight *\
                #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg *\
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
                # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                train_loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
                self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred,
            #     'rotations': quat_pred
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')
            # animation = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos,
            #     'rotations': quat
            # }
            # MocapDataset.save_as_bvh(animation,
            #                          dataset_name=self.dataset,
            #                          subset_name='gt')

            # Compute statistics
            batch_loss += train_loss.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = train_loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()
예제 #6
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        randomized=True,
                        epoch=None,
                        animations_as_videos=False):

        if epoch is None:
            epoch = 'best'
        if load_saved_model:
            self.load_model_at_epoch(epoch=epoch)
        self.model.eval()
        test_loader = self.data_loader['test']

        start_time = time.time()
        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, perceived_emotion, perceived_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_pred = torch.zeros_like(quat)
            quat_pred[:, 0] = torch.cat(
                quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:,
                                                                      0].shape)

            quat_pred, quat_pred_pre_norm = self.model(
                text, perceived_emotion, perceived_polarity, acting_task,
                gender, age, handedness, native_tongue, quat[:, :-1],
                joint_lengths / scales[..., None])
            # text_latent = self.model(text, intended_emotion, intended_polarity,
            #                          acting_task, gender, age, handedness, native_tongue, only_encoder=True)
            # for t in range(1, self.T):
            #     quat_pred_curr, _ = self.model(text_latent, quat=quat_pred[:, 0:t],
            #                                    offset_lengths=joint_lengths / scales[..., None],
            #                                    only_decoder=True)
            #     quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone()

            # for s in range(len(quat_pred)):
            #     quat_pred[s] = qfix(quat_pred[s].view(quat_pred[s].shape[0],
            #                                           self.V, -1)).view(quat_pred[s].shape[0], -1)
            quat_pred = torch.cat((quat[:, 1:2], quat_pred), dim=1)
            quat_pred = qfix(
                quat_pred.view(quat_pred.shape[0], quat_pred.shape[1], self.V,
                               -1)).view(quat_pred.shape[0],
                                         quat_pred.shape[1], -1)
            quat_pred = quat_pred[:, 1:]

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        quat_np = quat.detach().cpu().numpy()
        quat_pred_np = quat_pred.detach().cpu().numpy()

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred,
            'valid_idx': quat_valid_idx
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset,
                                 subset_name='test_epoch_{}'.format(epoch),
                                 include_default_pose=False)
        end_time = time.time()
        print('Time taken: {} secs.'.format(end_time - start_time))
        shifted_pos = pos - pos[:, :, 0:1]
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': shifted_pos,
            'rotations': quat,
            'valid_idx': quat_valid_idx
        }

        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset,
                                 subset_name='gt',
                                 include_default_pose=False)

        if animations_as_videos:
            pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                     pos_pred.shape[1], -1).permute(0, 2, 1).\
                detach().cpu().numpy()
            display_animations(pos_pred_np,
                               self.joint_parents,
                               save=True,
                               dataset_name=self.dataset,
                               subset_name='epoch_' +
                               str(self.best_loss_epoch),
                               overwrite=True)
예제 #7
0
    def forward_pass(self, joint_offsets, pos, affs, quat, quat_sos,
                     quat_valid_idx, text, text_valid_idx, perceived_emotion,
                     perceived_polarity, acting_task, gender, age, handedness,
                     native_tongue):
        self.optimizer.zero_grad()
        with torch.autograd.detect_anomaly():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            # text_latent = self.model(text, perceived_emotion, perceived_polarity,
            #                          acting_task, gender, age, handedness, native_tongue,
            #                          only_encoder=True)
            # quat_pred = torch.zeros_like(quat)
            # quat_pred_pre_norm = torch.zeros_like(quat)
            # quat_in = quat_sos[:, :self.T_steps]
            # quat_valid_idx_max = torch.max(torch.sum(quat_valid_idx, dim=-1))
            # for t in range(0, self.T, self.T_steps):
            #     if t > quat_valid_idx_max:
            #         break
            #     quat_pred[:, t:min(self.T, t + self.T_steps)],\
            #         quat_pred_pre_norm[:, t:min(self.T, t + self.T_steps)] =\
            #         self.model(text_latent, quat=quat_in, offset_lengths= joint_lengths / scales[..., None],
            #                    only_decoder=True)
            #     if torch.rand(1) > self.tf:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat_pred[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat_pred[:, t:t + self.T_steps].clone()
            #     else:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat[:, t:min(self.T, t + self.T_steps)].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, perceived_emotion, perceived_polarity, acting_task,
                gender, age, handedness, native_tongue, quat_sos[:, :-1],
                joint_lengths / scales[..., None])
            quat_fixed = qfix(quat.contiguous().view(
                quat.shape[0], quat.shape[1], -1,
                self.D)).contiguous().view(quat.shape[0], quat.shape[1], -1)
            quat_pred = qfix(quat_pred.contiguous().view(
                quat_pred.shape[0], quat_pred.shape[1], -1,
                self.D)).contiguous().view(quat_pred.shape[0],
                                           quat_pred.shape[1], -1)

            quat_pred_pre_norm = quat_pred_pre_norm.view(
                quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1,
                self.D)
            quat_norm_loss = self.args.quat_norm_reg *\
                torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

            # quat_loss, quat_derv_loss = losses.quat_angle_loss(quat_pred, quat_fixed,
            #                                                    quat_valid_idx[:, 1:],
            #                                                    self.V, self.D,
            #                                                    self.lower_body_start,
            #                                                    self.args.upper_body_weight)
            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred, quat_fixed[:, 1:], quat_valid_idx[:, 1:], self.V,
                self.D, self.lower_body_start, self.args.upper_body_weight)
            quat_loss *= self.args.quat_reg

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))
            affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

            row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
            row_sums[row_sums == 0.] = 1.

            shifted_pos = pos - pos[:, :, 0:1]
            shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

            # recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos)
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, :, 7:15])
            recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                shifted_pos[:, 1:])
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, 1:, 7:15])
            # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
            # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #               recons_loss[:, :, self.lower_body_start:].sum(-1)
            recons_loss = self.args.recons_reg *\
                torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            #
            # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
            #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
            # recons_derv_loss = self.args.upper_body_weight *\
            #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
            # recons_derv_loss = 2. * self.args.recons_reg *\
            #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
            #
            # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
            affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])
            affs_loss = self.args.affs_reg * torch.mean(
                (affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            # affs_loss = self.affs_loss_func(affs_pred, affs)

            total_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
            # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss

        return total_loss
예제 #8
0
파일: loader.py 프로젝트: Tanmay-r/STEP
def load_data_MPI(_path, _ftype, coords, joints, cycles=3):

    # Counts: 'Hindi': 292, 'German': 955, 'English': 200
    bvhDirectory = os.path.join(_path, "bvh")
    tagDirectory = os.path.join(_path, "tags")

    data_list = {}
    num_samples = 0
    time_steps = 0
    labels_list = {}
    fileIDs = []
    for filenum in range(1, 1452):
        filename = str(filenum).zfill(6)
        # print(filenum)
        if not os.path.exists(os.path.join(tagDirectory, filename + ".txt")):
            print(os.path.join(tagDirectory, filename + ".txt"), " not found!")
            continue
        names, parents, offsets, positions, rotations = MocapDataset.load_bvh(
            os.path.join(bvhDirectory, filename + ".bvh"))
        tag, text = MocapDataset.read_tags(
            os.path.join(tagDirectory, filename + ".txt"))
        num_samples += 1
        positions = np.reshape(
            positions,
            (positions.shape[0], positions.shape[1] * positions.shape[2]))
        data_list[filenum] = list(positions)
        time_steps_curr = len(positions)
        if time_steps_curr > time_steps:
            time_steps = time_steps_curr
        if "Hindi" in tag:
            labels_list[filenum] = 0
        elif "German" in tag:
            labels_list[filenum] = 1
        elif "English" in tag:
            labels_list[filenum] = 2
        else:
            print("ERROR: ", tag)
        fileIDs.append(filenum)

    labels = np.empty(num_samples)
    data = np.empty((num_samples, time_steps * cycles, joints * coords))
    index = 0
    for si in fileIDs:
        data_list_curr = np.tile(
            data_list[si], (int(np.ceil(time_steps / len(data_list[si]))), 1))
        if (data_list_curr.shape[1] != 69):
            continue
        for ci in range(cycles):
            data[index, time_steps * ci:time_steps *
                 (ci + 1), :] = data_list_curr[0:time_steps]
        labels[index] = labels_list[si]
        index += 1
    data = data[:index]
    labels = labels[:index]
    print(index, num_samples)

    # data = common.get_affective_features(np.reshape(data, (data.shape[0], data.shape[1], joints, coords)))[:, :, :48]
    data_train, data_test, labels_train, labels_test = train_test_split(
        data, labels, test_size=0.1)
    data_train, labels_train = balance_classes(data_train, labels_train)

    return data, labels, data_train, labels_train, data_test, labels_test
예제 #9
0
def load_data(_path, dataset, frame_drop=1, add_mirrored=False):
    data_path = os.path.join(_path, dataset)
    data_dict_file = os.path.join(data_path,
                                  'data_dict_drop_' + str(frame_drop) + '.npz')
    try:
        data_dict = np.load(data_dict_file,
                            allow_pickle=True)['data_dict'].item()
        tag_categories = list(
            np.load(data_dict_file, allow_pickle=True)['tag_categories'])
        max_text_length = np.load(data_dict_file,
                                  allow_pickle=True)['max_text_length'].item()
        max_time_steps = np.load(data_dict_file,
                                 allow_pickle=True)['max_time_steps'].item()
        print('Data file found. Returning data.')
    except FileNotFoundError:
        data_dict = []
        tag_categories = []
        max_text_length = 0.
        max_time_steps = 0.
        if dataset == 'mpi':
            channel_map = {
                'Xrotation': 'x',
                'Yrotation': 'y',
                'Zrotation': 'z'
            }
            data_dict = dict()
            tag_names = []
            with open(os.path.join(data_path, 'tag_names.txt')) as names_file:
                for line in names_file.readlines():
                    line = line[:-1]
                    tag_names.append(line)
            id = tag_names.index('ID')
            relevant_tags = [
                'Intended emotion', 'Intended polarity', 'Perceived category',
                'Perceived polarity', 'Acting task', 'Gender', 'Age',
                'Handedness', 'Native tongue', 'Text'
            ]
            tag_categories = [[] for _ in range(len(relevant_tags) - 1)]
            tag_files = glob.glob(os.path.join(data_path, 'tags/*.txt'))
            num_files = len(tag_files)
            for tag_file in tag_files:
                tag_data = []
                with open(tag_file) as f:
                    for line in f.readlines():
                        line = line[:-1]
                        tag_data.append(line)
                for category in range(len(tag_categories)):
                    tag_to_append = relevant_tags[category]
                    if tag_data[tag_names.index(
                            tag_to_append)] not in tag_categories[category]:
                        tag_categories[category].append(
                            tag_data[tag_names.index(tag_to_append)])

            for data_counter, tag_file in enumerate(tag_files):
                tag_data = []
                with open(tag_file) as f:
                    for line in f.readlines():
                        line = line[:-1]
                        tag_data.append(line)
                bvh_file = os.path.join(data_path,
                                        'bvh/' + tag_data[id] + '.bvh')
                names, parents, offsets,\
                    positions, rotations, base_fps = MocapDataset.load_bvh(bvh_file, channel_map)
                positions_down_sampled = positions[1::frame_drop]
                rotations_down_sampled = rotations[1::frame_drop]
                if len(positions_down_sampled) > max_time_steps:
                    max_time_steps = len(positions_down_sampled)
                joints_dict = dict()
                joints_dict['joints_to_model'] = np.arange(len(parents))
                joints_dict['joints_parents_all'] = parents
                joints_dict['joints_parents'] = parents
                joints_dict['joints_names_all'] = names
                joints_dict['joints_names'] = names
                joints_dict['joints_offsets_all'] = offsets
                joints_dict['joints_left'] = [
                    idx for idx, name in enumerate(names)
                    if 'left' in name.lower()
                ]
                joints_dict['joints_right'] = [
                    idx for idx, name in enumerate(names)
                    if 'right' in name.lower()
                ]
                data_dict[tag_data[id]] = dict()
                data_dict[tag_data[id]]['joints_dict'] = joints_dict
                data_dict[tag_data[id]]['positions'] = positions_down_sampled
                data_dict[tag_data[id]]['rotations'] = rotations_down_sampled
                data_dict[tag_data[id]]['affective_features'] =\
                    MocapDataset.get_mpi_affective_features(positions_down_sampled)
                for tag_index, tag_name in enumerate(relevant_tags):
                    if tag_name.lower() == 'text':
                        data_dict[tag_data[id]][tag_name] =\
                            tag_data[tag_names.index(tag_name)].replace(' s ', '\'s ').replace(' t ', '\'t ')
                        text_vad = []
                        words = data_dict[tag_data[id]][tag_name].split(' ')
                        for lexeme in words:
                            if lexeme.isalpha():
                                if len(lexeme) == 1 and not (
                                        lexeme.lower() is 'a'
                                        or lexeme.lower() is 'i'):
                                    continue
                                text_vad.append(get_vad(lexeme))
                        try:
                            data_dict[tag_data[id]][
                                tag_name + ' VAD'] = np.stack(text_vad)
                            data_dict[tag_data[id]]['best_tts_rate'],\
                                data_dict[tag_data[id]]['gesture_splits'] =\
                                get_gesture_splits(data_dict[tag_data[id]][tag_name], words,
                                                   len(data_dict[tag_data[id]]['positions']),
                                                   base_fps / frame_drop)
                        except ValueError:
                            data_dict[tag_data[id]][tag_name +
                                                    ' VAD'] = np.zeros((0, 3))
                        text_length = len(data_dict[tag_data[id]][tag_name])
                        if text_length > max_text_length:
                            max_text_length = text_length
                        continue
                    if tag_name.lower() == 'age':
                        data_dict[tag_data[id]][tag_name] = float(
                            tag_data[tag_names.index(tag_name)]) / 100.
                        continue
                    if tag_name is 'Perceived category':
                        categories = tag_categories[0]
                    elif tag_name is 'Perceived polarity':
                        categories = tag_categories[1]
                    else:
                        categories = tag_categories[tag_index]
                    data_dict[tag_data[id]][tag_name] = to_one_hot(
                        tag_data[tag_names.index(tag_name)], categories)
                    if tag_name is 'Intended emotion' or tag_name is 'Perceived category':
                        data_dict[tag_data[id]][tag_name + ' VAD'] = get_vad(
                            tag_data[tag_names.index(tag_name)])
                print('\rData file not found. Processing file {}/{}: {:3.2f}%'.
                      format(data_counter + 1, num_files,
                             data_counter * 100. / num_files),
                      end='')
            print('\rData file not found. Processing files: done. Saving...',
                  end='')
            np.savez_compressed(data_dict_file,
                                data_dict=data_dict,
                                tag_categories=tag_categories,
                                max_text_length=max_text_length,
                                max_time_steps=max_time_steps)
            print('done. Returning data.')
        elif dataset == 'creative_it':
            mocap_data_dirs = os.listdir(os.path.join(data_path, 'mocap'))
            for mocap_dir in mocap_data_dirs:
                mocap_data_files = glob.glob(
                    os.path.join(data_path, 'mocap/' + mocap_dir + '/*.txt'))
        else:
            raise FileNotFoundError('Dataset not found.')

    return data_dict, tag_categories, max_text_length, max_time_steps
예제 #10
0
    def __init__(self, args, dataset, data_loader, T, V, C, D, A, S,
                 joints_dict, joint_names, joint_offsets, joint_parents,
                 num_labels, prefix_length, target_length,
                 min_train_epochs=20, generate_while_train=False,
                 save_path=None, device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, np.arange(V), joints_dict)
        self.joint_names = joint_names
        self.joint_offsets = joint_offsets
        self.joint_parents = joint_parents
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 4
        self.Z = 1
        self.RS = 1
        self.o_scale = 10.
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0])
        self.model.cuda(device)
        self.orient_h = None
        self.quat_h = None
        self.z_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.spline_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr)
                # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr
예제 #11
0
class Processor(object):
    """
        Processor for emotive gait generation
    """

    def __init__(self, args, dataset, data_loader, T, V, C, D, A, S,
                 joints_dict, joint_names, joint_offsets, joint_parents,
                 num_labels, prefix_length, target_length,
                 min_train_epochs=20, generate_while_train=False,
                 save_path=None, device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, np.arange(V), joints_dict)
        self.joint_names = joint_names
        self.joint_offsets = joint_offsets
        self.joint_parents = joint_parents
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 4
        self.Z = 1
        self.RS = 1
        self.o_scale = 10.
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0])
        self.model.cuda(device)
        self.orient_h = None
        self.quat_h = None
        self.z_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.spline_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr)
                # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr

    def process_data(self, data, poses, quat, trans, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        quat = quat.float().to(self.device)
        trans = trans.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, quat, trans, affs

    def load_best_model(self, ):
        model_name, self.best_loss_epoch, self.best_loss =\
            get_best_epoch_and_loss(self.args.work_dir)
        best_model_found = False
        try:
            loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name))
            self.model.load_state_dict(loaded_vars['model_dict'])
            self.orient_h = loaded_vars['orient_h']
            self.quat_h = loaded_vars['quat_h']
            best_model_found = True
        except (FileNotFoundError, IsADirectoryError):
            print('No saved model found.')
        return best_model_found

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tf(self):
        if self.meta_info['epoch'] > 20:
            self.tf = self.tf * self.args.tf_decay

    def show_epoch_info(self):

        print_epochs = [self.best_loss_epoch if self.best_loss_epoch is not None else 0]
        best_metrics = [self.best_loss]
        i = 0
        for k, v in self.epoch_info.items():
            self.io.print_log('\t{}: {}. Best so far: {} (epoch: {:d}).'.
                              format(k, v, best_metrics[i], print_epochs[i]))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def yield_batch(self, batch_size, dataset):
        batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32')
        batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32')
        batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        probs = []
        for k in dataset.keys():
            if 'spline' not in dataset[k]:
                raise KeyError('No splines found. Perhaps you forgot to compute them?')
            probs.append(dataset[k]['spline'].size())
        probs = np.array(probs) / np.sum(probs)

        for p in range(pseudo_passes):
            rand_keys = np.random.choice(len(dataset), size=batch_size, replace=True, p=probs)
            for i, k in enumerate(rand_keys):
                pos = dataset[str(k)]['positions'][:self.T]
                quat = dataset[str(k)]['rotations'][:self.T, 1:]
                orient = dataset[str(k)]['rotations'][:self.T, 0]
                affs = dataset[str(k)]['affective_features'][:self.T]
                spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline'])
                spline = spline[:self.T]
                phase = phase[:self.T]
                z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T]
                z_mean = np.mean(z[:self.prefix_length])
                z_dev = z - z_mean
                root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T]
                labels = dataset[str(k)]['labels'][:self.num_labels[0]]

                batch_pos[i] = pos
                batch_quat[i] = quat.reshape(self.T, -1)
                batch_orient[i] = orient.reshape(self.T, -1)
                batch_z_mean[i] = z_mean.reshape(-1, 1)
                batch_z_dev[i] = z_dev.reshape(self.T, -1)
                batch_root_speed[i] = root_speed.reshape(self.T, 1)
                batch_affs[i] = affs
                batch_spline[i] = spline
                batch_labels[i] = np.expand_dims(labels, axis=0)
            yield batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\
                  batch_root_speed, batch_affs, batch_spline, batch_labels

    def return_batch(self, batch_size, dataset, randomized=True):
        if len(batch_size) > 1:
            rand_keys = np.copy(batch_size)
            batch_size = len(batch_size)
        else:
            batch_size = batch_size[0]
            probs = []
            for k in dataset.keys():
                if 'spline' not in dataset[k]:
                    raise KeyError('No splines found. Perhaps you forgot to compute them?')
                probs.append(dataset[k]['spline'].size())
            probs = np.array(probs) / np.sum(probs)
            if randomized:
                rand_keys = np.random.choice(len(dataset), size=batch_size, replace=False, p=probs)
            else:
                rand_keys = np.arange(batch_size)

        batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32')
        batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32')
        batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        for i, k in enumerate(rand_keys):
            pos = dataset[str(k)]['positions'][:self.T]
            quat = dataset[str(k)]['rotations'][:self.T, 1:]
            orient = dataset[str(k)]['rotations'][:self.T, 0]
            affs = dataset[str(k)]['affective_features'][:self.T]
            spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline'])
            spline = spline[:self.T]
            phase = phase[:self.T]
            z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T]
            z_mean = np.mean(z[:self.prefix_length])
            z_dev = z - z_mean
            root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T]
            labels = dataset[str(k)]['labels'][:self.num_labels[0]]

            batch_pos[i] = pos
            batch_quat[i] = quat.reshape(self.T, -1)
            batch_orient[i] = orient.reshape(self.T, -1)
            batch_z_mean[i] = z_mean.reshape(-1, 1)
            batch_z_dev[i] = z_dev.reshape(self.T, -1)
            batch_root_speed[i] = root_speed.reshape(self.T, 1)
            batch_affs[i] = affs
            batch_spline[i] = spline
            batch_labels[i] = np.expand_dims(labels, axis=0)

        return batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\
            batch_root_speed, batch_affs, batch_spline, batch_labels

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            orient = torch.from_numpy(orient).cuda()
            quat = torch.from_numpy(quat).cuda()
            z_mean = torch.from_numpy(z_mean).cuda()
            z_dev = torch.from_numpy(z_dev).cuda()
            root_speed = torch.from_numpy(root_speed).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
            z_rs = torch.cat((z_dev, root_speed), dim=-1)
            quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

            pos_pred = pos.clone()
            orient_pred = orient.clone()
            quat_pred = quat.clone()
            z_rs_pred = z_rs.clone()
            affs_pred = affs.clone()
            spline_pred = spline.clone()
            pos_pred_all = pos.clone()
            orient_pred_all = orient.clone()
            quat_pred_all = quat.clone()
            z_rs_pred_all = z_rs.clone()
            affs_pred_all = affs.clone()
            spline_pred_all = spline.clone()
            orient_prenorm_terms = torch.zeros_like(orient_pred)
            quat_prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    self.orient_h, self.quat_h,\
                    orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                    quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        orient_pred[:, t:self.prefix_length + t],
                        quat_pred[:, t:self.prefix_length + t],
                        z_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline_pred[:, t:self.prefix_length + t],
                        labels[:, t:self.prefix_length + t],
                        orient_h=None if t == 0 else self.orient_h,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, :self.prefix_length + t],
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                        orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()

            prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                torch.cat((orient_pred_all[:, self.prefix_length - 1:],
                           quat_pred_all[:, self.prefix_length - 1:]), dim=-1),
                quat_all, self.V, self.D)
            quat_loss *= self.args.quat_reg

            z_rs_loss = self.z_rs_loss_func(z_rs_pred_all[:, self.prefix_length:],
                                            z_rs[:, self.prefix_length:])
            spline_loss = self.spline_loss_func(spline_pred_all[:, self.prefix_length:],
                                                spline[:, self.prefix_length:])
            fs_loss = losses.foot_speed_loss(pos_pred, pos)
            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + z_rs_loss + spline_loss + fs_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).
            #         float().unsqueeze(0).repeat(pos_pred_all.shape[0], 1, 1),
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred_all,
            #     'rotations': torch.cat((orient_pred_all, quat_pred_all), dim=-1)
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()

    def per_test(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        valid_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                pos = torch.from_numpy(pos).cuda()
                orient = torch.from_numpy(orient).cuda()
                quat = torch.from_numpy(quat).cuda()
                z_mean = torch.from_numpy(z_mean).cuda()
                z_dev = torch.from_numpy(z_dev).cuda()
                root_speed = torch.from_numpy(root_speed).cuda()
                affs = torch.from_numpy(affs).cuda()
                spline = torch.from_numpy(spline).cuda()
                labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
                z_rs = torch.cat((z_dev, root_speed), dim=-1)
                quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

                pos_pred = pos.clone()
                orient_pred = orient.clone()
                quat_pred = quat.clone()
                z_rs_pred = z_rs.clone()
                affs_pred = affs.clone()
                spline_pred = spline.clone()
                orient_prenorm_terms = torch.zeros_like(orient_pred)
                quat_prenorm_terms = torch.zeros_like(quat_pred)

                # forward
                for t in range(self.target_length):
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        self.orient_h, self.quat_h,\
                        orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                        quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                        self.model(
                            orient_pred[:, t:self.prefix_length + t],
                            quat_pred[:, t:self.prefix_length + t],
                            z_rs_pred[:, t:self.prefix_length + t],
                            affs_pred[:, t:self.prefix_length + t],
                            spline[:, t:self.prefix_length + t],
                            labels[:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else self.orient_h,
                            quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                        affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[:, :self.prefix_length + t],
                            pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                            orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])

                prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
                prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                    torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    torch.cat((orient_pred[:, self.prefix_length - 1:],
                               quat_pred[:, self.prefix_length - 1:]), dim=-1),
                    quat_all, self.V, self.D)
                quat_loss *= self.args.quat_reg

                recons_loss = self.args.recons_reg *\
                              (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] -
                               pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()
                valid_loss += recons_loss
                N += quat.shape[0]

        valid_loss /= N
        # if self.meta_info['epoch'] > 5 and self.loss_updated:
        #     pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
        #                        overwrite=True)
        #     pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_in_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) +
        #                                                               '_gt',
        #                        overwrite=True)

        self.epoch_info['mean_loss'] = valid_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info['epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()

    def train(self):

        if self.args.load_last_best:
            best_model_found = self.load_best_model()
            self.args.start_epoch = self.best_loss_epoch if best_model_found else 0
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval == 0) or (
                    epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test()
                self.io.print_log('Done.')

            # save model and weights
            if self.loss_updated:
                torch.save({'model_dict': self.model.state_dict(),
                            'orient_h': self.orient_h,
                            'quat_h': self.quat_h},
                           os.path.join(self.args.work_dir, 'epoch_{}_loss_{:.4f}_model.pth.tar'.
                                        format(epoch, self.best_loss)))

                if self.generate_while_train:
                    self.generate_motion(load_saved_model=False, samples_to_generate=1)

    def copy_prefix(self, var, prefix_length=None):
        if prefix_length is None:
            prefix_length = self.prefix_length
        return [var[s, :prefix_length].unsqueeze(0) for s in range(var.shape[0])]

    def generate_linear_trajectory(self, traj, alpha=0.001):
        traj_markers = (traj[:, self.prefix_length - 2] +
                        (traj[:, self.prefix_length - 1] - traj[:, self.prefix_length - 2]) / alpha).unsqueeze(1)
        return traj_markers

    def generate_circular_trajectory(self, traj, alpha=5., num_segments=10):
        last_segment = alpha * traj[:, self.prefix_length - 1:self.prefix_length] -\
                       traj[:, self.prefix_length - 2:self.prefix_length - 1]
        last_marker = traj[:, self.prefix_length - 1:self.prefix_length]
        traj_markers = last_marker.clone()
        angle_per_segment = 2. * np.pi / num_segments
        for _ in range(num_segments):
            next_segment = qrot(expmap_to_quaternion(
                torch.tensor([0, -angle_per_segment, 0]).cuda().float().repeat(
                    last_segment.shape[0], last_segment.shape[1], 1)), torch.cat((
                last_segment[..., 0:1],
                torch.zeros_like(last_segment[..., 0:1]),
                last_segment[..., 1:]), dim=-1))[..., [0, 2]]
            next_marker = next_segment + last_marker
            traj_markers = torch.cat((traj_markers, next_marker), dim=1)
            last_segment = next_segment.clone()
            last_marker = next_marker.clone()
        traj_markers = traj_markers[:, 1:]
        return traj_markers

    def compute_next_traj_point(self, traj, traj_marker, rs_pred):
        tangent = traj_marker - traj
        tangent /= (torch.norm(tangent, dim=-1) + 1e-9)
        return tangent * rs_pred + traj

    def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred):
        # pos_next = torch.zeros_like(pos_last)
        offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \
            unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1)
        pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0],
                                                                               quat_next.shape[1], -1, self.D),
                                                   pos_last[:, :, 0],
                                                   self.joint_parents,
                                                   torch.from_numpy(self.joint_offsets).float().cuda())
        # for joint in range(1, self.V):
        #     pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \
        #                             + pos_next[:, :, self.mocap.joint_parents[joint]]
        root = pos_next[:, :, 0]
        l_shoulder = pos_next[:, :, 18]
        r_shoulder = pos_next[:, :, 25]
        facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
        facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
        return rs_pred * facing + pos_last[:, :, 0, [0, 2]]

    def get_diff_from_traj(self, pos_pred, traj_pred, s):
        root = pos_pred[s][:, :, 0]
        l_shoulder = pos_pred[s][:, :, 18]
        r_shoulder = pos_pred[s][:, :, 25]
        facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
        facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
        tangents = traj_pred[s][:, 1:] - traj_pred[s][:, :-1]
        tangent_norms = torch.norm(tangents, dim=-1)
        tangents /= (tangent_norms[..., None] + 1e-9)
        tangents = torch.cat((torch.zeros_like(tangents[:, 0:1]), tangents), dim=1)
        tangent_norms = torch.cat((torch.zeros_like(tangent_norms[:, 0:1]), tangent_norms), dim=1)
        axis_diff = torch.cross(torch.cat((facing[..., 0:1],
                                           torch.zeros_like(facing[..., 0:1]),
                                           facing[..., 1:]), dim=-1),
                                torch.cat((tangents[..., 0:1],
                                           torch.zeros_like(tangents[..., 0:1]),
                                           tangents[..., 1:]), dim=-1))
        axis_diff_norms = torch.norm(axis_diff, dim=-1)
        axis_diff /= (axis_diff_norms[..., None] + 1e-9)
        angle_diff = torch.acos(torch.einsum('ijk,ijk->ij', facing, tangents).clamp(min=-1., max=1.))
        angle_diff[tangent_norms < 1e-6] = 0.
        return axis_diff, angle_diff

    def rotate_gaits(self, orient_pred, quat_pred, quat_diff, head_tilt, l_shoulder_slouch, r_shoulder_slouch):
        quat_reshape = quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D).clone()
        quat_reshape[..., 14, :] = qmul(torch.from_numpy(head_tilt).cuda().float(),
                                        quat_reshape[..., 14, :])
        quat_reshape[..., 16, :] = qmul(torch.from_numpy(l_shoulder_slouch).cuda().float(),
                                        quat_reshape[..., 16, :])
        quat_reshape[..., 17, :] = qmul(torch.from_numpy(qinv(l_shoulder_slouch)).cuda().float(),
                                        quat_reshape[..., 17, :])
        quat_reshape[..., 23, :] = qmul(torch.from_numpy(r_shoulder_slouch).cuda().float(),
                                        quat_reshape[..., 23, :])
        quat_reshape[..., 24, :] = qmul(torch.from_numpy(qinv(r_shoulder_slouch)).cuda().float(),
                                        quat_reshape[..., 24, :])
        return qmul(quat_diff, orient_pred), quat_reshape.contiguous().view(quat_reshape.shape[0],
                                                                            quat_reshape.shape[1], -1)

    def generate_motion(self, load_saved_model=True, samples_to_generate=1534, max_steps=300, randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, quat, orient, z_mean, z_dev, \
        root_speed, affs, spline, labels = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        pos = torch.from_numpy(pos).cuda()
        traj = pos[:, :, 0, [0, 2]].clone()
        orient = torch.from_numpy(orient).cuda()
        quat = torch.from_numpy(quat).cuda()
        z_mean = torch.from_numpy(z_mean).cuda()
        z_dev = torch.from_numpy(z_dev).cuda()
        root_speed = torch.from_numpy(root_speed).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        z_rs = torch.cat((z_dev, root_speed), dim=-1)
        quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)
        labels = np.tile(labels, (1, max_steps + self.prefix_length, 1))

        # Begin for transition
        # traj[:, self.prefix_length - 2] = torch.tensor([-0.208, 4.8]).cuda().float()
        # traj[:, self.prefix_length - 1] = torch.tensor([-0.204, 5.1]).cuda().float()
        # final_emo_idx = int(max_steps/2)
        # labels[:, final_emo_idx:] = np.array([1., 0., 0., 0.])
        # labels[:, :final_emo_idx + 1] = np.linspace(labels[:, 0], labels[:, final_emo_idx],
        #                                             num=final_emo_idx + 1, axis=1)
        # End for transition
        labels = torch.from_numpy(labels).cuda()

        # traj_np = traj_markers.detach().cpu().numpy()
        # import matplotlib.pyplot as plt
        # plt.plot(traj_np[6, :, 0], traj_np[6, :, 1])
        # plt.show()

        happy_idx = [25, 295, 390, 667, 1196]
        sad_idx = [169, 184, 258, 948, 974]
        angry_idx = [89, 93, 96, 112, 289, 290, 978]
        neutral_idx = [72, 106, 143, 237, 532, 747, 1177]
        sample_idx = np.squeeze(np.concatenate((happy_idx, sad_idx, angry_idx, neutral_idx)))

        ## CHANGE HERE
        # scene_corners = torch.tensor([[149.862, 50.833],
        #                               [149.862, 36.81],
        #                               [161.599, 36.81],
        #                               [161.599, 50.833]]).cuda().float()
        # character_heights = torch.tensor([0.95, 0.88, 0.86, 0.90, 0.95, 0.82]).cuda().float()
        # num_characters_per_side = torch.tensor([2, 3, 2, 3]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_trajectories(scene_corners, z_mean, character_heights,
        #                           num_characters_per_side, traj[:, :self.prefix_length])
        # num_characters_per_side = torch.tensor([4, 0, 0, 0]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_simple_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                                  num_characters_per_side, traj[sample_idx, :self.prefix_length])
        # traj_markers, traj_offsets, character_scale =\
        #     generate_rvo_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                               num_characters_per_side, traj[sample_idx, :self.prefix_length])

        # traj[sample_idx, :self.prefix_length] += traj_offsets
        # pos_sampled = pos[sample_idx].clone()
        # pos_sampled[:, :self.prefix_length, :, [0, 2]] += traj_offsets.unsqueeze(2).repeat(1, 1, self.V, 1)
        # pos[sample_idx] = pos_sampled
        # traj_markers = self.generate_linear_trajectory(traj)

        pos_pred = self.copy_prefix(pos)
        traj_pred = self.copy_prefix(traj)
        orient_pred = self.copy_prefix(orient)
        quat_pred = self.copy_prefix(quat)
        z_rs_pred = self.copy_prefix(z_rs)
        affs_pred = self.copy_prefix(affs)
        spline_pred = self.copy_prefix(spline)
        labels_pred = self.copy_prefix(labels, prefix_length=max_steps + self.prefix_length)

        # forward
        elapsed_time = np.zeros(len(sample_idx))
        for counter, s in enumerate(sample_idx):  # range(samples_to_generate):
            start_time = time.time()
            orient_h_copy = self.orient_h.clone()
            quat_h_copy = self.quat_h.clone()
            ## CHANGE HERE
            num_markers = max_steps + self.prefix_length + 1
            # num_markers = traj_markers[s].shape[0]
            marker_idx = 0
            t = -1
            with torch.no_grad():
                while marker_idx < num_markers:
                    t += 1
                    if t > max_steps:
                        print('Sample: {}. Did not reach end in {} steps.'.format(s, max_steps), end='')
                        break
                    pos_pred[s] = torch.cat((pos_pred[s], torch.zeros_like(pos_pred[s][:, -1:])), dim=1)
                    traj_pred[s] = torch.cat((traj_pred[s], torch.zeros_like(traj_pred[s][:, -1:])), dim=1)
                    orient_pred[s] = torch.cat((orient_pred[s], torch.zeros_like(orient_pred[s][:, -1:])), dim=1)
                    quat_pred[s] = torch.cat((quat_pred[s], torch.zeros_like(quat_pred[s][:, -1:])), dim=1)
                    z_rs_pred[s] = torch.cat((z_rs_pred[s], torch.zeros_like(z_rs_pred[s][:, -1:])), dim=1)
                    affs_pred[s] = torch.cat((affs_pred[s], torch.zeros_like(affs_pred[s][:, -1:])), dim=1)
                    spline_pred[s] = torch.cat((spline_pred[s], torch.zeros_like(spline_pred[s][:, -1:])), dim=1)

                    orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    orient_h_copy, quat_h_copy = \
                        self.model(
                            orient_pred[s][:, t:self.prefix_length + t],
                            quat_pred[s][:, t:self.prefix_length + t],
                            z_rs_pred[s][:, t:self.prefix_length + t],
                            affs_pred[s][:, t:self.prefix_length + t],
                            spline_pred[s][:, t:self.prefix_length + t],
                            labels_pred[s][:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else orient_h_copy,
                            quat_h=None if t == 0 else quat_h_copy, return_prenorm=False)

                    traj_curr = traj_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t].clone()
                    # root_speed = torch.norm(
                    #     pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] - \
                    #     pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t, 0], dim=-1)

                    ## CHANGE HERE
                    # traj_next = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2])
                    try:
                        traj_next = traj[s, self.prefix_length + t]
                    except IndexError:
                        traj_next = \
                            self.compute_next_traj_point_sans_markers(
                                pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t],
                                quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 1])

                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            traj_next,
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])

                    # min_speed_pred = torch.min(torch.cat((lf_speed_pred.unsqueeze(-1),
                    #                                        rf_speed_pred.unsqueeze(-1)), dim=-1), dim=-1)[0]
                    # if min_speed_pred - diff_speeds_mean[s] - diff_speeds_std[s] < 0.:
                    #     root_speed_pred = 0.
                    # else:
                    #     root_speed_pred = o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2]
                    #

                    ## CHANGE HERE
                    # traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         root_speed_pred)
                    # if torch.norm(traj_next - traj_curr, dim=-1).squeeze() >= \
                    #         torch.norm(traj_markers[s, marker_idx] - traj_curr, dim=-1).squeeze():
                    #     marker_idx += 1
                    traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = traj_next
                    marker_idx += 1
                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])
                    print('Sample: {}. Steps: {}'.format(s, t), end='\r')
            print()

            # shift = torch.zeros((1, scene_corners.shape[1] + 1)).cuda().float()
            # shift[..., [0, 2]] = scene_corners[0]
            # pos_pred[s] = (pos_pred[s] - shift) / character_scale + shift
            # pos_pred_np = pos_pred[s].contiguous().view(pos_pred[s].shape[0],
            #                                             pos_pred[s].shape[1], -1).permute(0, 2, 1).\
            #     detach().cpu().numpy()
            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
            #                    save_file_names=[str(s).zfill(6)],
            #                    overwrite=True)

            # plt.cla()
            # fig, (ax1, ax2) = plt.subplots(2, 1)
            # ax1.plot(root_speeds[s])
            # ax1.plot(lf_speeds[s])
            # ax1.plot(rf_speeds[s])
            # ax1.plot(min_speeds[s] - root_speeds[s])
            # ax1.legend(['root', 'left', 'right', 'diff'])
            # ax2.plot(root_speeds_pred)
            # ax2.plot(lf_speeds_pred)
            # ax2.plot(rf_speeds_pred)
            # ax2.plot(min_speeds_pred - root_speeds_pred)
            # ax2.legend(['root', 'left', 'right', 'diff'])
            # plt.show()

            head_tilt = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            l_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            r_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            head_tilt = Quaternions.from_euler(head_tilt, order='xyz').qs
            l_shoulder_slouch = Quaternions.from_euler(l_shoulder_slouch, order='xyz').qs
            r_shoulder_slouch = Quaternions.from_euler(r_shoulder_slouch, order='xyz').qs

            # Begin for aligning facing direction to trajectory
            axis_diff, angle_diff = self.get_diff_from_traj(pos_pred, traj_pred, s)
            angle_thres = 0.3
            # angle_thres = torch.max(angle_diff[:, 1:self.prefix_length])
            angle_diff[angle_diff <= angle_thres] = 0.
            angle_diff[:, self.prefix_length] = 0.
            # End for aligning facing direction to trajectory
            # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # angle_diff_intermediate = self.get_diff_from_traj(pos_pred, traj_pred, s)
            # if torch.max(angle_diff_intermediate[:, self.prefix_length:]) > np.pi / 2.:
            #     quat_diff = Quaternions.from_angle_axis(-angle_diff.cpu().numpy(), np.array([0, 1, 0])).qs
            #     pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # axis_diff = torch.zeros_like(axis_diff)
            # axis_diff[..., 1] = 1.
            # angle_diff = torch.zeros_like(angle_diff)
            quat_diff = torch.from_numpy(Quaternions.from_angle_axis(
                angle_diff.cpu().numpy(), axis_diff.cpu().numpy()).qs).cuda().float()
            orient_pred[s], quat_pred[s] = self.rotate_gaits(orient_pred[s], quat_pred[s],
                                                             quat_diff, head_tilt,
                                                             l_shoulder_slouch, r_shoulder_slouch)

            if labels_pred[s][:, 0, 0] > 0.5:
                label_dir = 'happy'
            elif labels_pred[s][:, 0, 1] > 0.5:
                label_dir = 'sad'
            elif labels_pred[s][:, 0, 2] > 0.5:
                label_dir = 'angry'
            else:
                label_dir = 'neutral'

            ## CHANGE HERE
            # pos_pred[s] = pos_pred[s][:, self.prefix_length + 5:]
            # o_z_rs_pred[s] = o_z_rs_pred[s][:, self.prefix_length + 5:]
            # quat_pred[s] = quat_pred[s][:, self.prefix_length + 5:]

            traj_pred_np = pos_pred[s][0, :, 0].cpu().numpy()

            save_file_name = '{:06}_{:.2f}_{:.2f}_{:.2f}_{:.2f}'.format(s,
                                                                        labels_pred[s][0, 0, 0],
                                                                        labels_pred[s][0, 0, 1],
                                                                        labels_pred[s][0, 0, 2],
                                                                        labels_pred[s][0, 0, 3])

            animation_pred = {
                'joint_names': self.joint_names,
                'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).float().unsqueeze(0).repeat(
                    len(pos_pred), 1, 1),
                'joint_parents': self.joint_parents,
                'positions': pos_pred[s],
                'rotations': torch.cat((orient_pred[s], quat_pred[s]), dim=-1)
            }
            self.mocap.save_as_bvh(animation_pred,
                                   dataset_name=self.dataset,
                                   # subset_name='epoch_' + str(self.best_loss_epoch),
                                   # save_file_names=[str(s).zfill(6)])
                                   subset_name=os.path.join('no_aff_epoch_' + str(self.best_loss_epoch),
                                                            str(counter).zfill(2) + '_' + label_dir),
                                   save_file_names=['root'])
            end_time = time.time()
            elapsed_time[counter] = end_time - start_time
            print('Elapsed Time: {}'.format(elapsed_time[counter]))

            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset,
            #                    # subset_name='epoch_' + str(self.best_loss_epoch),
            #                    # save_file_names=[str(s).zfill(6)],
            #                    subset_name=os.path.join('epoch_' + str(self.best_loss_epoch), label_dir),
            #                    save_file_names=[save_file_name],
            #                    overwrite=True)
        print('Mean Elapsed Time: {}'.format(np.mean(elapsed_time)))
class Processor(object):
    """
        Processor for gait generation
    """
    def __init__(self,
                 args,
                 dataset,
                 data_loader,
                 T,
                 V,
                 C,
                 D,
                 A,
                 S,
                 joint_parents,
                 num_labels,
                 prefix_length,
                 target_length,
                 min_train_epochs=-1,
                 generate_while_train=False,
                 save_path=None,
                 device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, joint_parents)
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 1
        self.PRS = 2
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.joint_parents = joint_parents
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O,
                                                num_labels[0], self.PRS)
        self.model.cuda(device)
        self.quat_h = None
        self.p_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.best_mean_ap = 0.
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.best_acc_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr

    def process_data(self, data, poses, quat, trans, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        quat = quat.float().to(self.device)
        trans = trans.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, quat, trans, affs

    def load_best_model(self, ):
        if self.best_loss_epoch is None:
            model_name, self.best_loss_epoch, self.best_loss, self.best_mean_ap =\
                get_best_epoch_and_loss(self.args.work_dir)
            # load model
            # if self.best_loss_epoch > 0:
        loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name))
        self.model.load_state_dict(loaded_vars['model_dict'])
        self.quat_h = loaded_vars['quat_h']

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tf(self):
        if self.meta_info['epoch'] > 20:
            self.tf = self.tf * self.args.tf_decay

    def show_epoch_info(self):

        print_epochs = [
            self.best_loss_epoch if self.best_loss_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0
        ]
        best_metrics = [self.best_loss, 0, self.best_mean_ap]
        i = 0
        for k, v in self.epoch_info.items():
            self.io.print_log(
                '\t{}: {}. Best so far: {} (epoch: {:d}).'.format(
                    k, v, best_metrics[i], print_epochs[i]))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def yield_batch(self, batch_size, dataset):
        batch_pos = np.zeros((batch_size, self.T, self.V, self.C),
                             dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D),
                              dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS),
                                              dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]),
                                dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        probs = []
        for k in dataset.keys():
            if 'spline' not in dataset[k]:
                raise KeyError(
                    'No splines found. Perhaps you forgot to compute them?')
            probs.append(dataset[k]['spline'].size())
        probs = np.array(probs) / np.sum(probs)

        for p in range(pseudo_passes):
            rand_keys = np.random.choice(len(dataset),
                                         size=batch_size,
                                         replace=True,
                                         p=probs)
            for i, k in enumerate(rand_keys):
                pos = dataset[str(k)]['positions_world']
                quat = dataset[str(k)]['rotations']
                orient = dataset[str(k)]['orientations']
                affs = dataset[str(k)]['affective_features']
                spline, phase = Spline.extract_spline_features(
                    dataset[str(k)]['spline'])
                root_speed = dataset[str(k)]['trans_and_controls'][:,
                                                                   -1].reshape(
                                                                       -1, 1)
                labels = dataset[str(k)]['labels'][:self.num_labels[0]]

                batch_pos[i] = pos
                batch_quat[i] = quat.reshape(self.T, -1)
                batch_orient[i] = orient.reshape(self.T, -1)
                batch_affs[i] = affs
                batch_spline[i] = spline
                batch_phase_and_root_speed[i] = np.concatenate(
                    (phase, root_speed), axis=-1)
                batch_labels[i] = np.expand_dims(labels, axis=0)
            yield batch_pos, batch_quat, batch_orient, batch_affs, batch_spline,\
                  batch_phase_and_root_speed / np.pi, batch_labels

    def return_batch(self, batch_size, dataset):
        if len(batch_size) > 1:
            rand_keys = np.copy(batch_size)
            batch_size = len(batch_size)
        else:
            batch_size = batch_size[0]
            probs = []
            for k in dataset.keys():
                if 'spline' not in dataset[k]:
                    raise KeyError(
                        'No splines found. Perhaps you forgot to compute them?'
                    )
                probs.append(dataset[k]['spline'].size())
            probs = np.array(probs) / np.sum(probs)
            rand_keys = np.random.choice(len(dataset),
                                         size=batch_size,
                                         replace=False,
                                         p=probs)

        batch_pos = np.zeros((batch_size, self.T, self.V, self.C),
                             dtype='float32')
        batch_traj = np.zeros((batch_size, self.T, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D),
                              dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS),
                                              dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]),
                                dtype='float32')

        for i, k in enumerate(rand_keys):
            pos = dataset[str(k)]['positions_world']
            traj = dataset[str(k)]['trajectory']
            quat = dataset[str(k)]['rotations']
            orient = dataset[str(k)]['orientations']
            affs = dataset[str(k)]['affective_features']
            spline, phase = Spline.extract_spline_features(
                dataset[str(k)]['spline'])
            root_speed = dataset[str(k)]['trans_and_controls'][:, -1].reshape(
                -1, 1)
            labels = dataset[str(k)]['labels'][:self.num_labels[0]]

            batch_pos[i] = pos
            batch_traj[i] = traj
            batch_quat[i] = quat.reshape(self.T, -1)
            batch_orient[i] = orient.reshape(self.T, -1)
            batch_affs[i] = affs
            batch_spline[i] = spline
            batch_phase_and_root_speed[i] = np.concatenate((phase, root_speed),
                                                           axis=-1)
            batch_labels[i] = np.expand_dims(labels, axis=0)

        return batch_pos, batch_traj, batch_quat, batch_orient, batch_affs, batch_spline,\
               batch_phase_and_root_speed, batch_labels

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch(
                self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            quat = torch.from_numpy(quat).cuda()
            orient = torch.from_numpy(orient).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            p_rs = torch.from_numpy(p_rs).cuda()
            labels = torch.from_numpy(labels).cuda()

            pos_pred = pos.clone()
            quat_pred = quat.clone()
            p_rs_pred = p_rs.clone()
            affs_pred = affs.clone()
            pos_pred_all = pos.clone()
            quat_pred_all = quat.clone()
            p_rs_pred_all = p_rs.clone()
            affs_pred_all = affs.clone()
            prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]

            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0],
                                               prenorm_terms.shape[1], -1,
                                               self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean(
                (torch.sum(prenorm_terms**2, dim=-1) - 1)**2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred_all[:, self.prefix_length - 1:],
                quat[:, self.prefix_length - 1:], self.V, self.D)
            quat_loss *= self.args.quat_reg

            p_rs_loss = self.p_rs_loss_func(
                p_rs_pred_all[:, self.prefix_length:],
                p_rs[:, self.prefix_length:])
            affs_loss = self.affs_loss_func(
                affs_pred_all[:, self.prefix_length:],
                affs[:, self.prefix_length:])
            # recons_loss = self.args.recons_reg *\
            #               (pos_pred_all[:, self.prefix_length:] - pos_pred_all[:, self.prefix_length:, 0:1] -
            #                 pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()

            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + p_rs_loss + affs_loss  # + recons_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()

    def per_test(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        valid_loss = 0.
        N = 0.

        for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch(
                self.args.batch_size, test_loader):
            pos = torch.from_numpy(pos).cuda()
            quat = torch.from_numpy(quat).cuda()
            orient = torch.from_numpy(orient).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            p_rs = torch.from_numpy(p_rs).cuda()
            labels = torch.from_numpy(labels).cuda()

            pos_pred = pos.clone()
            quat_pred = quat.clone()
            p_rs_pred = p_rs.clone()
            affs_pred = affs.clone()
            prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])

            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0],
                                               prenorm_terms.shape[1], -1,
                                               self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean(
                (torch.sum(prenorm_terms**2, dim=-1) - 1)**2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred[:, self.prefix_length - 1:],
                quat[:, self.prefix_length - 1:], self.V, self.D)
            quat_loss *= self.args.quat_reg

            recons_loss = self.args.recons_reg *\
                          (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] -
                           pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()
            valid_loss += recons_loss
            N += quat.shape[0]

        valid_loss /= N
        # if self.meta_info['epoch'] > 5 and self.loss_updated:
        #     pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_pred_np, self.V, self.C, self.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
        #                        overwrite=True)
        #     pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_in_np, self.V, self.C, self.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) +
        #                                                               '_gt',
        #                        overwrite=True)

        self.epoch_info['mean_loss'] = valid_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()

    def train(self):

        if self.args.load_last_best:
            self.load_best_model()
            self.args.start_epoch = self.best_loss_epoch
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval
                    == 0) or (epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test()
                self.io.print_log('Done.')

            # save model and weights
            if self.loss_updated:
                torch.save(
                    {
                        'model_dict': self.model.state_dict(),
                        'quat_h': self.quat_h
                    },
                    os.path.join(
                        self.args.work_dir,
                        'epoch_{}_loss_{:.4f}_acc_{:.2f}_model.pth.tar'.format(
                            epoch, self.best_loss, self.best_mean_ap * 100.)))

                if self.generate_while_train:
                    self.generate_motion(load_saved_model=False,
                                         samples_to_generate=1)

    def copy_prefix(self, var, target_length):
        shape = list(var.shape)
        shape[1] = self.prefix_length + target_length
        var_pred = torch.zeros(torch.Size(shape)).cuda().float()
        var_pred[:, :self.prefix_length] = var[:, :self.prefix_length]
        return var_pred

    def flip_trajectory(self, traj, target_length):
        traj_flipped = traj[:, -(target_length - self.target_length):].flip(
            dims=[1])
        orient_flipped = torch.zeros(
            (traj_flipped.shape[0], traj_flipped.shape[1], 1)).cuda().float()
        # orient_flipped[:, 0] = np.pi
        # traj_diff = traj_flipped[:, 1:, [0, 2]] - traj_flipped[:, :-1, [0, 2]]
        # traj_diff /= torch.norm(traj_diff, dim=-1)[..., None]
        # orient_flipped[:, 1:, 0] = torch.atan2(traj_diff[:, :, 1], traj_diff[:, :, 0])
        return traj_flipped, orient_flipped

    def generate_motion(self,
                        load_saved_model=True,
                        target_length=100,
                        samples_to_generate=10):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, traj, quat, orient, affs, spline, p_rs, labels = self.return_batch(
            [samples_to_generate], test_loader)
        pos = torch.from_numpy(pos).cuda()
        traj = torch.from_numpy(traj).cuda()
        quat = torch.from_numpy(quat).cuda()
        orient = torch.from_numpy(orient).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        p_rs = torch.from_numpy(p_rs).cuda()
        labels = torch.from_numpy(labels).cuda()

        traj_flipped, orient_flipped = self.flip_trajectory(
            traj, target_length)
        traj = torch.cat((traj, traj_flipped), dim=1)
        orient = torch.cat((orient, orient_flipped), dim=1)

        pos_pred = self.copy_prefix(pos, target_length)
        quat_pred = self.copy_prefix(quat, target_length)
        p_rs_pred = self.copy_prefix(p_rs, target_length)
        affs_pred = self.copy_prefix(affs, target_length)
        spline_pred = self.copy_prefix(spline, target_length)

        # forward
        with torch.no_grad():
            for t in range(target_length):
                quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline_pred[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=False)
                data_pred = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, :self.prefix_length + t],
                        orient[:, :self.prefix_length + t],
                        traj[:, self.prefix_length + t:self.prefix_length + t + 1],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])
                pos_pred[:, self.prefix_length + t:self.prefix_length + t +
                         1] = data_pred['positions_world']
                affs_pred[:, self.prefix_length + t:self.prefix_length + t +
                          1] = data_pred['affective_features']
                spline_pred[:, self.prefix_length + t:self.prefix_length + t +
                            1] = data_pred['spline']


            recons_loss = self.args.recons_reg *\
                          (pos_pred[:, self.prefix_length:self.T] - pos_pred[:, self.prefix_length:self.T, 0:1] -
                           pos[:, self.prefix_length:self.T] + pos[:, self.prefix_length:self.T, 0:1]).norm()

        pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()
        pos_np = pos.contiguous().view(pos.shape[0], pos.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()
        display_animations(pos_pred_np,
                           self.V,
                           self.C,
                           self.joint_parents,
                           save=True,
                           dataset_name=self.dataset,
                           subset_name='epoch_' + str(self.best_loss_epoch),
                           overwrite=True)
        display_animations(pos_np,
                           self.V,
                           self.C,
                           self.joint_parents,
                           save=True,
                           dataset_name=self.dataset,
                           subset_name='epoch_' + str(self.best_loss_epoch) +
                           '_gt',
                           overwrite=True)
        self.mocap.save_as_bvh(
            traj.detach().cpu().numpy(),
            orient.detach().cpu().numpy(),
            np.reshape(quat_pred.detach().cpu().numpy(),
                       (quat_pred.shape[0], quat_pred.shape[1], -1, self.D)),
            'render/bvh')
예제 #13
0
    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        max_steps=300,
                        randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_pred = torch.zeros_like(quat)
            quat_pred_pre_norm = torch.zeros_like(quat)
            quat_pred[:, 0] = torch.cat(
                quat_pred.shape[0] * [self.quats_sos]).view(quat_pred[:,
                                                                      0].shape)
            text_latent = self.model(text,
                                     intended_emotion,
                                     intended_polarity,
                                     acting_task,
                                     gender,
                                     age,
                                     handedness,
                                     native_tongue,
                                     only_encoder=True)
            for t in range(1, self.T):
                quat_pred_curr, quat_pred_pre_norm_curr = \
                    self.model(text_latent, quat=quat_pred[:, 0:t],
                               offset_lengths=joint_lengths / scales[..., None],
                               only_decoder=True)
                quat_pred[:, t:t + 1] = quat_pred_curr[:, -1:].clone()
                quat_pred_pre_norm_curr[:, t:t +
                                        1] = quat_pred_pre_norm_curr[:,
                                                                     -1:].clone(
                                                                     )

            for s in range(len(quat_pred)):
                quat_pred[s] = qfix(
                    quat_pred[s].view(quat_pred[s].shape[0], self.V,
                                      -1)).view(quat_pred[s].shape[0], -1)

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset + '_new',
                                 subset_name='test')
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos,
            'rotations': quat
        }
        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset + '_new',
                                 subset_name='gt')
        pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                 pos_pred.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()
예제 #14
0
def load_edin_data(_path,
                   num_labels,
                   frame_drop=1,
                   add_mirrored=False,
                   randomized=True):
    if add_mirrored:
        edin_data_dict_file = os.path.join(
            _path,
            'edin_data_dict_with_mirrored_drop_' + str(frame_drop) + '.npz')
    else:
        edin_data_dict_file = os.path.join(
            _path, 'edin_data_dict_drop_' + str(frame_drop) + '.npz')
    try:
        data_dict = np.load(edin_data_dict_file,
                            allow_pickle=True)['data_dict'].item()
        print('Data file found. Returning data.')
    except FileNotFoundError:
        data_dict = dict()
        data_counter = 0
        bvh_files = glob.glob(
            os.path.join(_path, 'data_edin_original') + '/*.bvh')
        num_files = len(bvh_files)
        discard_idx = []
        for fidx, bvh_file in enumerate(bvh_files):
            names, parents, offsets, \
            positions, rotations = MocapDataset.load_bvh([f for f in bvh_files if str(fidx).zfill(6) in f][0])
            if len(positions) < 241:
                discard_idx.append(fidx)
                continue
            positions_down_sampled = positions[1::frame_drop]
            rotations_down_sampled = rotations[1::frame_drop]
            joints_dict = dict()
            joints_dict['joints_to_model'] = np.arange(len(parents))
            joints_dict['joint_parents_all'] = parents
            joints_dict['joint_parents'] = parents
            joints_dict['joint_names_all'] = names
            joints_dict['joint_names'] = names
            joints_dict['joint_offsets_all'] = offsets
            joints_dict['joints_left'] = [
                idx for idx, name in enumerate(names)
                if 'left' in name.lower()
            ]
            joints_dict['joints_right'] = [
                idx for idx, name in enumerate(names)
                if 'right' in name.lower()
            ]
            dict_key = str(data_counter)
            data_counter += 1
            data_dict[dict_key] = dict()
            data_dict[dict_key]['joints_dict'] = joints_dict
            data_dict[dict_key]['positions'] = positions_down_sampled
            data_dict[dict_key]['rotations'] = rotations_down_sampled
            data_dict[dict_key][
                'affective_features'] = MocapDataset.get_affective_features(
                    positions_down_sampled)
            data_dict[dict_key]['trans_and_controls'] =\
                MocapDataset.compute_translations_and_controls(data_dict[dict_key])
            data_dict[dict_key]['spline'] = MocapDataset.compute_splines(
                data_dict[dict_key])
            print('\rData file not found. Processing file: {:3.2f}%'.format(
                fidx * 100. / num_files),
                  end='')
        print('\rData file not found. Processing file: done. Saving...',
              end='')
        labels, num_annotators = load_edin_labels(
            _path, np.array([num_files], dtype='int'))
        if add_mirrored:
            labels = np.repeat(labels, 2, axis=0)
        label_partitions = np.append([0], np.cumsum(num_labels))
        for lpidx in range(len(num_labels)):
            labels[:, label_partitions[lpidx]:label_partitions[lpidx + 1]] = \
                labels[:, label_partitions[lpidx]:label_partitions[lpidx + 1]] / \
                np.linalg.norm(labels[:, label_partitions[lpidx]:label_partitions[lpidx + 1]], ord=1, axis=1)[:, None]
        for data_counter, idx in enumerate(
                np.setdiff1d(range(num_files), discard_idx)):
            data_dict[str(data_counter)]['labels'] = labels[idx]
        np.savez_compressed(edin_data_dict_file, data_dict=data_dict)
        print('done. Returning data.')
    return data_dict, split_data_dict(data_dict, randomized=randomized)
예제 #15
0
def load_data_with_glove(_path,
                         dataset,
                         embedding_src,
                         frame_drop=1,
                         add_mirrored=False):
    data_path = os.path.join(_path, dataset)
    data_dict_file = os.path.join(
        data_path, 'data_dict_glove_drop_' + str(frame_drop) + '.npz')
    try:
        data_dict = np.load(data_dict_file,
                            allow_pickle=True)['data_dict'].item()
        word2idx = np.load(data_dict_file,
                           allow_pickle=True)['word2idx'].item()
        embedding_table = np.load(data_dict_file,
                                  allow_pickle=True)['embedding_table']
        tag_categories = list(
            np.load(data_dict_file, allow_pickle=True)['tag_categories'])
        max_time_steps = np.load(data_dict_file,
                                 allow_pickle=True)['max_time_steps'].item()
        print('Data file found. Returning data.')
    except FileNotFoundError:
        data_dict = []
        word2idx = []
        embedding_table = []
        tag_categories = []
        max_time_steps = 0.
        if dataset == 'mpi':
            channel_map = {
                'Xrotation': 'x',
                'Yrotation': 'y',
                'Zrotation': 'z'
            }
            data_dict = dict()
            tag_names = []
            with open(os.path.join(data_path, 'tag_names.txt')) as names_file:
                for line in names_file.readlines():
                    line = line[:-1]
                    tag_names.append(line)
            id = tag_names.index('ID')
            relevant_tags = [
                'Intended emotion', 'Intended polarity', 'Perceived category',
                'Perceived polarity', 'Acting task', 'Gender', 'Age',
                'Handedness', 'Native tongue', 'Text'
            ]
            tag_categories = [[] for _ in range(len(relevant_tags) - 1)]
            tag_files = glob.glob(os.path.join(data_path, 'tags/*.txt'))
            num_files = len(tag_files)
            for tag_file in tag_files:
                tag_data = []
                with open(tag_file) as f:
                    for line in f.readlines():
                        line = line[:-1]
                        tag_data.append(line)
                for category in range(len(tag_categories)):
                    tag_to_append = relevant_tags[category]
                    if tag_data[tag_names.index(
                            tag_to_append)] not in tag_categories[category]:
                        tag_categories[category].append(
                            tag_data[tag_names.index(tag_to_append)])

            all_texts = [[] for _ in range(len(tag_files))]
            for data_counter, tag_file in enumerate(tag_files):
                tag_data = []
                with open(tag_file) as f:
                    for line in f.readlines():
                        line = line[:-1]
                        tag_data.append(line)
                bvh_file = os.path.join(data_path,
                                        'bvh/' + tag_data[id] + '.bvh')
                names, parents, offsets,\
                positions, rotations = MocapDataset.load_bvh(bvh_file, channel_map)
                positions_down_sampled = positions[1::frame_drop]
                rotations_down_sampled = rotations[1::frame_drop]
                if len(positions_down_sampled) > max_time_steps:
                    max_time_steps = len(positions_down_sampled)
                joints_dict = dict()
                joints_dict['joints_to_model'] = np.arange(len(parents))
                joints_dict['joints_parents_all'] = parents
                joints_dict['joints_parents'] = parents
                joints_dict['joints_names_all'] = names
                joints_dict['joints_names'] = names
                joints_dict['joints_offsets_all'] = offsets
                joints_dict['joints_left'] = [
                    idx for idx, name in enumerate(names)
                    if 'left' in name.lower()
                ]
                joints_dict['joints_right'] = [
                    idx for idx, name in enumerate(names)
                    if 'right' in name.lower()
                ]
                data_dict[tag_data[id]] = dict()
                data_dict[tag_data[id]]['joints_dict'] = joints_dict
                data_dict[tag_data[id]]['positions'] = positions_down_sampled
                data_dict[tag_data[id]]['rotations'] = rotations_down_sampled
                data_dict[tag_data[id]]['affective_features'] =\
                    MocapDataset.get_mpi_affective_features(positions_down_sampled)
                for tag_index, tag_name in enumerate(relevant_tags):
                    if tag_name.lower() == 'text':
                        all_texts[data_counter] = [
                            e for e in str.split(tag_data[tag_names.index(
                                tag_name)]) if e.isalnum()
                        ]
                        data_dict[tag_data[id]][tag_name] = tag_data[
                            tag_names.index(tag_name)]
                        text_length = len(data_dict[tag_data[id]][tag_name])
                        continue
                    if tag_name.lower() == 'age':
                        data_dict[tag_data[id]][tag_name] = float(
                            tag_data[tag_names.index(tag_name)]) / 100.
                        continue
                    if tag_name is 'Perceived category':
                        categories = tag_categories[0]
                    elif tag_name is 'Perceived polarity':
                        categories = tag_categories[1]
                    else:
                        categories = tag_categories[tag_index]
                    data_dict[tag_data[id]][tag_name] = to_one_hot(
                        tag_data[tag_names.index(tag_name)], categories)
                print(
                    '\rData file not found. Reading data files {}/{}: {:3.2f}%'
                    .format(data_counter + 1, num_files,
                            data_counter * 100. / num_files),
                    end='')
            print('\rData file not found. Reading files: done.')
            print('Preparing embedding table:')
            word2idx = build_vocab_idx(all_texts, min_word_count=0)
            embedding_table = build_embedding_table(embedding_src, word2idx)
            np.savez_compressed(data_dict_file,
                                data_dict=data_dict,
                                word2idx=word2idx,
                                embedding_table=embedding_table,
                                tag_categories=tag_categories,
                                max_time_steps=max_time_steps)
            print('done. Returning data.')
        elif dataset == 'creative_it':
            mocap_data_dirs = os.listdir(os.path.join(data_path, 'mocap'))
            for mocap_dir in mocap_data_dirs:
                mocap_data_files = glob.glob(
                    os.path.join(data_path, 'mocap/' + mocap_dir + '/*.txt'))
        else:
            raise FileNotFoundError('Dataset not found.')

    return data_dict, word2idx, embedding_table, tag_categories, max_time_steps
예제 #16
0
def load_cmu_data(_path,
                  V,
                  C,
                  joints_to_model=None,
                  frame_drop=1,
                  add_mirrored=False):
    data_path = os.path.join(_path, 'data_cmu_cleaned')
    cmu_data_dict_file = os.path.join(
        _path, 'cmu_data_dict_drop_' + str(frame_drop) + '.npz')
    try:
        data_dict = np.load(cmu_data_dict_file,
                            allow_pickle=True)['data_dict'].item()
        min_time_steps = np.load(cmu_data_dict_file,
                                 allow_pickle=True)['min_time_steps'].item()
        print('Data file found. Returning data.')
    except FileNotFoundError:
        cmu_data_files = glob.glob(os.path.join(data_path, '*.bvh'))
        channel_map = {'Xrotation': 'x', 'Yrotation': 'y', 'Zrotation': 'z'}
        mocap = MocapDataset(V, C)
        data_dict = dict()
        labels_file = os.path.join(_path, 'labels_cmu/cmu_labels.csv')
        data_name = ['' for _ in range(len(cmu_data_files))]
        labels = []
        with open(os.path.join(labels_file)) as csv_file:
            read_lines = csv.reader(csv_file, delimiter=',')
            row_count = -1
            for row in read_lines:
                row_count += 1
                if row_count == 0:
                    labels_order = [x.lower() for x in row[1:]]
                    continue
                data_name[row_count - 1] = row[0]
                labels.append(list(map(float, row[1:])))
        labels = np.stack(labels)
        labels /= np.linalg.norm(labels, ord=1, axis=-1)[..., None]
        emo_idx = [
            labels_order.index(x)
            for x in ['happy', 'sad', 'angry', 'neutral']
        ]
        labels = labels[:, emo_idx]
        num_files = len(cmu_data_files)
        min_time_steps = np.inf
        for data_counter, file in enumerate(cmu_data_files):
            offsets, positions, orientations, rot_in, rotations =\
                mocap.load_bvh(file, channel_map, joints_to_model=joints_to_model)
            if len(positions) - 1 < min_time_steps:
                min_time_steps = len(positions) - 1
            data_dict[str(data_counter)] = mocap.get_features_from_data(
                'cmu',
                offsets=offsets,
                positions=positions[1::frame_drop],
                orientations=orientations[1::frame_drop],
                rotations=rotations[1::frame_drop])
            file_name = file.split('/')[-1].split('.')[0]
            # mocap.save_as_bvh(np.expand_dims(positions[:, 0], axis=0),
            #                   np.expand_dims(np.expand_dims(orientations, axis=-1), axis=0),
            #                   np.expand_dims(rotations, axis=0),
            #                   np.expand_dims(rot_in, axis=0),
            #                   dataset_name='cmu', subset_name='test', save_file_names=[file_name])
            data_dict[str(data_counter)]['labels'] = labels[data_name.index(
                file_name)]
            print('\rData file not found. Processing file {}/{}: {:3.2f}%'.
                  format(data_counter + 1, num_files,
                         data_counter * 100. / num_files),
                  end='')
        min_time_steps = int(min_time_steps / frame_drop)
        print('\rData file not found. Processing files: done. Saving...',
              end='')
        np.savez_compressed(cmu_data_dict_file,
                            data_dict=data_dict,
                            min_time_steps=min_time_steps)
        print('done. Returning data.')
    return data_dict, min_time_steps
예제 #17
0
import torch
from utils.mocap_dataset import MocapDataset as MD

anim = MD.load_bvh(
    '/media/uttaran/repo0/Gamma/MotionSim/src/quater_long_term_emonet_2/render/bvh/edin/test/000001.bvh'
)
animation_pred = {
    'joint_names': anim[0],
    'joint_offsets': torch.from_numpy(anim[2][1:]).unsqueeze(0),
    'joint_parents': anim[1],
    'positions': torch.from_numpy(anim[3]).unsqueeze(0),
    'rotations': torch.from_numpy(anim[4][1:]).unsqueeze(0)
}
MD.save_as_bvh(animation_pred, dataset_name='edin', subset_name='fixed')