Ejemplo n.º 1
0
def get_joints_from_mocap_data(data, apply_transformations=True):
    data = np.moveaxis(np.squeeze(data), -1, 0)
    if not apply_transformations:
        joints = np.copy(data)
    else:
        if data.shape[-1] == 73:
            joints, root_x, root_z, root_r = data[:, 3:-7], data[:, -7], data[:, -6], data[:, -5]
        elif data.shape[-1] == 66:
            joints, root_x, root_z, root_r = data[:, :-3], data[:, -3], data[:, -2], data[:, -1]
    joints = joints.reshape((len(joints), -1, 3))

    rotations = np.empty((len(joints), 4))
    translations = np.empty((len(joints), 3))
    rotations[0, :] = Quaternions.id(1).qs
    offsets = []
    translations[0, :] = np.array([[0, 0, 0]])

    if apply_transformations:
        for i in range(len(joints)):
            joints[i, :, :] = Quaternions(rotations[i, :]) * joints[i]
            joints[i, :, 0] = joints[i, :, 0] + translations[i, 0]
            joints[i, :, 2] = joints[i, :, 2] + translations[i, 2]
            if i + 1 < len(joints):
                rotations[i + 1, :] = (Quaternions.from_angle_axis(-root_r[i], np.array([0, 1, 0])) *
                                       Quaternions(rotations[i, :])).qs
                offsets.append(Quaternions(rotations[i + 1, :]) * np.array([0, 0, 1]))
                translations[i + 1, :] = translations[i, :] +\
                                         Quaternions(rotations[i + 1, :]) * np.array([root_x[i], 0, root_z[i]])
    return joints, rotations, translations
    def get_positions_and_transformations(self, raw_data, mirrored=False):
        data = np.swapaxes(np.squeeze(raw_data), -1, 0)
        if data.shape[-1] == 73:
            positions, root_x, root_z, root_r = data[:, 3:-7], data[:, -7], data[:, -6], data[:, -5]
        elif data.shape[-1] == 66:
            positions, root_x, root_z, root_r = data[:, :-3], data[:, -3], data[:, -2], data[:, -1]
        else:
            raise AssertionError('Input data format not understood')
        num_frames = len(positions)
        positions_local = positions.reshape((num_frames, -1, 3))
        if mirrored:
            positions_local[:, self.joints_left], positions_local[:, self.joints_right] = \
            positions_local[:, self.joints_right], positions_local[:, self.joints_left]
            positions_local[:, :, [0, 2]] = -positions_local[:, :, [0, 2]]
        positions_world = np.zeros_like(positions_local)
        num_joints = positions_world.shape[1]

        trajectory = np.empty((num_frames, 3))
        orientations = np.empty(num_frames)
        rotations = np.zeros((num_frames, num_joints - 1, 4))
        cum_rotations = np.zeros((num_frames, 4))
        rotations_euler = np.zeros((num_frames, num_joints - 1, 3))
        cum_rotations_euler = np.zeros((num_frames, 3))
        translations = np.zeros((num_frames, num_joints, 3))
        cum_translations = np.zeros((num_frames, 3))
        offsets = []

        for t in range(num_frames):
            positions_world[t, :, :] = (Quaternions(cum_rotations[t - 1]) if t > 0 else Quaternions.id(1)) * \
                                       positions_local[t]
            positions_world[t, :, 0] = positions_world[t, :, 0] + (cum_translations[t - 1, 0] if t > 0 else 0)
            positions_world[t, :, 2] = positions_world[t, :, 2] + (cum_translations[t - 1, 2] if t > 0 else 0)
            trajectory[t] = positions_world[t, 0]
            # if t > 0:
            #     rotations[t, 1:] = Quaternions.between(positions_world[t - 1, 1:], positions_world[t, 1:]).qs
            # else:
            #     rotations[t, 1:] = Quaternions.id(positions_world.shape[1] - 1).qs
            limbs = positions_world[t, 1:] - positions_world[t, self.joint_parents[1:]]
            rotations[t] = Quaternions.between(self.joint_offsets[1:], limbs)
            # limb_recons = Quaternions(rotations[t, 1:]) * self._offsets[1:]
            # test_limbs = np.setdiff1d(np.arange(20), [12, 16])
            # if np.max(np.abs(limb_recons[test_limbs] - limbs[test_limbs])) > 1e-6:
            #     temp = 1
            rotations_euler[t] = Quaternions(rotations[t]).euler('yzx')
            orientations[t] = -root_r[t]
            # rotations[t, 0] = Quaternions.from_angle_axis(-root_r[t], np.array([0, 1, 0])).qs
            # rotations_euler[t, 0] = Quaternions(rotations[t, 0]).euler('yzx')
            cum_rotations[t] = (Quaternions.from_angle_axis(orientations[t], np.array([0, 1, 0])) *
                                (Quaternions(cum_rotations[t - 1]) if t > 0 else Quaternions.id(1))).qs
            # cum_rotations[t] = (Quaternions(rotations[t, 0]) *
            #                     (Quaternions(cum_rotations[t - 1]) if t > 0 else Quaternions.id(1))).qs
            cum_rotations_euler[t] = Quaternions(cum_rotations[t]).euler('yzx')
            offsets.append(Quaternions(cum_rotations[t]) * np.array([0, 0, 1]))
            translations[t, 0] = Quaternions(cum_rotations[t]) * np.array([root_x[t], 0, root_z[t]])
            cum_translations[t] = (cum_translations[t - 1] if t > 0 else np.zeros((1, 3))) + translations[t, 0]
        # limb_lengths = np.zeros((num_frames, 20))
        return positions_local, positions_world, trajectory, orientations, rotations, rotations_euler, \
               translations, cum_rotations, cum_rotations_euler, cum_translations, offsets
