コード例 #1
0
    def set_new_root(self, new_root):
        euler = torch.tensor(self.anim.rotations[:, 0, :], dtype=torch.float)
        transform = ForwardKinematics.transform_from_euler(euler, 'xyz')
        offset = torch.tensor(self.anim.offsets[new_root], dtype=torch.float)
        new_pos = torch.matmul(transform, offset)
        new_pos = new_pos.numpy() + self.anim.positions[:, 0, :]
        self.anim.offsets[0] = -self.anim.offsets[new_root]
        self.anim.offsets[new_root] = np.zeros((3, ))
        self.anim.positions[:, new_root, :] = new_pos
        rot0 = Quaternions.from_euler(np.radians(self.anim.rotations[:, 0, :]),
                                      order='xyz')
        rot1 = Quaternions.from_euler(np.radians(
            self.anim.rotations[:, new_root, :]),
                                      order='xyz')
        new_rot1 = rot0 * rot1
        new_rot0 = (-rot1)
        new_rot0 = np.degrees(new_rot0.euler())
        new_rot1 = np.degrees(new_rot1.euler())
        self.anim.rotations[:, 0, :] = new_rot0
        self.anim.rotations[:, new_root, :] = new_rot1

        new_seq = []
        vis = [0] * self.anim.rotations.shape[1]
        new_idx = [-1] * len(vis)
        new_parent = [0] * len(vis)

        def relabel(x):
            nonlocal new_seq, vis, new_idx, new_parent
            new_idx[x] = len(new_seq)
            new_seq.append(x)
            vis[x] = 1
            for y in range(len(vis)):
                if not vis[y] and (self.anim.parents[x] == y
                                   or self.anim.parents[y] == x):
                    relabel(y)
                    new_parent[new_idx[y]] = new_idx[x]

        relabel(new_root)
        self.anim.rotations = self.anim.rotations[:, new_seq, :]
        self.anim.offsets = self.anim.offsets[new_seq]
        names = self._names.copy()
        for i, j in enumerate(new_seq):
            self._names[i] = names[j]
        self.anim.parents = np.array(new_parent, dtype=np.int)
コード例 #2
0
 def get_root_rotation(self):
     """
     Get the rotation of the root joint
     :return: A numpy Array in shape [frame, 1, 4]
     """
     # rotation in shape[frame, 1, 4]
     rotation = self.anim.rotations[:, self.corps, :]
     rotation = rotation[:, 0:1, :]
     # convert to quaternion format
     rotation = Quaternions.from_euler(np.radians(rotation)).qs
     return rotation
コード例 #3
0
 def rotate(self, theta, axis):
     q = Quaternions(
         np.hstack((np.cos(theta / 2), np.sin(theta / 2) * axis)))
     position = self.anim.positions[:, 0, :].copy()
     rotation = self.anim.rotations[:, 0, :]
     position[1:, ...] -= position[0:-1, ...]
     q_position = Quaternions(
         np.hstack((np.zeros((position.shape[0], 1)), position)))
     q_rotation = Quaternions.from_euler(np.radians(rotation))
     q_rotation = q * q_rotation
     q_position = q * q_position * (-q)
     self.anim.rotations[:, 0, :] = np.degrees(q_rotation.euler())
     position = q_position.imaginaries
     for i in range(1, position.shape[0]):
         position[i] += position[i - 1]
     self.anim.positions[:, 0, :] = position
コード例 #4
0
 def get_rotation(self):
     """
     Get the rotation of each joint's parent except the root joint
     :return: A numpy Array in shape [frame, simple_joint_num - 1, 4]]
     """
     # rotation in shape[frame, simple_joint_num, 3]
     rotations = self.anim.rotations[:, self.corps, :]
     # transform euler degree into radians then into quaternion
     # in shape [frame, simple_joint_num, 4]
     rotations = Quaternions.from_euler(np.radians(rotations)).qs
     # 除去root的剩下所有joint信息都存储在edges中
     # 因为.bvh中每个joint的rotation指的是施加在child joint上的旋转
     # 而此时简化的骨架的旋转是完整骨架中每根骨骼的parent的旋转,所以要把这些旋转信息替换成
     # 简化骨骼中对应的parent的旋转
     # 又因为使用了edges的信息,所以一定程度上可以看作这些edges自身旋转了多少
     index = []
     for e in self.edges:
         index.append(e[0])
     # now rotation is in shape[frame, simple_joint_num - 1, 4]
     rotations = rotations[:, index, :]
     return rotations
