class MorphableGraphStateMachine(StateMachineController):
    def __init__(self, scene_object, graph, start_node=None, use_all_joints=False, config=DEFAULT_CONFIG, pfnn_data=None):
        StateMachineController.__init__(self, scene_object)
        self._graph = graph
        if start_node is None or start_node not in self._graph.nodes:
            start_node = self._graph.start_node
        self.start_node = start_node
        self.frame_time = self._graph.skeleton.frame_time
        self.skeleton = self._graph.skeleton
        self.thread = None
        self._visualization = None
        set_log_mode(LOG_MODE_DEBUG)
        self.current_node = self.start_node
        self.use_all_joints = use_all_joints
        self.node_type = NODE_TYPE_IDLE
        self.state = None
        self.set_initial_idle_state(use_all_joints)
        print("start node", self.current_node)
        self.start_pose = {"position": [0, 0, 0], "orientation": [0, 0, 0]}
        self.speed = 1
        if "speed" in config:
            self.speed = config["speed"]
        print("set speed", self.speed)
        
        self.pose_buffer = []
        self.buffer_size = 10
        self.max_step_length = 80
        self.direction_vector = np.array([-1.0, 0.0, 0.0])
        self.action_constraint = None
        self.target_projection_len = 0
        self.n_joints = len(self.skeleton.animated_joints)
        self.n_max_state_queries = 20

        self.retarget_engine = None
        self.target_skeleton = None
        self.activate_emit = False
        self.show_skeleton = True
        self.node_queue = []
        self.activate_grounding = False
        self.collision_boundary = None
        self.hand_collision_boundary = None

        self.aligning_transform = np.eye(4)
        self.draw_root_trajectory = False
        self.planner = MGStatePlanner(self, self._graph, config)
        self.motion_grounding = MotionGrounding(self.skeleton, config["algorithm"]["inverse_kinematics_settings"], self.skeleton.skeleton_model)
        self.actions = self.planner.action_definitions
        self.planner.settings.use_all_joints = use_all_joints
        self.state.play = True
        self.thread = None
        self.animation_server = None
        self.success = True
        self.is_recording = False
        self.stop_current_state = False
        self.lock = threading.Lock()
        self.recorded_poses = list()
        #if pfnn_data is not None:
        #    self.planner.pfnn_wrapper = PFNNWrapper.load_from_dict(self.skeleton, pfnn_data["weights"], pfnn_data["means"])
        #    self.planner.use_pfnn = True
        self.load_clips = dict()
        if "clip_directory" in config and "clip_directory" in config:
            data_path = config["data_directory"] + os.sep + config["clip_directory"]
            if os.path.isdir(data_path):
                self.loaded_clips = load_clips(data_path)
            else:
                print("Could not find clip directory", data_path)

    def set_graph(self, graph, start_node):
        print("set graph")
        self.lock.acquire()
        self.stop_current_state = True
        if self.thread is not None:
            print("stop thread")
            self.planner.stop_thread = True
            self.thread.join()
            self.stop_current_state = True
            self.thread = None
        self._graph = graph
        self.start_node = start_node
        self.current_node = self.start_node
        self.set_initial_idle_state(self.planner.settings.use_all_joints)
        self.planner.state_queue.reset()
        self.lock.release()
        if self.animation_server is not None:
            #self.animation_server.start()
            print("restarted animation server..............................")

    def create_collision_boundary(self, radius, length, visualize=True, active=True):
        if not constants.activate_simulation:
            return
        print("create collision boundary", radius, length)
        self.collision_boundary = self.scene_object.scene.object_builder.create_component("collision_boundary", self.scene_object, radius, length, "morphablegraph_state_machine", visualize=visualize)
        self.collision_boundary.active = active
        self.planner.collision_boundary = self.collision_boundary

    def create_hand_collision_boundary(self, joint_name, radius, visualize=True, active=True):
        if not constants.activate_simulation:
            return
        print("create collision boundary",joint_name, radius)
        self.hand_collision_boundary = self.scene_object.scene.object_builder.create_component("hand_collision_boundary", joint_name, self.scene_object, radius, "morphablegraph_state_machine", visualize=visualize)
        self.hand_collision_boundary.active = active
        self.planner.hand_collision_boundary = self.hand_collision_boundary

    def load_pfnn_controller(self, path, mean_folder, src_skeleton=None):
        self.planner.pfnn_wrapper = PFNNWrapper.load_from_file(self.skeleton, path, mean_folder, src_skeleton)
        self.planner.use_pfnn = True

    def set_initial_idle_state(self, use_all_joints=False):
        mv = MotionVector(self.skeleton)
        print("node", self.current_node)
        mv.frames = self._graph.nodes[self.current_node].sample().get_motion_vector()
        mv.frame_time = self.frame_time
        mv.n_frames = len(mv.frames)
        print("before", mv.frames.shape, self.skeleton.reference_frame_length)
        other_animated_joints = self._graph.nodes[self.current_node].get_animated_joints()
        if len(other_animated_joints) == 0:
            other_animated_joints = ANIMATED_JOINTS_CUSTOM
        if use_all_joints:
            other_animated_joints = self._graph.nodes[self.current_node].get_animated_joints()
            if len(other_animated_joints) == 0:
                other_animated_joints = ANIMATED_JOINTS_CUSTOM
            full_frames = np.zeros((len(mv.frames), self.skeleton.reference_frame_length))
            for idx, reduced_frame in enumerate(mv.frames):
                full_frames[idx] = self.skeleton.add_fixed_joint_parameters_to_other_frame(reduced_frame,
                                                                                           other_animated_joints)
            mv.frames = full_frames
        self.state = MotionState(mv)
        self.state.play = self.play
    
    def set_config(self, config):
        if "activate_grounding" in config:
            self.activate_grounding =config["activate_grounding"]
        self.planner.set_config(config)

    def set_visualization(self, visualization):
        self._visualization = visualization
        self._visualization.update_dir_vis(self.direction_vector, self.target_projection_len)
        self.update_transformation()

    def update(self, dt):
        """ update current frame and global joint transformation matrices
        """
        if self.play:
            transition = self.state.update(self.speed * dt)
            self.lock.acquire()
            if transition or (len(self.planner.state_queue) > 0 and self.stop_current_state):
                # decide if the state planner should be used based on a given task and the number of states in the queue
                use_state_planner = False
                #self.planner.state_queue.mutex.acquire()
                if self.planner.is_processing or len(self.planner.state_queue) > 0:
                    use_state_planner = True
                #self.planner.state_queue.mutex.release()
                if use_state_planner:
                    # if the state planner should be used wait until a state was generated
                    success = False
                    n_queries = 0
                    while not success and n_queries < self.n_max_state_queries:
                        self.planner.state_queue.mutex.acquire()
                        success = self.pop_motion_state_from_queue()
                        if not success:
                            #print("Warning: state queue is empty")
                            n_queries += 1
                        self.planner.state_queue.mutex.release()
                    if not success:
                        print("Warning: transition to idle state due to empty state queue")
                        state_entry = self.planner.state_queue.generate_idle_state(dt, self.pose_buffer, False)
                        self.set_state_entry(state_entry)
                    self.stop_current_state = False
                else:
                    # otherwise transition to new state without the planner, e.g. to idle state
                    self.transition_to_next_state_controlled()
                    #print("WAIT")
            self.lock.release()
            self.update_transformation()

    def pop_motion_state_from_queue(self):
        if len(self.planner.state_queue) > 0:
            state_entry = self.planner.state_queue.get_first_state()
            self.set_state_entry(state_entry)
            self.planner.state_queue.pop_first_state()
            return True
        else:
            return False

    def set_state_entry(self, state_entry):
        self.state = state_entry.state
        self.current_node = state_entry.node
        self.node_type = state_entry.node_type
        #print("set state", self.current_node, self.state.mv.frames[:,1])
        self.pose_buffer = copy.copy(state_entry.pose_buffer)

    def set_global_position(self, position):
        self.lock.acquire()
        self.state.set_position(position)
        self.set_buffer_position(position)
        self.lock.release()
        assert not np.isnan(self.pose_buffer[-1]).any(), "Error in set pos "+str(position)

    def set_global_orientation(self, orientation):
        self.lock.acquire()
        self.state.set_orientation(orientation)
        self.set_buffer_orientation(orientation)
        self.lock.release()
        assert not np.isnan(self.pose_buffer[-1]).any(), "Error in set orientation "+str(orientation)

    def set_buffer_position(self, pos):
        for idx in range(len(self.pose_buffer)):
            self.pose_buffer[idx][:3] = pos

    def set_buffer_orientation(self, orientation):
        for idx in range(len(self.pose_buffer)):
            self.pose_buffer[idx][3:7] = orientation
        
    def unpause(self):
        self.state.hold_last_frame = False
        self.state.paused = False

    def play_clip(self, clip_name):
        print("play clip")
        if clip_name in self.loaded_clips:
            state = self.loaded_clips[clip_name]
            node_id = ("walk", "idle")
            node_type = NODE_TYPE_IDLE
            self.lock.acquire()
            self.planner.state_queue.mutex.acquire()
            self.stop_current_state = True
            pose_buffer = self.pose_buffer
            state_entry = StateQueueEntry(node_id, node_type, state, pose_buffer)
            self.set_state_entry(state_entry)
            self.planner.state_queue.reset()
            print("set state entry ", clip_name)
            self.planner.state_queue.mutex.release()
            self.lock.release()

    def generate_action_constraints(self, action_desc):
        action_name = action_desc["name"]
        velocity_factor = 1.0
        n_cycles = 1
        upper_body_gesture = None
        constrain_look_at = False
        look_at_constraints = False
        if "locomotionUpperBodyAction" in action_desc:
            upper_body_gesture = dict()
            upper_body_gesture["name"] = action_desc["locomotionUpperBodyAction"]
        elif "upperBodyGesture" in action_desc:
            upper_body_gesture = action_desc["upperBodyGesture"]
        if "velocityFactor" in action_desc:
            velocity_factor = action_desc["velocityFactor"]
        if "nCycles" in action_desc:
            n_cycles = action_desc["nCycles"]
        if "constrainLookAt" in action_desc:
            constrain_look_at = action_desc["constrainLookAt"]
        if "lookAtConstraints" in action_desc:
            look_at_constraints = action_desc["lookAtConstraints"]
        print("enqueue states", action_name)
        frame_constraints, end_direction, body_orientation_targets = self.planner.constraint_builder.extract_constraints_from_dict(action_desc, look_at_constraints)
        out = dict()
        out["action_name"] = action_name
        out["frame_constraints"] = frame_constraints
        out["end_direction"] = end_direction
        out["body_orientation_targets"] = body_orientation_targets
        if "controlPoints" in action_desc:
            out["control_points"] = action_desc["controlPoints"]
        elif "directionAngle" in action_desc and "nSteps" in action_desc and "stepDistance" in action_desc:
            root_dir = get_global_node_orientation_vector(self.skeleton, self.skeleton.aligning_root_node, self.get_current_frame(), self.skeleton.aligning_root_dir)
            root_dir = np.array([root_dir[0], 0, root_dir[1]])
            out["direction"] = rotate_vector_deg(root_dir, action_desc["directionAngle"])
            out["n_steps"] = action_desc["nSteps"]
            out["step_distance"] = action_desc["stepDistance"]
        elif "direction" in action_desc and "nSteps" in action_desc and "stepDistance" in action_desc:
            out["direction"] = action_desc["direction"]
            out["n_steps"] = action_desc["nSteps"]
            out["step_distance"] = action_desc["stepDistance"]
        out["upper_body_gesture"] = upper_body_gesture
        out["velocity_factor"] = velocity_factor
        out["n_cycles"] = n_cycles
        out["constrain_look_at"] = constrain_look_at
        out["look_at_constraints"] = look_at_constraints
        return out

    def enqueue_states(self, action_sequence, dt, refresh=False):
        """ generates states until all control points have been reached
            should to be called by extra thread to asynchronously
        """
        _action_sequence = []
        for action_desc in action_sequence:
            
            if "collisionObjectsUpdates" in action_desc:
                func_name = "a"
                params = self.scene_object.scene,action_desc["collisionObjectsUpdates"]
                self.scene_object.scene.schedule_func_call(func_name, update_collision_objects, params)
            a = self.generate_action_constraints(action_desc)
            _action_sequence.append(a)

        if self.thread is not None:
            print("stop thread")
            self.planner.stop_thread = True
            self.thread.join()
            self.stop_current_state = refresh
            #self.planner.state_queue.reset()
            self.thread = None

        self.planner.state_queue.mutex.acquire()
        start_node = self.current_node
        start_node_type = self.node_type
        pose_buffer = [np.array(frame) for frame in self.state.get_frames()[-self.buffer_size:]]
        self.planner.state_queue.reset()
        self.planner.state_queue.mutex.release()
        self.planner.stop_thread = False
        self.planner.is_processing = True
        if refresh:
            self.lock.acquire()
            self.stop_current_state = True
            pose_buffer = []
            for p in self.pose_buffer:
                pose_buffer.append(p)
            #self.transition_to_next_state_controlled()
            self.lock.release()

        method_args = (_action_sequence, start_node, start_node_type, pose_buffer, dt)
        self.thread = threading.Thread(target=self.planner.generate_motion_states_from_action_sequence, name="c", args=method_args)
        self.thread.start()

    def draw(self, modelMatrix, viewMatrix, projectionMatrix, lightSources):
        return
        if self.show_skeleton:
            self._visualization.draw(modelMatrix, viewMatrix, projectionMatrix, lightSources)
        self._visualization.update_dir_vis(self.direction_vector, self.target_projection_len)
        self.line_renderer.draw(modelMatrix, viewMatrix, projectionMatrix)


    def transition_to_next_state_randomly(self):
        self.current_node = self._graph.nodes[self.current_node].generate_random_transition(NODE_TYPE_STANDARD)
        spline = self._graph.nodes[self.current_node].sample()
        self.set_state_by_spline(spline)

    def emit_update(self):
        if self.activate_emit:
            return
            #self.update_scene_object.emit(-1)

    def set_aligning_transform(self):
        """ uses a random sample of the morphable model to find an aligning transformation to bring constraints into the local coordinate system"""
        sample = self._graph.nodes[self.current_node].sample(False)
        frames = sample.get_motion_vector()
        m = get_node_aligning_2d_transform(self.skeleton, self.skeleton.aligning_root_node,
                                           self.pose_buffer, frames)
        self.aligning_transform = np.linalg.inv(m)

    def transition_to_next_state_controlled(self):
        self.current_node, self.node_type, self.node_queue = self.select_next_node(self.current_node, self.node_type, self.node_queue, self.target_projection_len)
        #print("transition", self.current_node, self.node_type, self.target_projection_len)
        self.set_aligning_transform()
        if isinstance(self._graph.nodes[self.current_node].motion_primitive, StaticMotionPrimitive):
            spline = self._graph.nodes[self.current_node].sample()
            new_frames = spline.get_motion_vector()
        else:
            mp_constraints = self.planner.constraint_builder.generate_walk_constraints(self.current_node, self.aligning_transform, self.direction_vector, self.target_projection_len, self.pose_buffer)
            s = self.planner.mp_generator.generate_constrained_sample(self._graph.nodes[self.current_node], mp_constraints)
            spline = self._graph.nodes[self.current_node].back_project(s, use_time_parameters=False)
            new_frames = spline.get_motion_vector()
            #new_frames = self.planner.generate_constrained_motion_primitive(self.current_node, mp_constraints.constraints, self.pose_buffer)
        
        if self.planner.settings.use_all_joints:
            new_frames = self.planner.complete_frames(self.current_node, new_frames)
        #new_frames = self.state.get_frames()
        ignore_rotation = False
        if self.current_node[1] == "idle" and self.planner.settings.ignore_idle_rotation:
            ignore_rotation = True
        self.state = self.planner.state_queue.build_state(new_frames, self.pose_buffer, ignore_rotation)
        self.state.play = self.play
        self.emit_update()

    def select_next_node(self, current_node, current_node_type, node_queue, step_distance):
        if len(node_queue):
            next_node, node_type = node_queue[0]
            node_queue = node_queue[1:]
            next_node_type = node_type
        else:
            next_node_type = self.planner.get_next_node_type(current_node_type, step_distance)
            next_node = self._graph.nodes[current_node].generate_random_transition(next_node_type)
            if next_node is None:
               next_node = self.start_node
               next_node_type = NODE_TYPE_IDLE
        return next_node, next_node_type, node_queue

    def apply_ik_on_transition(self, spline):
        left_foot = self.skeleton.skeleton_model["joints"]["left_foot"]
        right_foot = self.skeleton.skeleton_model["joints"]["right_foot"]
        right_hand = self.skeleton.skeleton_model["joints"]["right_wrist"]
        left_hand = self.skeleton.skeleton_model["joints"]["left_wrist"]
        n_coeffs = len(spline.coeffs)
        ik_chains = self.skeleton.skeleton_model["ik_chains"]
        ik_window = 5  # n_coeffs - 2
        align_joint(self.skeleton, spline.coeffs, 0, left_foot, ik_chains["foot_l"], ik_window)
        align_joint(self.skeleton, spline.coeffs, 0, right_foot, ik_chains["foot_r"], ik_window)
        align_joint(self.skeleton, spline.coeffs, 0, left_hand, ik_chains[left_hand], ik_window)
        align_joint(self.skeleton, spline.coeffs, 0, right_hand, ik_chains[right_hand], ik_window)

        for i in range(1, n_coeffs):
            spline.coeffs[i] = self.align_frames(spline.coeffs[i], spline.coeffs[0])

    def align_frames(self, frame, ref_frame):
        for i in range(self.n_joints):
            o = i*4+3
            q = frame[o:o+4]
            frame[o:o+4] = -q if np.dot(ref_frame[o:o+4], q) < 0 else q
        return frame

    def update_transformation(self):
        pose = self.state.get_pose()
        if self.activate_grounding:
            pose = self.motion_grounding.apply_on_frame(pose, self.scene_object.scene)
        self.pose_buffer.append(np.array(pose))
        _pose = copy.copy(pose)
        #_pose[:3] = [0,0,0]
        if self.show_skeleton and self._visualization is not None:
            self._visualization.updateTransformation(_pose , self.scene_object.scale_matrix)
            self._visualization.update_dir_vis(self.direction_vector, self.target_projection_len)
        self.pose_buffer = self.pose_buffer[-self.buffer_size:]
        if self.is_recording:
            self.recorded_poses.append(pose)

    def getPosition(self):
        if self.state is not None:
            return self.state.get_pose()[:3]
        else:
            return [0, 0, 0]

    def get_global_transformation(self):
        return self.skeleton.nodes[self.skeleton.root].get_global_matrix(self.pose_buffer[-1])

    def handle_keyboard_input(self, key):
        if key == "p":
            self.transition_to_action("placeLeft")
        else:
            if key == "a":
                self.rotate_dir_vector(-10)
            elif key == "d":
                self.rotate_dir_vector(10)
            elif key == "w":
                self.target_projection_len += 10
                self.target_projection_len = min(self.target_projection_len, self.max_step_length)
            elif key == "s":
                self.target_projection_len -= 10
                self.target_projection_len = max(self.target_projection_len, 0)
            #if self.node_type == NODE_TYPE_IDLE:
            #    self.transition_to_next_state_controlled()
            #if not self.play and self.node_type == NODE_TYPE_END and self.target_projection_len > 0:
            #    self.play = True
        self.emit_update()
    
    def create_action_constraint(self, action_name, keyframe_label, position, joint_name=None):
        node = self.actions[action_name]["constraint_slots"][keyframe_label]["node"]
        if joint_name is None:
            joint_name = self.actions[action_name]["constraint_slots"][keyframe_label]["joint"]
        action_constraint = UnityFrameConstraint((action_name, node), keyframe_label, joint_name, position, None)
        return action_constraint

    def transition_to_action(self, action, constraint=None):
        self.action_constraint  = constraint
        if self.current_node[0] != "walk":
            return
        for node_name, node_type in self.actions[action]["node_sequence"]:
            self.node_queue.append(((action, node_name), node_type))
        if self.node_type == NODE_TYPE_IDLE:
            self.node_queue.append((self.start_node, NODE_TYPE_IDLE))
        self.transition_to_next_state_controlled()

    def rotate_dir_vector(self, angle):
        r = np.radians(angle)
        s = np.sin(r)
        c = np.cos(r)
        self.direction_vector[0] = c * self.direction_vector[0] - s * self.direction_vector[2]
        self.direction_vector[2] = s * self.direction_vector[0] + c * self.direction_vector[2]
        self.direction_vector /= np.linalg.norm(self.direction_vector)
        print("rotate",self.direction_vector)

    def get_n_frames(self):
        return self.state.get_n_frames()

    def get_frame_time(self):
        return self.state.get_frame_time()

    def get_pose(self, frame_idx=None):
        frame = self.state.get_pose(frame_idx)
        if self.retarget_engine is not None:
            return self.retarget_engine.retarget_frame(frame, None)
        else:
            return frame
        

    def get_current_frame_idx(self):
        return self.state.frame_idx

    def get_current_annotation(self):
        return self.state.get_current_annotation()

    def get_n_annotations(self):
        return self.state.get_n_annotations()

    def get_semantic_annotation(self):
        return self.state.get_semantic_annotation()

    def set_target_skeleton(self, target_skeleton):
        self.target_skeleton = target_skeleton
        target_knee = target_skeleton.skeleton_model["joints"]["right_knee"]
        src_knee = self.skeleton.skeleton_model["joints"]["right_knee"]
        scale = 1.0 # np.linalg.norm(target_skeleton.nodes[target_knee].offset) / np.linalg.norm(self.skeleton.nodes[src_knee].offset)
        joint_map = generate_joint_map(self.skeleton.skeleton_model, target_skeleton.skeleton_model)
        skeleton_copy = copy.deepcopy(self.skeleton)
        self.retarget_engine = Retargeting(skeleton_copy, target_skeleton, joint_map, scale, additional_rotation_map=None, place_on_ground=False)
        self.activate_emit = False
        self.show_skeleton = False

    def get_actions(self):
        return list(self.actions.keys())

    def get_keyframe_labels(self, action_name):
        if action_name in self.actions:
            if "constraint_slots" in self.actions[action_name]:
                return list(self.actions[action_name]["constraint_slots"].keys())
            else:
                raise Exception("someting "+ action_name)
        return list()

    def get_skeleton(self):
        if self.target_skeleton is not None:
            return self.target_skeleton
        else:
            return self.skeleton

    def get_animated_joints(self):
        return self._graph.animated_joints

    def get_current_frame(self):
        pose = self.state.get_pose(None)
        if self.target_skeleton is not None:
            pose = self.retarget_engine.retarget_frame(pose, self.target_skeleton.reference_frame)
            if self.activate_grounding:
                x = pose[0]
                z = pose[2]
                target_ground_height = self.scene_object.scene.get_height(x, z)
                #pelvis = self.target_skeleton.skeleton_model["joints"]["pelvis"]
                #offset = self.target_skeleton.nodes[pelvis].offset
                #print("offset", pelvis, offset[2],np.linalg.norm(offset))
                #shift = target_ground_height - (pose[1] + offset[2])
                shift = target_ground_height - pose[1]
                pose[1] += shift
        return pose

    def get_events(self):
        event_keys = list(self.state.events.keys())
        for key in event_keys:
            if self.state.frame_idx >= key:
                # send and remove event
                events = self.state.events[key]
                del self.state.events[key]
                return events
        else:
            return list()

    def get_current_annotation_label(self):
        return ""

    def isPlaying(self):
        return True

    def has_success(self):
        return self.success

    def reset_planner(self):
        print("reset planner")
        self.planner.state_queue.mutex.acquire()
        if self.planner.is_processing:
            self.planner.stop_thread = True
            if self.thread is not None:
                self.thread.stop()
            self.planner.is_processing = False
            #self.current_node = ("walk", "idle")
            self.current_node = self.start_node
            self.node_type = NODE_TYPE_IDLE
            self.planner.state_queue.reset()
            self.pose_buffer = list()
            self.set_initial_idle_state(self.use_all_joints)

        self.planner.state_queue.mutex.release()
        return

    def start_recording(self):
        self.is_recording = True
        self.recorded_poses = list()

    def save_recording_to_file(self):
        time_str = datetime.now().strftime("%d%m%y_%H%M%S")
        filename = "recording_"+time_str+".bvh"
        n_frames = len(self.recorded_poses)
        if n_frames > 0:
            other_animated_joints = self._graph.nodes[self.current_node].get_animated_joints()
            full_frames = np.zeros((n_frames, self.skeleton.reference_frame_length))
            for idx, reduced_frame in enumerate(self.recorded_poses):
                full_frames[idx] = self.skeleton.add_fixed_joint_parameters_to_other_frame(reduced_frame,
                                                                                           other_animated_joints)
            mv = MotionVector()
            mv.frames = full_frames
            mv.n_frames = n_frames
            mv.frame_time = self.frame_time
            mv.export(self.skeleton, filename)
            print("saved recording with", n_frames, "to file", filename)
            self.is_recording = False

    def get_bone_matrices(self):
        return self._visualization.matrices

    def handle_collision(self):
        print("handle collision")
        self.lock.acquire()
        if self.thread is not None:
            print("stop thread")
            self.planner.stop_thread = True
            self.thread.join()
            self.stop_current_state = True
            self.thread = None

        self.planner.state_queue.mutex.acquire()
        self.planner.state_queue.reset()
        self.planner.state_queue.mutex.release()
        self.planner.stop_thread = False
        self.planner.is_processing = True
        
        self.stop_current_state = True
        #self.transition_to_next_state_controlled()
        self.lock.release()