Ejemplo n.º 3
0
def get_del_pos_and_orientation(pos1, pos2, dim):
    # del_orientation, axis, theta = get_del_orientation(pos1, pos2, dim)
    # pos1_rotated = get_rotated_points(axis, theta, pos1)
    axis, theta = get_del_orientation(pos1, pos2, dim)
    quats = Quaternions.from_angle_axis(theta, axis).qs
    axis_norm = np.sum(axis ** 2, axis=-1) ** 0.5
    axis_normalized = axis / axis_norm[:, None]
    pos1_rotated = get_rotated_points(axis_normalized, theta, pos1)
    del_pos = np.reshape(pos2 - pos1_rotated, (-1, dim))
    # return np.append(del_pos, del_orientation, axis=1).flatten()
    return quats.flatten(), del_pos.flatten()
Ejemplo n.º 4
0
def load_ewalk_data(_path, coords, joints, upsample=1):

    file_feature = os.path.join(_path, 'features' + '_ewalk.h5')
    ff = h5py.File(file_feature, 'r')
    file_label = os.path.join(_path, 'labels' + '_ewalk.h5')
    fl = h5py.File(file_label, 'r')

    data_list = []
    num_samples = len(ff.keys())
    time_steps = 0
    labels = np.empty(num_samples)
    for si in range(num_samples):
        ff_group_key = list(ff.keys())[si]
        data_list.append(list(ff[ff_group_key]))  # Get the data
        time_steps_curr = len(ff[ff_group_key])
        if time_steps_curr > time_steps:
            time_steps = time_steps_curr
        labels[si] = fl[list(fl.keys())[si]][()]

    data = np.zeros((num_samples, time_steps * upsample, joints * coords))
    num_frames = np.empty(num_samples)
    for si in range(num_samples):
        data_list_curr_arr = np.array(data_list[si])
        tsteps_curr = len(data_list[si]) * upsample
        for lidx in range(data_list_curr_arr.shape[1]):
            data[si, :tsteps_curr,
                 lidx] = signal.resample(data_list_curr_arr[:, lidx],
                                         tsteps_curr)
            if lidx > 0 and lidx % coords == 0:
                temp = np.copy(data[si, :tsteps_curr, lidx - 1])
                data[si, :tsteps_curr,
                     lidx - 1] = np.copy(-data[si, :tsteps_curr, lidx - 2])
                data[si, :tsteps_curr, lidx - 2] = temp
                rotation = Quaternions.from_angle_axis(np.pi / 2.,
                                                       np.array([1, 0, 0]))
                for t in range(tsteps_curr):
                    data[si, t,
                         lidx - 3:lidx] = rotation * data[si, t, lidx - 3:lidx]
        num_frames[si] = tsteps_curr
    poses, differentials, affective_features = common.get_ewalk_differentials_with_padding(
        data, num_frames, coords)
    return train_test_split(poses,
                            differentials,
                            affective_features,
                            num_frames,
                            labels,
                            test_size=0.1)