コード例 #5
0
    def load_bvh(file_name,
                 channel_map=None,
                 start=None,
                 end=None,
                 order=None,
                 world=False):
        '''
        Reads a BVH file and constructs an animation

        Parameters
        ----------
        file_name: str
            File to be opened

        channel_map: Dict
            Mapping between the coordinates x, y, z and
            the positions X, Y, Z in the bvh file

        start : int
            Optional Starting Frame

        end : int
            Optional Ending Frame

        order : str
            Optional Specifier for joint order.
            Given as string E.G 'xyz', 'zxy'

        world : bool
            If set to true euler angles are applied
            together in world space rather than local
            space

        Returns
        -------

        (animation, joint_names, frame_time)
            Tuple of loaded animation and joint names
        '''

        if channel_map is None:
            channel_map = {
                'Xrotation': 'x',
                'Yrotation': 'y',
                'Zrotation': 'z'
            }
        f = open(file_name, 'r')

        i = 0
        active = -1
        end_site = False

        names = []
        orients = Quaternions.id(0)
        offsets = np.array([]).reshape((0, 3))
        parents = np.array([], dtype=int)
        prev_data_match = []
        for line in f:

            if 'HIERARCHY' in line:
                continue
            if 'MOTION' in line:
                continue

            root_match = re.match(r'ROOT (\w+)', line)
            if root_match:
                names.append(root_match.group(1))
                offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
                orients.qs = np.append(orients.qs,
                                       np.array([[1, 0, 0, 0]]),
                                       axis=0)
                parents = np.append(parents, active)
                active = (len(parents) - 1)
                continue

            if '{' in line:
                continue

            if '}' in line:
                if end_site:
                    end_site = False
                else:
                    active = parents[active]
                continue

            offset_match = re.match(
                r'\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)',
                line)
            if offset_match:
                if not end_site:
                    offsets[active] = np.array(
                        [list(map(float, offset_match.groups()))])
                continue

            channel_match = re.match(r'\s*CHANNELS\s+(\d+)', line)
            if channel_match:
                channels = int(channel_match.group(1))
                if order is None:
                    channel_is = 0 if channels == 3 else 3
                    channel_ie = 3 if channels == 3 else 6
                    parts = line.split()[2 + channel_is:2 + channel_ie]
                    if any([p not in channel_map for p in parts]):
                        continue
                    order = ''.join([channel_map[p] for p in parts])
                continue

            joint_match = re.match('\s*JOINT\s+(\w+)', line)
            if joint_match:
                names.append(joint_match.group(1))
                offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
                orients.qs = np.append(orients.qs,
                                       np.array([[1, 0, 0, 0]]),
                                       axis=0)
                parents = np.append(parents, active)
                active = (len(parents) - 1)
                continue

            if 'End Site' in line:
                end_site = True
                continue

            frame_match = re.match('\s*Frames:\s+(\d+)', line)
            if frame_match:
                if start and end:
                    frame_num = (end - start) - 1
                else:
                    frame_num = int(frame_match.group(1))
                joint_num = len(parents)
                positions = offsets[np.newaxis].repeat(frame_num, axis=0)
                rotations = np.zeros((frame_num, len(orients), 3))
                continue

            frame_match = re.match('\s*Frame Time:\s+([\d\.]+)', line)
            if frame_match:
                frame_time = float(frame_match.group(1))
                continue

            if (start and end) and (i < start or i >= end - 1):
                i += 1
                continue
            if '#IND' in line:
                new_data_match = line.strip().split(' ')
                for i in range(len(new_data_match)):
                    if '#IND' in new_data_match[i]:
                        new_data_match[i] = prev_data_match[i]
                line = ' '.join(new_data_match)

            data_match = line.strip().split(' ')
            prev_data_match = data_match
            if data_match:

                data_block = np.array(list(map(float, data_match)))
                N = len(parents)
                fi = i - start if start else i
                if channels == 3:
                    positions[fi, 0:1] = data_block[0:3]
                    rotations[fi, :] = data_block[3:].reshape(N, 3)
                elif channels == 6:
                    data_block = data_block.reshape(N, 6)
                    positions[fi, :] = data_block[:, 0:3]
                    rotations[fi, :] = data_block[:, 3:6]
                elif channels == 9:
                    positions[fi, 0] = data_block[0:3]
                    data_block = data_block[3:].reshape(N - 1, 9)
                    rotations[fi, 1:] = data_block[:, 3:6]
                    positions[fi,
                              1:] += data_block[:, 0:3] * data_block[:, 6:9]
                else:
                    raise Exception('Too many channels! {}'.format(channels))

                i += 1

        f.close()

        rotations = qfix(
            Quaternions.from_euler(np.radians(rotations),
                                   order=order,
                                   world=world).qs)
        positions = MocapDataset.forward_kinematics(
            torch.from_numpy(rotations).cuda().float().unsqueeze(0),
            torch.from_numpy(positions[:, 0]).cuda().float().unsqueeze(0),
            parents,
            torch.from_numpy(offsets).cuda().float()).squeeze().cpu().numpy()
        orientations, _ = Quaternions(rotations[:, 0]).angle_axis()
        return names, parents, offsets, positions, rotations
