コード例 #1
0
    def predict(self, prefix, target_length):
        """
        Predict a sequence using the given prefix.
        """
        assert target_length > 0

        with torch.no_grad():
            prefix = prefix.reshape(prefix.shape[1], -1, 4)
            prefix = qeuler_np(prefix, 'zyx')
            prefix = qfix(euler_to_quaternion(prefix, 'zyx'))
            inputs = torch.from_numpy(
                prefix.reshape(1, prefix.shape[0], -1).astype('float32'))

            if self.use_cuda:
                inputs = inputs.cuda()

            predicted, hidden = self.model(inputs)
            frames = [predicted]

            for i in range(1, target_length):
                predicted, hidden = self.model(predicted, hidden)
                frames.append(predicted)

            result = torch.cat(frames, dim=1)
            return result.view(result.shape[0], result.shape[1], -1,
                               4).cpu().numpy()
コード例 #2
0
    def predict(self, prefix, target_length):
        """
        Predict a sequence using the given prefix.
        """
        assert target_length > 0

        with torch.no_grad():
            prefix = prefix.reshape(prefix.shape[1], -1, 4)
            prefix = qeuler_np(prefix, 'zyx')
            prefix = qfix(euler_to_quaternion(prefix, 'zyx'))
            inputs = torch.from_numpy(
                prefix.reshape(1, prefix.shape[0], -1).astype('float32'))

            if self.use_cuda:
                inputs = inputs.cuda()

            #print("HELLOLLLEOOELLLEOO", inputs.shape, target_length)
            predicted, terms = self.model(
                inputs, target_length,
                0)  #teacher forcing ratio = 0 for testing
            frames = [predicted]

            #result = torch.cat(frames, dim=1)
            result = torch.cat(predicted, dim=1)
            return result.view(result.shape[0], result.shape[1], -1,
                               4).cpu().numpy()
コード例 #3
0
        def process_file(filename):
            # Find class
            cls = os.path.splitext(os.path.split(filename)[1])[0][11:-8]

            anim, names, frametime = BVH.load(filename)
            global_positions = Animation.positions_global(anim)

            pos = anim.positions[:, 0]  # Root joint trajectory
            rot = qfix(
                anim.rotations.qs)  # Local joint rotations as quaternions
            return pos, rot, cls
コード例 #4
0
    def _mirror_sequence(self, sequence):
        mirrored_rotations = sequence['rotations'].copy()
        mirrored_trajectory = sequence['trajectory'].copy()
        
        joints_left = self._skeleton.joints_left()
        joints_right = self._skeleton.joints_right()
        
        # Flip left/right joints
        mirrored_rotations[:, joints_left] = sequence['rotations'][:, joints_right]
        mirrored_rotations[:, joints_right] = sequence['rotations'][:, joints_left]
        
        mirrored_rotations[:, :, [2, 3]] *= -1
        mirrored_trajectory[:, 0] *= -1

        return {
            'rotations': qfix(mirrored_rotations),
            'trajectory': mirrored_trajectory
        }
コード例 #5
0
        out_actions = []

        print('Converting dataset...')
        subjects = sorted(glob(output_directory + '/h3.6m/dataset/*'))
        for subject in subjects:
            actions = sorted(glob(subject + '/*'))
            result_ = {}
            for action_filename in actions:
                data = read_file(action_filename)

                # Discard the first joint, which represents a corrupted translation
                data = data[:, 1:]

                # Convert to quaternion and fix antipodal representations
                quat = expmap_to_quaternion(-data)
                quat = qfix(quat)

                out_pos.append(np.zeros(
                    (quat.shape[0], 3)))  # No trajectory for H3.6M
                out_rot.append(quat)
                tokens = re.split('\/|\.', action_filename.replace('\\', '/'))
                subject_name = tokens[-3]
                out_subjects.append(subject_name)
                action_name = tokens[-2]
                out_actions.append(action_name)

        print('Saving...')
        np.savez_compressed(output_file_path,
                            trajectories=out_pos,
                            rotations=out_rot,
                            subjects=out_subjects,