Ejemplo n.º 5
0
    def get_positions_and_transformations(self, raw_data, mirrored=False):
        data = np.swapaxes(np.squeeze(raw_data), -1, 0)
        if data.shape[-1] == 73:
            positions, root_x, root_z, root_r = data[:, 3:
                                                     -7], data[:,
                                                               -7], data[:,
                                                                         -6], data[:,
                                                                                   -5]
        elif data.shape[-1] == 66:
            positions, root_x, root_z, root_r = data[:, :
                                                     -3], data[:,
                                                               -3], data[:,
                                                                         -2], data[:,
                                                                                   -1]
        else:
            raise AssertionError('Input data format not understood')
        num_frames = len(positions)
        positions_local = positions.reshape((num_frames, -1, 3))
        # positions_local[:, 11:13, [0, 2]] += 2. * (positions_local[:, 10:11, [0, 2]] - positions_local[:, 11:12, [0, 2]])
        # positions_local[:, 12, [0, 2]] = 2. * positions_local[:, 11, [0, 2]] - positions_local[:, 12, [0, 2]]
        # positions_local[:, 13:17, 0] += \
        #     np.sqrt(
        #         np.sum(np.square(positions_local[:, 13:14, [0, 2]] -
        #                          positions_local[:, 10:11, [0, 2]]), axis=-1) -\
        #         np.square(0.8)
        #     ) + positions_local[:, 10:11, 0] - positions_local[:, 13:14, 0]
        # positions_local[:, 13:17, 2] += positions_local[:, 10:11, 2] + 0.8 - positions_local[:, 13:14, 2]
        # positions_local[:, 17:, 0] += \
        #     positions_local[:, 10:11, 0] - \
        #     np.sqrt(
        #         np.sum(np.square(positions_local[:, 17:18, [0, 2]] -
        #                          positions_local[:, 10:11, [0, 2]]), axis=-1) -\
        #         np.square(0.5)
        #     ) - positions_local[:, 17:18, 0]
        # positions_local[:, 17:, 2] += positions_local[:, 10:11, 2] + 0.5 - positions_local[:, 17:18, 2]
        if mirrored:
            positions_local[:, self.joints_left], positions_local[:, self.joints_right] = \
            positions_local[:, self.joints_right], positions_local[:, self.joints_left]
            positions_local[:, :, [0, 2]] = -positions_local[:, :, [0, 2]]
        positions_world = np.zeros_like(positions_local)
        num_joints = positions_world.shape[1]

        trajectory = np.empty((num_frames, 3))
        orientations = np.empty(num_frames)
        rotations = np.zeros((num_frames, num_joints - 1, 4))
        cum_rotations = np.zeros((num_frames, 4))
        rotations_euler = np.zeros((num_frames, num_joints - 1, 3))
        cum_rotations_euler = np.zeros((num_frames, 3))
        translations = np.zeros((num_frames, num_joints, 3))
        cum_translations = np.zeros((num_frames, 3))
        offsets = []
        limbs_all = []

        for t in range(num_frames):
            positions_world[t, :, :] = (Quaternions(cum_rotations[t - 1]) if t > 0 else Quaternions.id(1)) * \
                                       positions_local[t]
            positions_world[t, :, 0] = positions_world[t, :, 0] + (
                cum_translations[t - 1, 0] if t > 0 else 0)
            positions_world[t, :, 2] = positions_world[t, :, 2] + (
                cum_translations[t - 1, 2] if t > 0 else 0)
            trajectory[t] = positions_world[t, 0]
            # if t > 0:
            #     rotations[t, 1:] = Quaternions.between(positions_world[t - 1, 1:], positions_world[t, 1:]).qs
            # else:
            #     rotations[t, 1:] = Quaternions.id(positions_world.shape[1] - 1).qs
            limbs = positions_world[t, 1:] - positions_world[
                t, self.joint_parents[1:]]
            rotations[t] = Quaternions.between(self.joint_offsets[1:], limbs)
            limbs_all.append(limbs)
            # limb_recons = Quaternions(rotations[t, 1:]) * self._offsets[1:]
            # test_limbs = np.setdiff1d(np.arange(20), [12, 16])
            # if np.max(np.abs(limb_recons[test_limbs] - limbs[test_limbs])) > 1e-6:
            #     temp = 1
            rotations_euler[t] = Quaternions(rotations[t]).euler('yzx')
            orientations[t] = -root_r[t]
            # rotations[t, 0] = Quaternions.from_angle_axis(-root_r[t], np.array([0, 1, 0])).qs
            # rotations_euler[t, 0] = Quaternions(rotations[t, 0]).euler('yzx')
            cum_rotations[t] = (Quaternions.from_angle_axis(
                orientations[t], np.array([0, 1, 0])) * (Quaternions(
                    cum_rotations[t - 1]) if t > 0 else Quaternions.id(1))).qs
            # cum_rotations[t] = (Quaternions(rotations[t, 0]) *
            #                     (Quaternions(cum_rotations[t - 1]) if t > 0 else Quaternions.id(1))).qs
            cum_rotations_euler[t] = Quaternions(cum_rotations[t]).euler('yzx')
            offsets.append(Quaternions(cum_rotations[t]) * np.array([0, 0, 1]))
            translations[t, 0] = Quaternions(cum_rotations[t]) * np.array(
                [root_x[t], 0, root_z[t]])
            cum_translations[t] = (cum_translations[t - 1] if t > 0 else
                                   np.zeros((1, 3))) + translations[t, 0]
        # limb_lengths =
        limbs_all = np.stack(limbs_all)
        self.save_as_bvh(np.expand_dims(positions_world[:, 0], 0),
                         np.expand_dims(orientations.reshape(-1, 1), 0),
                         np.expand_dims(rotations, 0),
                         dataset_name='edin',
                         subset_name='test')
        return positions_local, positions_world, trajectory, orientations, rotations, rotations_euler, \
               translations, cum_rotations, cum_rotations_euler, cum_translations, offsets
    def save_as_bvh(self, trajectory, orientations, quaternions, save_path,
                    save_file_names=None, frame_time=0.032):

        quaternions = np.concatenate((Quaternions.from_angle_axis(-orientations, np.array([0., 1., 0.])).qs,
                                      quaternions), axis=-2)
        num_joints = len(self.joint_parents_all)
        num_samples = quaternions.shape[0]
        for s in range(num_samples):
            num_frames = quaternions[s].shape[0]
            positions = np.tile(self.joint_offsets_all, (num_frames, 1, 1))
            positions[:, 0] = trajectory[s]
            orients = Quaternions.id(num_joints)
            save_file_name = os.path.join(
                save_path, save_file_names[s] if save_file_names is not None else str(s).zfill(6) + '.bvh')
            save_quats = np.zeros((num_frames, num_joints, quaternions.shape[-1]))
            save_quats[..., 0] = 1.
            save_quats[:, 1] = Quaternions(quaternions[s, :, 0]) * \
                               Quaternions(quaternions[s, :, 1])
            save_quats[:, 2] = Quaternions(quaternions[s, :, 1]).__neg__() * \
                               Quaternions(quaternions[s, :, 2])
            save_quats[:, 3] = Quaternions(quaternions[s, :, 2]).__neg__() * \
                               Quaternions(quaternions[s, :, 3])
            save_quats[:, 4] = Quaternions(quaternions[s, :, 3]).__neg__() * \
                               Quaternions(quaternions[s, :, 4])
            save_quats[:, 6] = Quaternions(quaternions[s, :, 0]) * \
                               Quaternions(quaternions[s, :, 5])
            save_quats[:, 7] = Quaternions(quaternions[s, :, 5]).__neg__() * \
                               Quaternions(quaternions[s, :, 6])
            save_quats[:, 8] = Quaternions(quaternions[s, :, 6]).__neg__() * \
                               Quaternions(quaternions[s, :, 7])
            save_quats[:, 9] = Quaternions(quaternions[s, :, 7]).__neg__() * \
                               Quaternions(quaternions[s, :, 8])
            save_quats[:, 11] = Quaternions(quaternions[s, :, 0]) * \
                                Quaternions(quaternions[s, :, 9])
            save_quats[:, 12] = Quaternions(quaternions[s, :, 9]).__neg__() * \
                                Quaternions(quaternions[s, :, 10])
            save_quats[:, 13] = Quaternions(quaternions[s, :, 10]).__neg__() * \
                                Quaternions(quaternions[s, :, 11])
            save_quats[:, 15] = Quaternions(quaternions[s, :, 11]).__neg__() * \
                                Quaternions(quaternions[s, :, 12])
            save_quats[:, 17] = Quaternions(quaternions[s, :, 11]).__neg__() * \
                                Quaternions(quaternions[s, :, 13])
            save_quats[:, 18] = Quaternions(quaternions[s, :, 13]).__neg__() * \
                                Quaternions(quaternions[s, :, 14])
            save_quats[:, 19] = Quaternions(quaternions[s, :, 14]).__neg__() * \
                                Quaternions(quaternions[s, :, 15])
            save_quats[:, 20] = Quaternions(quaternions[s, :, 15]).__neg__() * \
                                Quaternions(quaternions[s, :, 16])
            save_quats[:, 22] = Quaternions(quaternions[s, :, 11]).__neg__() * \
                                Quaternions(quaternions[s, :, 17])
            save_quats[:, 23] = Quaternions(quaternions[s, :, 17]).__neg__() * \
                                Quaternions(quaternions[s, :, 18])
            save_quats[:, 24] = Quaternions(quaternions[s, :, 18]).__neg__() * \
                                Quaternions(quaternions[s, :, 19])
            save_quats[:, 25] = Quaternions(quaternions[s, :, 19]).__neg__() * \
                                Quaternions(quaternions[s, :, 20])
            # counter = 1
            # j = 2
            # while j < num_joints:
            #     # save_quats[:, counter] = Quaternions(save_quats[:, self.joint_parents[j]]).__neg__() * \
            #     #                          Quaternions(quaternions[s, :, j])
            #     save_quats[:, counter] = Quaternions(quaternions[s, :, self.joint_parents[j]]).__neg__() * \
            #                              Quaternions(quaternions[s, :, j])
            #     counter += 1 if self.has_children_all[j] else 2
            #     j += 1 if self.has_children_all[j] else 2
            BVH.save(save_file_name,
                     Animation(Quaternions(save_quats), positions, orients,
                               self.joint_offsets_all, self.joint_parents_all),
                     names=self.joint_names_all, frame_time=frame_time)