コード例 #6
0
ファイル: model_zoo.py プロジェクト: liuguoyou/MotioNet
 def convert_eular_to_quaternions(self, rotations):
     rotations_reshape = rotations.view((-1, 3))
     q = Quaternions.from_euler(rotations_reshape)
     return q.qs
コード例 #7
0
def load(filename, start=None, end=None, order=None, world=False):
    """
    Reads a BVH file and constructs an animation

    Parameters
    ----------
    filename: str
        File to be opened

    start : int
        Optional Starting Frame

    end : int
        Optional Ending Frame

    order : str
        Optional Specifier for joint order.
        Given as string E.G 'xyz', 'zxy'

    world : bool
        If set to true euler angles are applied
        together in world space rather than local
        space

    Returns
    -------

    (animation, joint_names, frametime)
        Tuple of loaded animation and joint names
    """

    f = open(filename, "r")

    i = 0
    active = -1
    end_site = False

    names = []
    orients = Quaternions.id(0)
    offsets = np.array([]).reshape((0, 3))
    parents = np.array([], dtype=int)

    for line in f:

        if "HIERARCHY" in line: continue
        if "MOTION" in line: continue

        rmatch = re.match(r"ROOT (\w+)", line)
        if rmatch:
            names.append(rmatch.group(1))
            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
            orients.qs = np.append(orients.qs,
                                   np.array([[1, 0, 0, 0]]),
                                   axis=0)
            parents = np.append(parents, active)
            active = (len(parents) - 1)
            continue

        if "{" in line: continue

        if "}" in line:
            if end_site:
                end_site = False
            else:
                active = parents[active]
            continue

        offmatch = re.match(
            r"\s*OFFSET\s+([\-\d\.e]+)\s+([\-\d\.e]+)\s+([\-\d\.e]+)", line)
        if offmatch:
            if not end_site:
                offsets[active] = np.array(
                    [list(map(float, offmatch.groups()))])
            continue

        chanmatch = re.match(r"\s*CHANNELS\s+(\d+)", line)
        if chanmatch:
            channels = int(chanmatch.group(1))
            if order is None:
                channelis = 0 if channels == 3 else 3
                channelie = 3 if channels == 3 else 6
                parts = line.split()[2 + channelis:2 + channelie]
                if any([p not in channelmap for p in parts]):
                    continue
                order = "".join([channelmap[p] for p in parts])
            continue

        jmatch = re.match("\s*JOINT\s+(\w+)", line)
        if jmatch:
            names.append(jmatch.group(1))
            offsets = np.append(offsets, np.array([[0, 0, 0]]), axis=0)
            orients.qs = np.append(orients.qs,
                                   np.array([[1, 0, 0, 0]]),
                                   axis=0)
            parents = np.append(parents, active)
            active = (len(parents) - 1)
            continue

        if "End Site" in line:
            end_site = True
            continue

        fmatch = re.match("\s*Frames:\s+(\d+)", line)
        if fmatch:
            if start and end:
                fnum = (end - start) - 1
            else:
                fnum = int(fmatch.group(1))
            jnum = len(parents)
            positions = offsets[np.newaxis].repeat(fnum, axis=0)
            rotations = np.zeros((fnum, len(orients), 3))
            continue

        fmatch = re.match("\s*Frame Time:\s+([\d\.]+)", line)
        if fmatch:
            frametime = float(fmatch.group(1))
            continue

        if (start and end) and (i < start or i >= end - 1):
            i += 1
            continue

        dmatch = line.strip().split(' ')
        if dmatch:
            data_block = np.array(list(map(float, dmatch)))
            N = len(parents)
            fi = i - start if start else i
            if channels == 3:
                positions[fi, 0:1] = data_block[0:3]
                rotations[fi, :] = data_block[3:].reshape(N, 3)
            elif channels == 6:
                data_block = data_block.reshape(N, 6)
                positions[fi, :] = data_block[:, 0:3]
                rotations[fi, :] = data_block[:, 3:6]
            elif channels == 9:
                positions[fi, 0] = data_block[0:3]
                data_block = data_block[3:].reshape(N - 1, 9)
                rotations[fi, 1:] = data_block[:, 3:6]
                positions[fi, 1:] += data_block[:, 0:3] * data_block[:, 6:9]
            else:
                raise Exception("Too many channels! %i" % channels)

            i += 1

    f.close()

    rotations = Quaternions.from_euler(np.radians(rotations),
                                       order=order,
                                       world=world)

    return Animation(rotations, positions, orients, offsets,
                     parents), names, frametime