Ejemplo n.º 2
0
class SkeletonAnimationController(SkeletonAnimationControllerBase):
    """ The class controls the pose of a skeleton based on an instance of a MotionState class.
        The scene containing a controller connects to signals emitted by an instance of the class and relays them to the GUI.
    """
    def __init__(self, scene_object):
        SkeletonAnimationControllerBase.__init__(self, scene_object)
        self.loadedCorrectly = False
        self.hasVisualization = False
        self.filePath = ""
        self.name = ""
        self._visualization = None
        self._motion = None
        self.markers = dict()
        self.recorder = None
        self.relative_root = False
        self.root_pos = None
        self.root_q = None
        self.type = CONTROLLER_TYPE_ANIMATION
        self.animationSpeed = 1.0
        self.loopAnimation = False
        self.activate_emit = True
        self.visualize = True

    def set_skeleton(self, skeleton, visualize=True):
        self.visualize = visualize
        if visualize:
            self._visualization.set_skeleton(skeleton, visualize)

    def set_motion(self, motion):
        self._motion = MotionState(motion)

    def set_color_annotation(self, semantic_annotation, color_map):
        self._motion.set_color_annotation(semantic_annotation, color_map)

    def set_time_function(self, time_function):
        self._motion.set_time_function(time_function)

    def set_color_annotation_legacy(self, annotation, color_map):
        self._motion.set_color_annotation_legacy(annotation, color_map)

    def set_random_color_annotation(self):
        self._motion.set_random_color_annotation()

    def set_visualization(self,
                          visualization,
                          draw_mode=SKELETON_DRAW_MODE_BOXES):
        self._visualization = visualization
        self._visualization.draw_mode = draw_mode
        self._visualization.updateTransformation(
            self._motion.get_pose(), self.scene_object.scale_matrix)

    def update(self, dt):
        """ update current frame and global joint transformation matrices
        """
        if not self.isLoadedCorrectly():
            return
        reset = self._motion.update(dt * self.animationSpeed)
        if self._motion.play:
            self.updateTransformation()
            # update gui
            if reset:
                self.reached_end_of_animation.emit(self.loopAnimation)
                self._motion.play = self.loopAnimation
            else:
                if self.activate_emit:
                    self.updated_animation_frame.emit(
                        self._motion.get_current_frame_idx())

    def draw(self, modelMatrix, viewMatrix, projectionMatrix, lightSources):
        if self.isLoadedCorrectly():
            self._visualization.draw(modelMatrix, viewMatrix, projectionMatrix,
                                     lightSources)

    def updateTransformation(self):
        if self.relative_root:
            return
        self.set_transformation_from_frame(self._motion.get_pose())

    def set_transformation_from_frame(self, frame):
        if frame is None:
            return
        self._visualization.updateTransformation(
            frame, self.scene_object.scale_matrix)
        #self.update_markers()
        self.updateAnnotation()

    def updateAnnotation(self):
        if self._motion.get_current_frame_idx(
        ) < self._motion.get_n_annotations():
            current_annotation = self._motion.get_current_annotation()
            self._visualization.set_color(current_annotation["color"])

    def get_current_annotation_label(self):
        return self._motion.get_current_annotation_label()

    def resetAnimationTime(self):
        self._motion.reset()
        self.updateTransformation()

    def setCurrentFrameNumber(self, frame_idx):
        self._motion.set_frame_idx(frame_idx)
        self.updateTransformation()
        #self.update_markers()

    def getNumberOfFrames(self):
        return self._motion.get_n_frames()

    def isLoadedCorrectly(self):
        return self._motion is not None

    def getFrameTime(self):
        if self.isLoadedCorrectly():
            # print self.frameTime
            return self._motion.get_frame_time()
        else:
            return 0

    def getScaleFactor(self):
        if self.isLoadedCorrectly():
            return self.scaleFactor
        else:
            return -1

    def getFilePath(self):
        if self.isLoadedCorrectly():
            return self.filePath

    def getNumberOfJoints(self):
        return len(self._visualization.skeleton.get_n_joints())

    def setColor(self, color):
        print("set color", color)
        self._visualization.set_color(color)

    def getColor(self):
        return self._visualization.color

    def getPosition(self):
        m = self.scene_object.transformation
        if self._motion is not None:
            root = self._visualization.skeleton.root
            pos = self._visualization.skeleton.nodes[
                root].offset + self._motion.get_pose()[:3]
            pos = [pos[0], pos[1], pos[2], 1]
            pos = np.dot(m, pos)[:3]
            return np.array(pos)
        else:
            return m[3, :3]

    def get_visualization(self):
        return self._visualization

    def create_ragdoll(self, use_reference_frame=True, create_markers=True):
        if self._motion is not None and self._visualization.skeleton.skeleton_model is not None:
            frame = self._motion.get_pose()
            skeleton = self._visualization.skeleton
            if use_reference_frame:
                frame = skeleton.get_reduced_reference_frame()
            o = self.scene_object.scene.object_builder.create_component(
                "ragdoll_from_skeleton",
                skeleton,
                frame,
                figure_def,
                add_contact_vis=False)
            #o = self.scene_object.scene.object_builder.create_ragdoll_from_skeleton(self._visualization.skeleton, frame)
            self.scene_object.scene.addAnimationController(
                o, "character_animation_recorder")
            self.recorder = o._components["character_animation_recorder"]
        if create_markers:
            self.create_markers()

    def create_markers(self, figure_def, scale=1.0):
        if self.recorder is not None:
            markers = self.recorder.generate_constraint_markers_v9(
                self, scale, figure_def)
            self.attach_constraint_markers(markers)

    def attach_constraint_markers(self, markers):
        self.markers = markers

    def detach_constraint_markers(self):
        self.markers = dict()

    def update_markers(self):
        frame = self._motion.get_pose()
        scale = self.scene_object.scale_matrix[0][0]
        for joint in list(self.markers.keys()):
            for marker in self.markers[joint]:
                m = self._visualization.skeleton.nodes[
                    joint].get_global_matrix(frame, True)
                position = np.dot(m, marker["relative_trans"])[:3, 3]
                marker["object"].setPosition(position * scale)

    def toggle_animation_loop(self):
        self.loopAnimation = not self.loopAnimation

    def get_bvh_string(self):
        skeleton = self._visualization.skeleton
        print("generate bvh string", len(skeleton.animated_joints))
        frames = self._motion.get_frames()
        frames = skeleton.add_fixed_joint_parameters_to_motion(frames)
        frame_time = self._motion.get_frame_time()
        bvh_writer = BVHWriter(None, skeleton, frames, frame_time, True)
        return bvh_writer.generate_bvh_string()

    def get_json_data(self):
        self._motion.mv.skeleton = self._visualization.skeleton
        return self._motion.mv.to_db_format()

    def export_to_file(self, filename, export_format="bvh", frame_range=None):
        if self._motion is not None:
            frame_time = self._motion.get_frame_time()
            if export_format == "bvh":
                skeleton = self._visualization.skeleton
                frames = self._motion.get_frames()
                frames = np.array(frames)
                if frames is not None:
                    print("frames shape", frames.shape)
                else:
                    print("frames is none")

                print("ref framee length", skeleton.reference_frame_length)
                joint_count = 0
                for joint_name in skeleton.nodes.keys():
                    if len(skeleton.nodes[joint_name].children
                           ) > 0 and "EndSite" not in joint_name:
                        joint_count += 1
                skeleton.reference_frame_length = joint_count * 4 + 3
                frames = skeleton.add_fixed_joint_parameters_to_motion(frames)
                if frame_range is not None:
                    bvh_writer = BVHWriter(
                        None, skeleton,
                        frames[frame_range[0]:frame_range[1], :], frame_time,
                        True)
                else:
                    bvh_writer = BVHWriter(None, skeleton, frames, frame_time,
                                           True)
                bvh_writer.write(filename)
            elif export_format == "fbx":
                export_motion_vector_to_fbx_file(self._visualization.skeleton,
                                                 self._motion, filename)
            elif export_format == "json":
                self._visualization.skeleton.save_to_json(filename)
            else:
                print("unsupported format", export_format)

    def retarget_from_src(self,
                          src_controller,
                          scale_factor=1.0,
                          src_model=None,
                          target_model=None,
                          frame_range=None):
        target_skeleton = self._visualization.skeleton
        frame_time = src_controller.get_frame_time()
        if target_model is not None:
            target_skeleton.skeleton_model = target_model
        new_frames = None
        if type(src_controller) == SkeletonAnimationController:
            src_skeleton = src_controller._visualization.skeleton
            src_frames = src_controller._motion.get_frames()
            if src_model is not None:
                src_skeleton.skeleton_model = src_model
            if src_skeleton.identity_frame is None or target_skeleton.identity_frame is None:
                raise Exception("Error identiframe is None")
            new_frames = retarget_from_src_to_target(src_skeleton,
                                                     target_skeleton,
                                                     src_frames,
                                                     scale_factor=scale_factor,
                                                     frame_range=frame_range)
        elif type(src_controller) == PointCloudAnimationController:
            src_joints = src_controller._joints
            src_frames = src_controller._animated_points
            if src_model is None:
                src_model = src_controller.skeleton_model
            new_frames = retarget_from_point_cloud_to_target(
                src_joints,
                src_model,
                target_skeleton,
                src_frames,
                scale_factor=scale_factor,
                frame_range=frame_range)

        if new_frames is not None:
            self._motion.mv.frames = new_frames
            self._motion.mv.n_frames = len(new_frames)
            self._motion.frame_idx = 0
            self._motion.mv.frame_time = frame_time
            self.currentFrameNumber = 0
            self.updateTransformation()
            self.update_scene_object.emit(-1)
            self.updated_animation_frame.emit(self.currentFrameNumber)
            print("finished retargeting", self._motion.get_n_frames(),
                  "frames")
        return self._motion.get_n_frames()

    def retarget_from_frames(self,
                             src_skeleton,
                             src_frames,
                             scale_factor=1.0,
                             target_model=None,
                             frame_range=None,
                             place_on_ground=False,
                             joint_filter=None):
        target_skeleton = self._visualization.skeleton
        if target_model is not None:
            target_skeleton.skeleton_model = target_model
        new_frames = retarget_from_src_to_target(
            src_skeleton,
            target_skeleton,
            src_frames,
            scale_factor=scale_factor,
            frame_range=frame_range,
            place_on_ground=place_on_ground,
            joint_filter=joint_filter)
        if new_frames is not None:
            self._motion.mv.frames = new_frames
            self._motion.mv.n_frames = len(new_frames)
            print("finished retargeting", self._motion.get_n_frames(),
                  "frames")
        return self._motion.get_n_frames()

    def set_scale(self, scale_factor):
        #self._visualization.set_scale(scale_factor)
        color = self._visualization.color

        #self._motion.mv.frames[:,:3] *= scale_factor
        skeleton = self._visualization.skeleton
        skeleton.scale(scale_factor)
        self._motion.mv.scale_root(scale_factor)
        self._visualization = SkeletonVisualization(self.scene_object, color)
        self._visualization.set_skeleton(skeleton)
        self.updateTransformation()
        self.scene_object.transformation = np.eye(4)

    def load_annotation(self, filename):
        with open(filename, "r") as in_file:
            annotation_data = json.load(in_file)
            semantic_annotation = annotation_data["semantic_annotation"]
            color_map = annotation_data["color_map"]
            self.set_color_annotation(semantic_annotation, color_map)

    def save_annotation(self, filename):
        with open(filename, "w") as out_file:
            data = dict()
            data["semantic_annotation"] = self._motion._semantic_annotation
            data["color_map"] = self._motion.label_color_map
            json.dump(data, out_file)

    def plot_joint_trajectories(self, joint_list):
        joint_objects = []
        for j in joint_list:
            o = self.plot_joint_trajectory(j)
            if o is not None:
                joint_objects.append(o)
        return joint_objects

    def plot_joint_trajectory(self, joint_name):
        scene_object = None
        if joint_name in list(self._visualization.skeleton.nodes.keys()):
            trajectory = list()
            for f in self._motion.get_frames():
                p = self.get_joint_position(joint_name, f)
                if p is not None:
                    trajectory.append(p)
            if len(trajectory) > 0:
                name = self.scene_object.name + "_" + joint_name + "_trajectory"
                scene_object = self.scene_object.scene.addSplineObject(
                    name, trajectory, get_random_color(), granularity=500)
            else:
                print("No points to plot for joint", joint_name)
        return scene_object

    def get_joint_position(self, joint_name, frame):
        if joint_name in self._visualization.skeleton.nodes.keys():
            return self._visualization.skeleton.nodes[
                joint_name].get_global_position(frame)
        else:
            return None

    def get_skeleton_copy(self):
        return deepcopy(self._visualization.skeleton)

    def get_motion_vector_copy(self, start_frame=0, end_frame=-1):
        mv_copy = MotionVector()
        if end_frame > 0:
            mv_copy.frames = deepcopy(
                self._motion.mv.frames[start_frame:end_frame])
        else:
            mv_copy.frames = np.array(self._motion.mv.frames)
        mv_copy.n_frames = len(mv_copy.frames)
        mv_copy.frame_time = self._motion.mv.frame_time
        return mv_copy

    def get_current_frame(self):
        return self._motion.get_pose()

    def apply_delta_frame(self, skeleton, frame):
        self._motion.apply_delta_frame(skeleton, frame)

    def replace_current_frame(self, frame):
        self._motion.replace_current_frame(frame)
        self.updateTransformation()

    def replace_current_frames(self, frames):
        self._motion.replace_frames(frames)

    def replace_motion_from_file(self, filename):
        if filename.endswith(".bvh"):
            bvh_reader = BVHReader(filename)
            motion_vector = MotionVector()
            motion_vector.from_bvh_reader(bvh_reader, False)
            self._motion.replace_frames(motion_vector.frames)
            self.currentFrameNumber = 0
            self.updateTransformation()
            self.update_scene_object.emit(-1)
            self.updated_animation_frame.emit(self.currentFrameNumber)
        elif filename.endswith("_mg.zip"):
            self.scene_object.scene.attach_mg_state_machine(
                self.scene_object, filename)
            self._motion = self.scene_object._components[
                "morphablegraph_state_machine"]
            self._motion.set_target_skeleton(self._visualization.skeleton)
            self.activate_emit = False
        elif filename.endswith("amc"):
            amc_frames = parse_amc_file(filename)
            motion_vector = MotionVector()
            motion_vector.from_amc_data(self._visualization.skeleton,
                                        amc_frames)
            self._motion.replace_frames(motion_vector.frames)
            self._motion.mv.frame_time = 1.0 / 120
            self.currentFrameNumber = 0
            self.updateTransformation()
            self.update_scene_object.emit(-1)
            self.updated_animation_frame.emit(self.currentFrameNumber)

    def replace_motion_from_str(self, bvh_str):
        bvh_reader = BVHReader("")
        lines = bvh_str.split("\n")
        print(len(lines))
        lines = [l for l in lines if len(l) > 0]
        bvh_reader.process_lines(lines)
        motion_vector = MotionVector()
        motion_vector.from_bvh_reader(bvh_reader, False)
        self._motion.replace_frames(motion_vector.frames)

    def replace_skeleton_model(self, filename):
        data = load_json_file(filename)
        model = data  # ["skeleton_model"]
        self.set_skeleton_model(model)

    def set_skeleton_model(self, model):
        self._visualization.skeleton.skeleton_model = model

    def attach_animated_mesh_component(
            self, filename, animation_controller="animation_controller"):
        scene = self.scene_object.scene
        model_data = load_model_from_fbx_file(filename)
        scene.object_builder.create_component("animated_mesh",
                                              self.scene_object, model_data,
                                              animation_controller)

    def get_bone_matrices(self):
        return self._visualization.matrices

    def set_color_annotation_from_labels(self, labels, colors):
        self._motion.set_color_annotation_from_labels(labels, colors)

    def set_reference_frame(self, frame_idx):
        self._visualization.skeleton.set_reference_frame(
            self._motion.get_pose(frame_idx))

    def get_semantic_annotation(self):
        return self._motion.get_semantic_annotation()

    def get_label_color_map(self):
        return self._motion.get_label_color_map()

    def isPlaying(self):
        return self._motion.play

    def stopAnimation(self):
        self._motion.play = False

    def startAnimation(self):
        self._motion.play = True

    def toggleAnimation(self):
        self._motion.play = not self._motion.play

    def setAnimationSpeed(self, speed):
        self.animationSpeed = speed

    def get_current_frame_idx(self):
        return self._motion.get_current_frame_idx()

    def toggle_animation(self):
        self._motion.play = not self._motion.play

    def set_frame_time(self, frame_time):
        self._motion.mv.frame_time = frame_time

    def get_frames(self):
        return self._motion.get_frames()

    def get_max_time(self):
        return self._motion.get_n_frames() * self._motion.get_frame_time()

    def get_frame_time(self):
        return self._motion.get_frame_time()

    def get_skeleton(self):
        return self._visualization.skeleton

    def set_time(self, t):
        self._motion.set_time(t)

    def get_animated_joints(self):
        return self._visualization.skeleton.animated_joints

    def add_skeleton_mirror(self, snapshot_interval=10):
        skeleton_mirror = SkeletonMirrorComponent(self.scene_object,
                                                  self._visualization,
                                                  snapshot_interval)
        self.scene_object._components["skeleton_mirror"] = skeleton_mirror
        return skeleton_mirror

    def set_ticker(self, tick):
        self._motion.set_ticker(tick)

    def replace_frames(self, frames):
        return self._motion.replace_frames(frames)

    def get_labeled_points(self):
        p = [m[:3, 3] for m in self._visualization.matrices]
        return self.get_animated_joints(), p