Ejemplo n.º 7
0
    def generate_motion(self, load_saved_model=True, samples_to_generate=1534, max_steps=300, randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, quat, orient, z_mean, z_dev, \
        root_speed, affs, spline, labels = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        pos = torch.from_numpy(pos).cuda()
        traj = pos[:, :, 0, [0, 2]].clone()
        orient = torch.from_numpy(orient).cuda()
        quat = torch.from_numpy(quat).cuda()
        z_mean = torch.from_numpy(z_mean).cuda()
        z_dev = torch.from_numpy(z_dev).cuda()
        root_speed = torch.from_numpy(root_speed).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        z_rs = torch.cat((z_dev, root_speed), dim=-1)
        quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)
        labels = np.tile(labels, (1, max_steps + self.prefix_length, 1))

        # Begin for transition
        # traj[:, self.prefix_length - 2] = torch.tensor([-0.208, 4.8]).cuda().float()
        # traj[:, self.prefix_length - 1] = torch.tensor([-0.204, 5.1]).cuda().float()
        # final_emo_idx = int(max_steps/2)
        # labels[:, final_emo_idx:] = np.array([1., 0., 0., 0.])
        # labels[:, :final_emo_idx + 1] = np.linspace(labels[:, 0], labels[:, final_emo_idx],
        #                                             num=final_emo_idx + 1, axis=1)
        # End for transition
        labels = torch.from_numpy(labels).cuda()

        # traj_np = traj_markers.detach().cpu().numpy()
        # import matplotlib.pyplot as plt
        # plt.plot(traj_np[6, :, 0], traj_np[6, :, 1])
        # plt.show()

        happy_idx = [25, 295, 390, 667, 1196]
        sad_idx = [169, 184, 258, 948, 974]
        angry_idx = [89, 93, 96, 112, 289, 290, 978]
        neutral_idx = [72, 106, 143, 237, 532, 747, 1177]
        sample_idx = np.squeeze(np.concatenate((happy_idx, sad_idx, angry_idx, neutral_idx)))

        ## CHANGE HERE
        # scene_corners = torch.tensor([[149.862, 50.833],
        #                               [149.862, 36.81],
        #                               [161.599, 36.81],
        #                               [161.599, 50.833]]).cuda().float()
        # character_heights = torch.tensor([0.95, 0.88, 0.86, 0.90, 0.95, 0.82]).cuda().float()
        # num_characters_per_side = torch.tensor([2, 3, 2, 3]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_trajectories(scene_corners, z_mean, character_heights,
        #                           num_characters_per_side, traj[:, :self.prefix_length])
        # num_characters_per_side = torch.tensor([4, 0, 0, 0]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_simple_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                                  num_characters_per_side, traj[sample_idx, :self.prefix_length])
        # traj_markers, traj_offsets, character_scale =\
        #     generate_rvo_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                               num_characters_per_side, traj[sample_idx, :self.prefix_length])

        # traj[sample_idx, :self.prefix_length] += traj_offsets
        # pos_sampled = pos[sample_idx].clone()
        # pos_sampled[:, :self.prefix_length, :, [0, 2]] += traj_offsets.unsqueeze(2).repeat(1, 1, self.V, 1)
        # pos[sample_idx] = pos_sampled
        # traj_markers = self.generate_linear_trajectory(traj)

        pos_pred = self.copy_prefix(pos)
        traj_pred = self.copy_prefix(traj)
        orient_pred = self.copy_prefix(orient)
        quat_pred = self.copy_prefix(quat)
        z_rs_pred = self.copy_prefix(z_rs)
        affs_pred = self.copy_prefix(affs)
        spline_pred = self.copy_prefix(spline)
        labels_pred = self.copy_prefix(labels, prefix_length=max_steps + self.prefix_length)

        # forward
        elapsed_time = np.zeros(len(sample_idx))
        for counter, s in enumerate(sample_idx):  # range(samples_to_generate):
            start_time = time.time()
            orient_h_copy = self.orient_h.clone()
            quat_h_copy = self.quat_h.clone()
            ## CHANGE HERE
            num_markers = max_steps + self.prefix_length + 1
            # num_markers = traj_markers[s].shape[0]
            marker_idx = 0
            t = -1
            with torch.no_grad():
                while marker_idx < num_markers:
                    t += 1
                    if t > max_steps:
                        print('Sample: {}. Did not reach end in {} steps.'.format(s, max_steps), end='')
                        break
                    pos_pred[s] = torch.cat((pos_pred[s], torch.zeros_like(pos_pred[s][:, -1:])), dim=1)
                    traj_pred[s] = torch.cat((traj_pred[s], torch.zeros_like(traj_pred[s][:, -1:])), dim=1)
                    orient_pred[s] = torch.cat((orient_pred[s], torch.zeros_like(orient_pred[s][:, -1:])), dim=1)
                    quat_pred[s] = torch.cat((quat_pred[s], torch.zeros_like(quat_pred[s][:, -1:])), dim=1)
                    z_rs_pred[s] = torch.cat((z_rs_pred[s], torch.zeros_like(z_rs_pred[s][:, -1:])), dim=1)
                    affs_pred[s] = torch.cat((affs_pred[s], torch.zeros_like(affs_pred[s][:, -1:])), dim=1)
                    spline_pred[s] = torch.cat((spline_pred[s], torch.zeros_like(spline_pred[s][:, -1:])), dim=1)

                    orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    orient_h_copy, quat_h_copy = \
                        self.model(
                            orient_pred[s][:, t:self.prefix_length + t],
                            quat_pred[s][:, t:self.prefix_length + t],
                            z_rs_pred[s][:, t:self.prefix_length + t],
                            affs_pred[s][:, t:self.prefix_length + t],
                            spline_pred[s][:, t:self.prefix_length + t],
                            labels_pred[s][:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else orient_h_copy,
                            quat_h=None if t == 0 else quat_h_copy, return_prenorm=False)

                    traj_curr = traj_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t].clone()
                    # root_speed = torch.norm(
                    #     pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] - \
                    #     pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t, 0], dim=-1)

                    ## CHANGE HERE
                    # traj_next = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2])
                    try:
                        traj_next = traj[s, self.prefix_length + t]
                    except IndexError:
                        traj_next = \
                            self.compute_next_traj_point_sans_markers(
                                pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t],
                                quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 1])

                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            traj_next,
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])

                    # min_speed_pred = torch.min(torch.cat((lf_speed_pred.unsqueeze(-1),
                    #                                        rf_speed_pred.unsqueeze(-1)), dim=-1), dim=-1)[0]
                    # if min_speed_pred - diff_speeds_mean[s] - diff_speeds_std[s] < 0.:
                    #     root_speed_pred = 0.
                    # else:
                    #     root_speed_pred = o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2]
                    #

                    ## CHANGE HERE
                    # traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         root_speed_pred)
                    # if torch.norm(traj_next - traj_curr, dim=-1).squeeze() >= \
                    #         torch.norm(traj_markers[s, marker_idx] - traj_curr, dim=-1).squeeze():
                    #     marker_idx += 1
                    traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = traj_next
                    marker_idx += 1
                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])
                    print('Sample: {}. Steps: {}'.format(s, t), end='\r')
            print()

            # shift = torch.zeros((1, scene_corners.shape[1] + 1)).cuda().float()
            # shift[..., [0, 2]] = scene_corners[0]
            # pos_pred[s] = (pos_pred[s] - shift) / character_scale + shift
            # pos_pred_np = pos_pred[s].contiguous().view(pos_pred[s].shape[0],
            #                                             pos_pred[s].shape[1], -1).permute(0, 2, 1).\
            #     detach().cpu().numpy()
            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
            #                    save_file_names=[str(s).zfill(6)],
            #                    overwrite=True)

            # plt.cla()
            # fig, (ax1, ax2) = plt.subplots(2, 1)
            # ax1.plot(root_speeds[s])
            # ax1.plot(lf_speeds[s])
            # ax1.plot(rf_speeds[s])
            # ax1.plot(min_speeds[s] - root_speeds[s])
            # ax1.legend(['root', 'left', 'right', 'diff'])
            # ax2.plot(root_speeds_pred)
            # ax2.plot(lf_speeds_pred)
            # ax2.plot(rf_speeds_pred)
            # ax2.plot(min_speeds_pred - root_speeds_pred)
            # ax2.legend(['root', 'left', 'right', 'diff'])
            # plt.show()

            head_tilt = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            l_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            r_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            head_tilt = Quaternions.from_euler(head_tilt, order='xyz').qs
            l_shoulder_slouch = Quaternions.from_euler(l_shoulder_slouch, order='xyz').qs
            r_shoulder_slouch = Quaternions.from_euler(r_shoulder_slouch, order='xyz').qs

            # Begin for aligning facing direction to trajectory
            axis_diff, angle_diff = self.get_diff_from_traj(pos_pred, traj_pred, s)
            angle_thres = 0.3
            # angle_thres = torch.max(angle_diff[:, 1:self.prefix_length])
            angle_diff[angle_diff <= angle_thres] = 0.
            angle_diff[:, self.prefix_length] = 0.
            # End for aligning facing direction to trajectory
            # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # angle_diff_intermediate = self.get_diff_from_traj(pos_pred, traj_pred, s)
            # if torch.max(angle_diff_intermediate[:, self.prefix_length:]) > np.pi / 2.:
            #     quat_diff = Quaternions.from_angle_axis(-angle_diff.cpu().numpy(), np.array([0, 1, 0])).qs
            #     pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # axis_diff = torch.zeros_like(axis_diff)
            # axis_diff[..., 1] = 1.
            # angle_diff = torch.zeros_like(angle_diff)
            quat_diff = torch.from_numpy(Quaternions.from_angle_axis(
                angle_diff.cpu().numpy(), axis_diff.cpu().numpy()).qs).cuda().float()
            orient_pred[s], quat_pred[s] = self.rotate_gaits(orient_pred[s], quat_pred[s],
                                                             quat_diff, head_tilt,
                                                             l_shoulder_slouch, r_shoulder_slouch)

            if labels_pred[s][:, 0, 0] > 0.5:
                label_dir = 'happy'
            elif labels_pred[s][:, 0, 1] > 0.5:
                label_dir = 'sad'
            elif labels_pred[s][:, 0, 2] > 0.5:
                label_dir = 'angry'
            else:
                label_dir = 'neutral'

            ## CHANGE HERE
            # pos_pred[s] = pos_pred[s][:, self.prefix_length + 5:]
            # o_z_rs_pred[s] = o_z_rs_pred[s][:, self.prefix_length + 5:]
            # quat_pred[s] = quat_pred[s][:, self.prefix_length + 5:]

            traj_pred_np = pos_pred[s][0, :, 0].cpu().numpy()

            save_file_name = '{:06}_{:.2f}_{:.2f}_{:.2f}_{:.2f}'.format(s,
                                                                        labels_pred[s][0, 0, 0],
                                                                        labels_pred[s][0, 0, 1],
                                                                        labels_pred[s][0, 0, 2],
                                                                        labels_pred[s][0, 0, 3])

            animation_pred = {
                'joint_names': self.joint_names,
                'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).float().unsqueeze(0).repeat(
                    len(pos_pred), 1, 1),
                'joint_parents': self.joint_parents,
                'positions': pos_pred[s],
                'rotations': torch.cat((orient_pred[s], quat_pred[s]), dim=-1)
            }
            self.mocap.save_as_bvh(animation_pred,
                                   dataset_name=self.dataset,
                                   # subset_name='epoch_' + str(self.best_loss_epoch),
                                   # save_file_names=[str(s).zfill(6)])
                                   subset_name=os.path.join('no_aff_epoch_' + str(self.best_loss_epoch),
                                                            str(counter).zfill(2) + '_' + label_dir),
                                   save_file_names=['root'])
            end_time = time.time()
            elapsed_time[counter] = end_time - start_time
            print('Elapsed Time: {}'.format(elapsed_time[counter]))

            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset,
            #                    # subset_name='epoch_' + str(self.best_loss_epoch),
            #                    # save_file_names=[str(s).zfill(6)],
            #                    subset_name=os.path.join('epoch_' + str(self.best_loss_epoch), label_dir),
            #                    save_file_names=[save_file_name],
            #                    overwrite=True)
        print('Mean Elapsed Time: {}'.format(np.mean(elapsed_time)))