コード例 #8
0
def load_from_maya(root, start, end):
    """
    Load Animation Object from Maya Joint Skeleton    
    
    Parameters
    ----------
    
    root : PyNode
        Root Joint of Maya Skeleton
        
    start, end : int, int
        Start and End frame index of Maya Animation
    
    Returns
    -------
    
    animation : Animation
        Loaded animation from maya
        
    names : [str]
        Joint names from maya   
    """

    import pymel.core as pm
    
    original_time = pm.currentTime(q=True)
    pm.currentTime(start)
    
    """ Build Structure """
    
    names, parents = AnimationStructure.load_from_maya(root)
    descendants = AnimationStructure.descendants_list(parents)
    orients = Quaternions.id(len(names))
    offsets = np.array([pm.xform(j, q=True, translation=True) for j in names])
    
    for j, name in enumerate(names):
        scale = pm.xform(pm.PyNode(name), q=True, scale=True, relative=True)
        if len(descendants[j]) == 0: continue
        offsets[descendants[j]] *= scale
    
    """ Load Animation """

    eulers    = np.zeros((end-start, len(names), 3))
    positions = np.zeros((end-start, len(names), 3))
    rotations = Quaternions.id((end-start, len(names)))
    
    for i in range(end-start):
        
        pm.currentTime(start+i+1, u=True)
        
        scales = {}
        
        for j, name, parent in zip(range(len(names)), names, parents):
            
            node = pm.PyNode(name)
            
            if i == 0 and pm.hasAttr(node, 'jointOrient'):
                ort = node.getOrientation()
                orients[j] = Quaternions(np.array([ort[3], ort[0], ort[1], ort[2]]))
            
            if pm.hasAttr(node, 'rotate'):    eulers[i,j]    = np.radians(pm.xform(node, q=True, rotation=True))
            if pm.hasAttr(node, 'translate'): positions[i,j] = pm.xform(node, q=True, translation=True)
            if pm.hasAttr(node, 'scale'):     scales[j]      = pm.xform(node, q=True, scale=True, relative=True)

        for j in scales:
            if len(descendants[j]) == 0: continue
            positions[i,descendants[j]] *= scales[j] 
        
        positions[i,0] = pm.xform(root, q=True, translation=True, worldSpace=True)
    
    rotations = orients[np.newaxis] * Quaternions.from_euler(eulers, order='xyz', world=True)
    
    """ Done """
    
    pm.currentTime(original_time)
    
    return Animation(rotations, positions, orients, offsets, parents), names
