示例#1
0
    def per_eval(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        eval_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
                native_tongue in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_prelude = self.quats_eos.view(1, -1).cuda() \
                    .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
                quat_prelude[:, -1] = quat[:, 0].clone()
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])
                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                row_sums = quat_valid_idx.sum(1,
                                              keepdim=True) * self.D * self.V
                row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg * torch.mean(
                #     (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight * \
                #                    (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg * \
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss
                # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                N += quat.shape[0]

        eval_loss /= N
        self.epoch_info['mean_loss'] = eval_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()
示例#2
0
    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx,\
            text, text_valid_idx, intended_emotion, intended_polarity,\
                acting_task, gender, age, handedness,\
                native_tongue in self.yield_batch(self.args.batch_size, train_loader):
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()

            self.optimizer.zero_grad()
            with torch.autograd.detect_anomaly():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])

                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
                # row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg *\
                #               torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight *\
                #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg *\
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
                # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                train_loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
                self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred,
            #     'rotations': quat_pred
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')
            # animation = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos,
            #     'rotations': quat
            # }
            # MocapDataset.save_as_bvh(animation,
            #                          dataset_name=self.dataset,
            #                          subset_name='gt')

            # Compute statistics
            batch_loss += train_loss.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = train_loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()
示例#3
0
    def forward_pass(self, joint_offsets, pos, affs, quat, quat_sos,
                     quat_valid_idx, text, text_valid_idx, perceived_emotion,
                     perceived_polarity, acting_task, gender, age, handedness,
                     native_tongue):
        self.optimizer.zero_grad()
        with torch.autograd.detect_anomaly():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            # text_latent = self.model(text, perceived_emotion, perceived_polarity,
            #                          acting_task, gender, age, handedness, native_tongue,
            #                          only_encoder=True)
            # quat_pred = torch.zeros_like(quat)
            # quat_pred_pre_norm = torch.zeros_like(quat)
            # quat_in = quat_sos[:, :self.T_steps]
            # quat_valid_idx_max = torch.max(torch.sum(quat_valid_idx, dim=-1))
            # for t in range(0, self.T, self.T_steps):
            #     if t > quat_valid_idx_max:
            #         break
            #     quat_pred[:, t:min(self.T, t + self.T_steps)],\
            #         quat_pred_pre_norm[:, t:min(self.T, t + self.T_steps)] =\
            #         self.model(text_latent, quat=quat_in, offset_lengths= joint_lengths / scales[..., None],
            #                    only_decoder=True)
            #     if torch.rand(1) > self.tf:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat_pred[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat_pred[:, t:t + self.T_steps].clone()
            #     else:
            #         if t + self.T_steps * 2 >= self.T:
            #             quat_in = quat[:, -(self.T - t - self.T_steps):].clone()
            #         else:
            #             quat_in = quat[:, t:min(self.T, t + self.T_steps)].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, perceived_emotion, perceived_polarity, acting_task,
                gender, age, handedness, native_tongue, quat_sos[:, :-1],
                joint_lengths / scales[..., None])
            quat_fixed = qfix(quat.contiguous().view(
                quat.shape[0], quat.shape[1], -1,
                self.D)).contiguous().view(quat.shape[0], quat.shape[1], -1)
            quat_pred = qfix(quat_pred.contiguous().view(
                quat_pred.shape[0], quat_pred.shape[1], -1,
                self.D)).contiguous().view(quat_pred.shape[0],
                                           quat_pred.shape[1], -1)

            quat_pred_pre_norm = quat_pred_pre_norm.view(
                quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1,
                self.D)
            quat_norm_loss = self.args.quat_norm_reg *\
                torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

            # quat_loss, quat_derv_loss = losses.quat_angle_loss(quat_pred, quat_fixed,
            #                                                    quat_valid_idx[:, 1:],
            #                                                    self.V, self.D,
            #                                                    self.lower_body_start,
            #                                                    self.args.upper_body_weight)
            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred, quat_fixed[:, 1:], quat_valid_idx[:, 1:], self.V,
                self.D, self.lower_body_start, self.args.upper_body_weight)
            quat_loss *= self.args.quat_reg

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))
            affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

            row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
            row_sums[row_sums == 0.] = 1.

            shifted_pos = pos - pos[:, :, 0:1]
            shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

            # recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos)
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, :, 7:15])
            recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                shifted_pos[:, 1:])
            # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, 1:, 7:15])
            # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
            # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #               recons_loss[:, :, self.lower_body_start:].sum(-1)
            recons_loss = self.args.recons_reg *\
                torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            #
            # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
            #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
            # recons_derv_loss = self.args.upper_body_weight *\
            #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
            #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
            # recons_derv_loss = 2. * self.args.recons_reg *\
            #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
            #
            # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
            affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])
            affs_loss = self.args.affs_reg * torch.mean(
                (affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
            # affs_loss = self.affs_loss_func(affs_pred, affs)

            total_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
            # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss

        return total_loss
示例#4
0
    def per_test(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        valid_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                pos = torch.from_numpy(pos).cuda()
                orient = torch.from_numpy(orient).cuda()
                quat = torch.from_numpy(quat).cuda()
                z_mean = torch.from_numpy(z_mean).cuda()
                z_dev = torch.from_numpy(z_dev).cuda()
                root_speed = torch.from_numpy(root_speed).cuda()
                affs = torch.from_numpy(affs).cuda()
                spline = torch.from_numpy(spline).cuda()
                labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
                z_rs = torch.cat((z_dev, root_speed), dim=-1)
                quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

                pos_pred = pos.clone()
                orient_pred = orient.clone()
                quat_pred = quat.clone()
                z_rs_pred = z_rs.clone()
                affs_pred = affs.clone()
                spline_pred = spline.clone()
                orient_prenorm_terms = torch.zeros_like(orient_pred)
                quat_prenorm_terms = torch.zeros_like(quat_pred)

                # forward
                for t in range(self.target_length):
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        self.orient_h, self.quat_h,\
                        orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                        quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                        self.model(
                            orient_pred[:, t:self.prefix_length + t],
                            quat_pred[:, t:self.prefix_length + t],
                            z_rs_pred[:, t:self.prefix_length + t],
                            affs_pred[:, t:self.prefix_length + t],
                            spline[:, t:self.prefix_length + t],
                            labels[:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else self.orient_h,
                            quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                        affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[:, :self.prefix_length + t],
                            pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                            orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])

                prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
                prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                    torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    torch.cat((orient_pred[:, self.prefix_length - 1:],
                               quat_pred[:, self.prefix_length - 1:]), dim=-1),
                    quat_all, self.V, self.D)
                quat_loss *= self.args.quat_reg

                recons_loss = self.args.recons_reg *\
                              (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] -
                               pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()
                valid_loss += recons_loss
                N += quat.shape[0]

        valid_loss /= N
        # if self.meta_info['epoch'] > 5 and self.loss_updated:
        #     pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
        #                        overwrite=True)
        #     pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_in_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) +
        #                                                               '_gt',
        #                        overwrite=True)

        self.epoch_info['mean_loss'] = valid_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info['epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()
示例#5
0
    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            orient = torch.from_numpy(orient).cuda()
            quat = torch.from_numpy(quat).cuda()
            z_mean = torch.from_numpy(z_mean).cuda()
            z_dev = torch.from_numpy(z_dev).cuda()
            root_speed = torch.from_numpy(root_speed).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
            z_rs = torch.cat((z_dev, root_speed), dim=-1)
            quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

            pos_pred = pos.clone()
            orient_pred = orient.clone()
            quat_pred = quat.clone()
            z_rs_pred = z_rs.clone()
            affs_pred = affs.clone()
            spline_pred = spline.clone()
            pos_pred_all = pos.clone()
            orient_pred_all = orient.clone()
            quat_pred_all = quat.clone()
            z_rs_pred_all = z_rs.clone()
            affs_pred_all = affs.clone()
            spline_pred_all = spline.clone()
            orient_prenorm_terms = torch.zeros_like(orient_pred)
            quat_prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    self.orient_h, self.quat_h,\
                    orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                    quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        orient_pred[:, t:self.prefix_length + t],
                        quat_pred[:, t:self.prefix_length + t],
                        z_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline_pred[:, t:self.prefix_length + t],
                        labels[:, t:self.prefix_length + t],
                        orient_h=None if t == 0 else self.orient_h,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, :self.prefix_length + t],
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                        orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()

            prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                torch.cat((orient_pred_all[:, self.prefix_length - 1:],
                           quat_pred_all[:, self.prefix_length - 1:]), dim=-1),
                quat_all, self.V, self.D)
            quat_loss *= self.args.quat_reg

            z_rs_loss = self.z_rs_loss_func(z_rs_pred_all[:, self.prefix_length:],
                                            z_rs[:, self.prefix_length:])
            spline_loss = self.spline_loss_func(spline_pred_all[:, self.prefix_length:],
                                                spline[:, self.prefix_length:])
            fs_loss = losses.foot_speed_loss(pos_pred, pos)
            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + z_rs_loss + spline_loss + fs_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).
            #         float().unsqueeze(0).repeat(pos_pred_all.shape[0], 1, 1),
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred_all,
            #     'rotations': torch.cat((orient_pred_all, quat_pred_all), dim=-1)
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()
    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch(
                self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            quat = torch.from_numpy(quat).cuda()
            orient = torch.from_numpy(orient).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            p_rs = torch.from_numpy(p_rs).cuda()
            labels = torch.from_numpy(labels).cuda()

            pos_pred = pos.clone()
            quat_pred = quat.clone()
            p_rs_pred = p_rs.clone()
            affs_pred = affs.clone()
            pos_pred_all = pos.clone()
            quat_pred_all = quat.clone()
            p_rs_pred_all = p_rs.clone()
            affs_pred_all = affs.clone()
            prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]

            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0],
                                               prenorm_terms.shape[1], -1,
                                               self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean(
                (torch.sum(prenorm_terms**2, dim=-1) - 1)**2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred_all[:, self.prefix_length - 1:],
                quat[:, self.prefix_length - 1:], self.V, self.D)
            quat_loss *= self.args.quat_reg

            p_rs_loss = self.p_rs_loss_func(
                p_rs_pred_all[:, self.prefix_length:],
                p_rs[:, self.prefix_length:])
            affs_loss = self.affs_loss_func(
                affs_pred_all[:, self.prefix_length:],
                affs[:, self.prefix_length:])
            # recons_loss = self.args.recons_reg *\
            #               (pos_pred_all[:, self.prefix_length:] - pos_pred_all[:, self.prefix_length:, 0:1] -
            #                 pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()

            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + p_rs_loss + affs_loss  # + recons_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()