コード例 #9
0
    def generate_motion(self, load_saved_model=True, samples_to_generate=1534, max_steps=300, randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, quat, orient, z_mean, z_dev, \
        root_speed, affs, spline, labels = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        pos = torch.from_numpy(pos).cuda()
        traj = pos[:, :, 0, [0, 2]].clone()
        orient = torch.from_numpy(orient).cuda()
        quat = torch.from_numpy(quat).cuda()
        z_mean = torch.from_numpy(z_mean).cuda()
        z_dev = torch.from_numpy(z_dev).cuda()
        root_speed = torch.from_numpy(root_speed).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        z_rs = torch.cat((z_dev, root_speed), dim=-1)
        quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)
        labels = np.tile(labels, (1, max_steps + self.prefix_length, 1))

        # Begin for transition
        # traj[:, self.prefix_length - 2] = torch.tensor([-0.208, 4.8]).cuda().float()
        # traj[:, self.prefix_length - 1] = torch.tensor([-0.204, 5.1]).cuda().float()
        # final_emo_idx = int(max_steps/2)
        # labels[:, final_emo_idx:] = np.array([1., 0., 0., 0.])
        # labels[:, :final_emo_idx + 1] = np.linspace(labels[:, 0], labels[:, final_emo_idx],
        #                                             num=final_emo_idx + 1, axis=1)
        # End for transition
        labels = torch.from_numpy(labels).cuda()

        # traj_np = traj_markers.detach().cpu().numpy()
        # import matplotlib.pyplot as plt
        # plt.plot(traj_np[6, :, 0], traj_np[6, :, 1])
        # plt.show()

        happy_idx = [25, 295, 390, 667, 1196]
        sad_idx = [169, 184, 258, 948, 974]
        angry_idx = [89, 93, 96, 112, 289, 290, 978]
        neutral_idx = [72, 106, 143, 237, 532, 747, 1177]
        sample_idx = np.squeeze(np.concatenate((happy_idx, sad_idx, angry_idx, neutral_idx)))

        ## CHANGE HERE
        # scene_corners = torch.tensor([[149.862, 50.833],
        #                               [149.862, 36.81],
        #                               [161.599, 36.81],
        #                               [161.599, 50.833]]).cuda().float()
        # character_heights = torch.tensor([0.95, 0.88, 0.86, 0.90, 0.95, 0.82]).cuda().float()
        # num_characters_per_side = torch.tensor([2, 3, 2, 3]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_trajectories(scene_corners, z_mean, character_heights,
        #                           num_characters_per_side, traj[:, :self.prefix_length])
        # num_characters_per_side = torch.tensor([4, 0, 0, 0]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_simple_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                                  num_characters_per_side, traj[sample_idx, :self.prefix_length])
        # traj_markers, traj_offsets, character_scale =\
        #     generate_rvo_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                               num_characters_per_side, traj[sample_idx, :self.prefix_length])

        # traj[sample_idx, :self.prefix_length] += traj_offsets
        # pos_sampled = pos[sample_idx].clone()
        # pos_sampled[:, :self.prefix_length, :, [0, 2]] += traj_offsets.unsqueeze(2).repeat(1, 1, self.V, 1)
        # pos[sample_idx] = pos_sampled
        # traj_markers = self.generate_linear_trajectory(traj)

        pos_pred = self.copy_prefix(pos)
        traj_pred = self.copy_prefix(traj)
        orient_pred = self.copy_prefix(orient)
        quat_pred = self.copy_prefix(quat)
        z_rs_pred = self.copy_prefix(z_rs)
        affs_pred = self.copy_prefix(affs)
        spline_pred = self.copy_prefix(spline)
        labels_pred = self.copy_prefix(labels, prefix_length=max_steps + self.prefix_length)

        # forward
        elapsed_time = np.zeros(len(sample_idx))
        for counter, s in enumerate(sample_idx):  # range(samples_to_generate):
            start_time = time.time()
            orient_h_copy = self.orient_h.clone()
            quat_h_copy = self.quat_h.clone()
            ## CHANGE HERE
            num_markers = max_steps + self.prefix_length + 1
            # num_markers = traj_markers[s].shape[0]
            marker_idx = 0
            t = -1
            with torch.no_grad():
                while marker_idx < num_markers:
                    t += 1
                    if t > max_steps:
                        print('Sample: {}. Did not reach end in {} steps.'.format(s, max_steps), end='')
                        break
                    pos_pred[s] = torch.cat((pos_pred[s], torch.zeros_like(pos_pred[s][:, -1:])), dim=1)
                    traj_pred[s] = torch.cat((traj_pred[s], torch.zeros_like(traj_pred[s][:, -1:])), dim=1)
                    orient_pred[s] = torch.cat((orient_pred[s], torch.zeros_like(orient_pred[s][:, -1:])), dim=1)
                    quat_pred[s] = torch.cat((quat_pred[s], torch.zeros_like(quat_pred[s][:, -1:])), dim=1)
                    z_rs_pred[s] = torch.cat((z_rs_pred[s], torch.zeros_like(z_rs_pred[s][:, -1:])), dim=1)
                    affs_pred[s] = torch.cat((affs_pred[s], torch.zeros_like(affs_pred[s][:, -1:])), dim=1)
                    spline_pred[s] = torch.cat((spline_pred[s], torch.zeros_like(spline_pred[s][:, -1:])), dim=1)

                    orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    orient_h_copy, quat_h_copy = \
                        self.model(
                            orient_pred[s][:, t:self.prefix_length + t],
                            quat_pred[s][:, t:self.prefix_length + t],
                            z_rs_pred[s][:, t:self.prefix_length + t],
                            affs_pred[s][:, t:self.prefix_length + t],
                            spline_pred[s][:, t:self.prefix_length + t],
                            labels_pred[s][:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else orient_h_copy,
                            quat_h=None if t == 0 else quat_h_copy, return_prenorm=False)

                    traj_curr = traj_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t].clone()
                    # root_speed = torch.norm(
                    #     pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] - \
                    #     pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t, 0], dim=-1)

                    ## CHANGE HERE
                    # traj_next = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2])
                    try:
                        traj_next = traj[s, self.prefix_length + t]
                    except IndexError:
                        traj_next = \
                            self.compute_next_traj_point_sans_markers(
                                pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t],
                                quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 1])

                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            traj_next,
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])

                    # min_speed_pred = torch.min(torch.cat((lf_speed_pred.unsqueeze(-1),
                    #                                        rf_speed_pred.unsqueeze(-1)), dim=-1), dim=-1)[0]
                    # if min_speed_pred - diff_speeds_mean[s] - diff_speeds_std[s] < 0.:
                    #     root_speed_pred = 0.
                    # else:
                    #     root_speed_pred = o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2]
                    #

                    ## CHANGE HERE
                    # traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         root_speed_pred)
                    # if torch.norm(traj_next - traj_curr, dim=-1).squeeze() >= \
                    #         torch.norm(traj_markers[s, marker_idx] - traj_curr, dim=-1).squeeze():
                    #     marker_idx += 1
                    traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = traj_next
                    marker_idx += 1
                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])
                    print('Sample: {}. Steps: {}'.format(s, t), end='\r')
            print()

            # shift = torch.zeros((1, scene_corners.shape[1] + 1)).cuda().float()
            # shift[..., [0, 2]] = scene_corners[0]
            # pos_pred[s] = (pos_pred[s] - shift) / character_scale + shift
            # pos_pred_np = pos_pred[s].contiguous().view(pos_pred[s].shape[0],
            #                                             pos_pred[s].shape[1], -1).permute(0, 2, 1).\
            #     detach().cpu().numpy()
            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
            #                    save_file_names=[str(s).zfill(6)],
            #                    overwrite=True)

            # plt.cla()
            # fig, (ax1, ax2) = plt.subplots(2, 1)
            # ax1.plot(root_speeds[s])
            # ax1.plot(lf_speeds[s])
            # ax1.plot(rf_speeds[s])
            # ax1.plot(min_speeds[s] - root_speeds[s])
            # ax1.legend(['root', 'left', 'right', 'diff'])
            # ax2.plot(root_speeds_pred)
            # ax2.plot(lf_speeds_pred)
            # ax2.plot(rf_speeds_pred)
            # ax2.plot(min_speeds_pred - root_speeds_pred)
            # ax2.legend(['root', 'left', 'right', 'diff'])
            # plt.show()

            head_tilt = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            l_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            r_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            head_tilt = Quaternions.from_euler(head_tilt, order='xyz').qs
            l_shoulder_slouch = Quaternions.from_euler(l_shoulder_slouch, order='xyz').qs
            r_shoulder_slouch = Quaternions.from_euler(r_shoulder_slouch, order='xyz').qs

            # Begin for aligning facing direction to trajectory
            axis_diff, angle_diff = self.get_diff_from_traj(pos_pred, traj_pred, s)
            angle_thres = 0.3
            # angle_thres = torch.max(angle_diff[:, 1:self.prefix_length])
            angle_diff[angle_diff <= angle_thres] = 0.
            angle_diff[:, self.prefix_length] = 0.
            # End for aligning facing direction to trajectory
            # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # angle_diff_intermediate = self.get_diff_from_traj(pos_pred, traj_pred, s)
            # if torch.max(angle_diff_intermediate[:, self.prefix_length:]) > np.pi / 2.:
            #     quat_diff = Quaternions.from_angle_axis(-angle_diff.cpu().numpy(), np.array([0, 1, 0])).qs
            #     pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # axis_diff = torch.zeros_like(axis_diff)
            # axis_diff[..., 1] = 1.
            # angle_diff = torch.zeros_like(angle_diff)
            quat_diff = torch.from_numpy(Quaternions.from_angle_axis(
                angle_diff.cpu().numpy(), axis_diff.cpu().numpy()).qs).cuda().float()
            orient_pred[s], quat_pred[s] = self.rotate_gaits(orient_pred[s], quat_pred[s],
                                                             quat_diff, head_tilt,
                                                             l_shoulder_slouch, r_shoulder_slouch)

            if labels_pred[s][:, 0, 0] > 0.5:
                label_dir = 'happy'
            elif labels_pred[s][:, 0, 1] > 0.5:
                label_dir = 'sad'
            elif labels_pred[s][:, 0, 2] > 0.5:
                label_dir = 'angry'
            else:
                label_dir = 'neutral'

            ## CHANGE HERE
            # pos_pred[s] = pos_pred[s][:, self.prefix_length + 5:]
            # o_z_rs_pred[s] = o_z_rs_pred[s][:, self.prefix_length + 5:]
            # quat_pred[s] = quat_pred[s][:, self.prefix_length + 5:]

            traj_pred_np = pos_pred[s][0, :, 0].cpu().numpy()

            save_file_name = '{:06}_{:.2f}_{:.2f}_{:.2f}_{:.2f}'.format(s,
                                                                        labels_pred[s][0, 0, 0],
                                                                        labels_pred[s][0, 0, 1],
                                                                        labels_pred[s][0, 0, 2],
                                                                        labels_pred[s][0, 0, 3])

            animation_pred = {
                'joint_names': self.joint_names,
                'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).float().unsqueeze(0).repeat(
                    len(pos_pred), 1, 1),
                'joint_parents': self.joint_parents,
                'positions': pos_pred[s],
                'rotations': torch.cat((orient_pred[s], quat_pred[s]), dim=-1)
            }
            self.mocap.save_as_bvh(animation_pred,
                                   dataset_name=self.dataset,
                                   # subset_name='epoch_' + str(self.best_loss_epoch),
                                   # save_file_names=[str(s).zfill(6)])
                                   subset_name=os.path.join('no_aff_epoch_' + str(self.best_loss_epoch),
                                                            str(counter).zfill(2) + '_' + label_dir),
                                   save_file_names=['root'])
            end_time = time.time()
            elapsed_time[counter] = end_time - start_time
            print('Elapsed Time: {}'.format(elapsed_time[counter]))

            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset,
            #                    # subset_name='epoch_' + str(self.best_loss_epoch),
            #                    # save_file_names=[str(s).zfill(6)],
            #                    subset_name=os.path.join('epoch_' + str(self.best_loss_epoch), label_dir),
            #                    save_file_names=[save_file_name],
            #                    overwrite=True)
        print('Mean Elapsed Time: {}'.format(np.mean(elapsed_time)))