Exemple #1
0
 def generate_simulated_trajectory(self, robot_state_batch,
                                   human_state_batch, action_batch,
                                   next_human_state_batch):
     # next_state = robot_state.clone()
     # action_list = []
     # if self.kinematics == 'holonomic':
     #     for i in range(next_state.shape[0]):
     #         action = self.action_space[action_index[i]]
     #         action_list.append([action.vx, action.vy])
     #     action_tensor = torch.tensor(action_list)
     #     next_state[:, :, 0:2] = next_state[:, :, 0:2] + action_tensor * self.time_step
     #     next_state[:, :, 2:4] = action_tensor
     # else:
     #     for i in range(next_state.shape[0]):
     #         action = self.action_space[action_index[i]]
     #         action_list.append([action.v, action.r])
     #     action_tensor = torch.tensor(action_list)
     #     next_state[:, :, 8] = (next_state[:, :, 8] + action_tensor[:, 1]) % (2 * np.pi)
     #     next_state[:, :, 2] = np.cos(next_state[:, :, 8]) * action_tensor[:, 0]
     #     next_state[:, :, 3] = np.sin(next_state[:, :, 8]) * action_tensor[:, 0]
     #     next_state[:, :, 0:2] = next_state[:, :, 0:2] + next_state[:, :, 2:4] * self.time_step
     # return next_state
     expand_next_robot_state = None
     expand_reward = []
     expand_done = []
     for i in range(robot_state_batch.shape[0]):
         action = self.action_space[action_batch[i]]
         cur_robot_state = robot_state_batch[i, :, :]
         cur_human_state = human_state_batch[i, :, :]
         cur_state = tensor_to_joint_state(
             (cur_robot_state, cur_human_state))
         next_robot_state = self.compute_next_robot_state(
             cur_robot_state, action)
         next_human_state = next_human_state_batch[i, :, :]
         next_state = tensor_to_joint_state(
             (next_robot_state, next_human_state))
         reward, info = self.reward_estimator.estimate_reward_on_predictor(
             cur_state, next_state)
         expand_reward.append(reward)
         done = False
         if info is ReachGoal() or info is Collision():
             done = True
         expand_done.append(done)
         if expand_next_robot_state is None:
             expand_next_robot_state = next_robot_state
         else:
             expand_next_robot_state = torch.cat(
                 (expand_next_robot_state, next_robot_state), dim=0)
         # expand_next_robot_state.append(next_robot_state)
     # expand_next_robot_state = torch.Tensor(expand_next_robot_state)
     expand_reward = torch.Tensor(expand_reward).unsqueeze(dim=1)
     expand_done = torch.Tensor(expand_done).unsqueeze(dim=1)
     return expand_next_robot_state, expand_reward, expand_done
    def estimate_reward(self, state, action):
        """ If the time step is small enough, it's okay to model agent as linear movement during this period

        """
        # collision detection
        if isinstance(state, list) or isinstance(state, tuple):
            state = tensor_to_joint_state(state)
        human_states = state.human_states
        robot_state = state.robot_state

        dmin = float('inf')
        collision = False
        for i, human in enumerate(human_states):
            px = human.px - robot_state.px
            py = human.py - robot_state.py
            if self.kinematics == 'holonomic':
                vx = human.vx - action.vx
                vy = human.vy - action.vy
            else:
                vx = human.vx - action.v * np.cos(action.r + robot_state.theta)
                vy = human.vy - action.v * np.sin(action.r + robot_state.theta)
            ex = px + vx * self.time_step
            ey = py + vy * self.time_step
            # closest distance between boundaries of two agents
            closest_dist = point_to_segment_dist(
                px, py, ex, ey, 0, 0) - human.radius - robot_state.radius
            if closest_dist < 0:
                collision = True
                break
            elif closest_dist < dmin:
                dmin = closest_dist

        # check if reaching the goal
        if self.kinematics == 'holonomic':
            px = robot_state.px + action.vx * self.time_step
            py = robot_state.py + action.vy * self.time_step
        else:
            theta = robot_state.theta + action.r
            px = robot_state.px + np.cos(theta) * action.v * self.time_step
            py = robot_state.py + np.sin(theta) * action.v * self.time_step

        end_position = np.array((px, py))
        reaching_goal = norm(
            end_position -
            np.array([robot_state.gx, robot_state.gy])) < robot_state.radius

        if collision:
            reward = -0.25
        elif reaching_goal:
            reward = 1
        elif dmin < 0.2:
            # adjust the reward based on FPS
            reward = (dmin - 0.2) * 0.5 * self.time_step
        else:
            reward = 0

        return reward
    def V_planning(self, state, depth, width):
        """ Plans n steps into future. Computes the value for the current state as well as the trajectories
        defined as a list of (state, action, reward) triples

        """

        current_state_value = self.value_estimator(state)
        if depth == 1:
            return current_state_value, [(state, None, None)]

        if self.do_action_clip:
            action_space_clipped = self.action_clip(state, self.action_space, width)
        else:
            action_space_clipped = self.action_space

        returns = []
        trajs = []
        actions =[]
        if self.kinematics == "holonomic":
            actions.append(ActionXY(0, 0))
        else:
            actions.append(ActionRot(0, 0))
        # actions.append(ActionXY(0, 0))
        pre_next_state = self.state_predictor(state, actions)
        for action in action_space_clipped:
            next_robot_staete = self.compute_next_robot_state(state[0], action)
            next_state_est = next_robot_staete, pre_next_state[1]
            # reward_est = self.estimate_reward(state, action)
            reward_est, _ = self.reward_estimator.estimate_reward_on_predictor(tensor_to_joint_state(state),
                                                                               tensor_to_joint_state(next_state_est))
            next_value, next_traj = self.V_planning(next_state_est, depth - 1, self.planning_width)
            return_value = current_state_value / depth + (depth - 1) / depth * (self.get_normalized_gamma() * next_value + reward_est)
            returns.append(return_value)
            trajs.append([(state, action, reward_est)] + next_traj)
        max_index = np.argmax(returns)
        max_return = returns[max_index]
        max_traj = trajs[max_index]
        return max_return, max_traj
    def action_clip(self, state, action_space, width, depth=1):
        values = []
        actions = []
        if self.kinematics == "holonomic":
            actions.append(ActionXY(0, 0))
        else:
            actions.append(ActionRot(0, 0))
        # actions.append(ActionXY(0, 0))
        next_robot_states = None
        next_human_states = None
        pre_next_state = self.state_predictor(state, actions)
        for action in action_space:
            # actions = []
            # actions.append(action)
            next_robot_state = self.compute_next_robot_state(state[0], action)
            next_human_state = pre_next_state[1]
            if next_robot_states is None and next_human_states is None:
                next_robot_states = next_robot_state
                next_human_states = next_human_state
            else:
                next_robot_states = torch.cat((next_robot_states, next_robot_state), dim=0)
                next_human_states = torch.cat((next_human_states, next_human_state), dim=0)
            next_state_tensor = (next_robot_state, next_human_state)
            next_state = tensor_to_joint_state(next_state_tensor)
            reward_est, _ = self.reward_estimator.estimate_reward_on_predictor(state, next_state)
            values.append(reward_est)
        next_return = self.value_estimator((next_robot_states, next_human_states)).squeeze()
        next_return = np.array(next_return.data.detach())
        values = np.array(values) + self.get_normalized_gamma() * next_return
        values = values.tolist()

        if self.sparse_search:
            # self.sparse_speed_samples = 2
            # search in a sparse grained action space
            added_groups = set()
            max_indices = np.argsort(np.array(values))[::-1]
            clipped_action_space = []
            for index in max_indices:
                if self.action_group_index[index] not in added_groups:
                    clipped_action_space.append(action_space[index])
                    added_groups.add(self.action_group_index[index])
                    if len(clipped_action_space) == width:
                        break
        else:
            max_indexes = np.argpartition(np.array(values), -width)[-width:]
            clipped_action_space = [action_space[i] for i in max_indexes]

        # print(clipped_action_space)
        return clipped_action_space
    def render(self, mode='video', output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt
        # plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap('hsv', 10)
        robot_color = 'black'
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        display_numbers = False

        if mode == 'traj':
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            for i in range(len(self.humans)):
                human = self.humans[i]
                human_goal = mlines.Line2D([human.get_goal_position()[0]], [human.get_goal_position()[1]],
                                           color=human_colors[i],
                                           marker='*', linestyle='None', markersize=15)
                ax.add_artist(human_goal)
                human_start = mlines.Line2D([human.get_start_position()[0]], [human.get_start_position()[1]],
                                            color=human_colors[i],
                                            marker='o', linestyle='None', markersize=15)
                ax.add_artist(human_start)

            robot_positions = [self.states[i][0].position for i in range(len(self.states))]
            human_positions = [[self.states[i][1][j].position for j in range(len(self.humans))]
                               for i in range(len(self.states))]

            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(robot_positions[k], self.robot.radius, fill=False, color=robot_color)
                    humans = [plt.Circle(human_positions[k][i], self.humans[i].radius, fill=False, color=cmap(i))
                              for i in range(len(self.humans))]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)

                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [plt.text(agents[i].center[0] - x_offset, agents[i].center[1] - y_offset,
                                      '{:.1f}'.format(global_time),
                                      color='black', fontsize=14) for i in range(self.human_num + 1)]
                    for time in times:
                       ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D((self.states[k - 1][0].px, self.states[k][0].px),
                                               (self.states[k - 1][0].py, self.states[k][0].py),
                                               color=robot_color, ls='solid')
                    human_directions = [plt.Line2D((self.states[k - 1][1][i].px, self.states[k][1][i].px),
                                                   (self.states[k - 1][1][i].py, self.states[k][1][i].py),
                                                   color=cmap(i), ls='solid')
                                        for i in range(self.human_num)]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ['Robot'], fontsize=16)
            plt.show()
        elif mode == 'video':
            fig, ax = plt.subplots(figsize=(8, 8))
            ax.tick_params(labelsize=12)
            ax.set_xlim(-11, 11)
            ax.set_ylim(-11, 11)
            ax.set_xlabel('x(m)', fontsize=14)
            ax.set_ylabel('y(m)', fontsize=14)
            show_human_start_goal = True
            show_sensor_range = True
            show_eval_info = True
            show_social_zone = True

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            if show_human_start_goal:
                for i in range(len(self.humans)):
                    human = self.humans[i]
                    human_goal = mlines.Line2D([human.get_goal_position()[0]], [human.get_goal_position()[1]],
                                               color=human_colors[i],
                                               marker='*', linestyle='None', markersize=8)
                    ax.add_artist(human_goal)
                    human_start = mlines.Line2D([human.get_start_position()[0]], [human.get_start_position()[1]],
                                                color=human_colors[i],
                                                marker='o', linestyle='None', markersize=8)
                    ax.add_artist(human_start)
            # add robot start position
            robot_start = mlines.Line2D([self.robot.get_start_position()[0]], [self.robot.get_start_position()[1]],
                                        color=robot_color,
                                        marker='o', linestyle='None', markersize=8, label='Start')
            robot_start_position = [self.robot.get_start_position()[0], self.robot.get_start_position()[1]]
            ax.add_artist(robot_start)
            # add robot and its goal 
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D([self.robot.get_goal_position()[0]], [self.robot.get_goal_position()[1]],
                                 color=robot_color, marker='*', linestyle='None',
                                 markersize=15, label='Goal')
            robot = plt.Circle(robot_positions[0], self.robot.radius, fill=False, color=robot_color)
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal, robot_start], ['Robot', 'Goal', 'Start'], fontsize=14)
            # if show_sensor_range:
            #     sensor_range = plt.Circle(robot_positions[0], self.robot_sensor_range, fill=False, ls='dashed')
            #     ax.add_artist(sensor_range)


            # add humans and their numbers
            human_positions = [[state[1][j].position for j in range(len(self.humans))] for state in self.states]
            humans = [plt.Circle(human_positions[0][i], self.humans[i].radius, fill=False, color=cmap(i))
                      for i in range(len(self.humans))]

            # disable showing human numbers
            if display_numbers:
                human_numbers = [plt.text(humans[i].center[0] - x_offset, humans[i].center[1] + y_offset, str(i),
                                          color='black') for i in range(len(self.humans))]
            
            for i, human in enumerate(humans):
                ax.add_artist(human)
                if display_numbers:
                    ax.add_artist(human_numbers[i])

            # add time annotation
            time = plt.text(0.5, 0.9, f'Time: {0}', fontsize=16, transform=ax.transAxes, horizontalalignment='center',
                verticalalignment='center')
            ax.add_artist(time)

            # add evaluation annotation
            if show_eval_info:
                eval_text = plt.text(0.6, 0.07, 
                    f"Aggregated Time: {0}\nMinimum Separation: {0}\nSocial Zone Violations: {0}\nJerk Cost: {0}",
                    fontsize=12, transform=ax.transAxes, horizontalalignment='left', verticalalignment='center')

            # calculate evaluation information
            list_aggregated_time = [self.infos[0]['aggregated_time']]
            list_min_separation = [self.infos[0]['min_separation']]
            list_personal_violation_cnt = [self.infos[0]['personal_violation_cnt']]
            list_social_violation_cnt = [self.infos[0]['social_violation_cnt']]
            list_jerk_cost = [self.infos[0]['jerk_cost']]
            for i in range(1, len(self.infos)):
                list_aggregated_time.append(list_aggregated_time[i-1] + self.infos[i]['aggregated_time'])
                list_min_separation.append(min(list_min_separation[i-1], self.infos[i]['min_separation']))
                list_social_violation_cnt.append(list_social_violation_cnt[i-1] + self.infos[i]['social_violation_cnt'])
                list_jerk_cost.append(list_jerk_cost[i-1] + self.infos[i]['jerk_cost'])
                list_personal_violation_cnt.append(list_personal_violation_cnt[i-1] + self.infos[i]['personal_violation_cnt'])

            # visualize attention scores
            # if hasattr(self.robot.policy, 'get_attention_weights'):
            #     attention_scores = [
            #         plt.text(-5.5, 5 - 0.5 * i, 'Human {}: {:.2f}'.format(i + 1, self.attention_weights[0][i]),
            #                  fontsize=16) for i in range(len(self.humans))]

            # compute social zone for each step
            social_zones_all_agents = []
            
            if show_social_zone:
                for i in range(self.human_num + 1):
                    social_zones = []
                    step_cnt = 0
                    for state in self.states:
                        step_cnt += 1
                        agent_state = state[0] if i == self.human_num else state[1][i]
                        if i == self.human_num: # robot
                            rect = AgentHeadingRect(agent_state.px, agent_state.py, self.robot.radius, agent_state.vx, agent_state.vy, self.robot.kinematics)
                            if step_cnt < len(self.infos) and self.infos[step_cnt]['social_violation_cnt'] > 0:
                                rect.color = 'red'
                        else:
                            rect = AgentHeadingRect(agent_state.px, agent_state.py, self.humans[i].radius, agent_state.vx, agent_state.vy, self.humans[i].kinematics)
                        social_zones.append(rect.get_pyplot_rect())
                    social_zones_all_agents.append(social_zones)

            # draw the zones for the first step
            social_zones_drawn = []
            for zones in social_zones_all_agents:
                ax.add_artist(zones[0])
                social_zones_drawn.append(zones[0])

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            orientations = []
            for i in range(self.human_num + 1):
                orientation = []
                for state in self.states:
                    agent_state = state[0] if i == 0 else state[1][i - 1]
                    if self.robot.kinematics == 'unicycle' and i == 0: # =========================================================== TODO: why unicycle only?
                        direction = (
                        (agent_state.px, agent_state.py), (agent_state.px + radius * np.cos(agent_state.theta),
                                                           agent_state.py + radius * np.sin(agent_state.theta)))
                    else:
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        direction = ((agent_state.px, agent_state.py), (agent_state.px + radius * np.cos(theta),
                                                                        agent_state.py + radius * np.sin(theta)))
                    orientation.append(direction)
                orientations.append(orientation)
                if i == 0:
                    arrow_color = 'black'
                    arrows = [patches.FancyArrowPatch(*orientation[0], color=arrow_color, arrowstyle=arrow_style)]
                else:
                    arrows.extend(
                        [patches.FancyArrowPatch(*orientation[0], color=human_colors[i - 1], arrowstyle=arrow_style)])
            for arrow in arrows:
                ax.add_artist(arrow)

            global_step = 0

            if len(self.trajs) != 0:
                human_future_positions = []
                human_future_circles = []
                for traj in self.trajs:
                    human_future_position = [[tensor_to_joint_state(traj[step+1][0]).human_states[i].position
                                              for step in range(self.robot.policy.planning_depth)]
                                             for i in range(self.human_num)]
                    human_future_positions.append(human_future_position)

                for i in range(self.human_num):
                    circles = []
                    for j in range(self.robot.policy.planning_depth):
                        circle = plt.Circle(human_future_positions[0][i][j], self.humans[0].radius/(1.7+j), fill=False, color=cmap(i))
                        ax.add_artist(circle)
                        circles.append(circle)
                    human_future_circles.append(circles)

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                nonlocal social_zones_drawn
                global_step = frame_num
                robot.center = robot_positions[frame_num]

                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    if display_numbers:
                        human_numbers[i].set_position((human.center[0] - x_offset, human.center[1] + y_offset))
                for arrow in arrows:
                    arrow.remove()
                for zone in social_zones_drawn: # remove last step's social zones
                    zone.remove()

                # draw social zones for each step
                if show_social_zone:
                    social_zones_drawn = []
                    for i in range(self.human_num + 1):
                        zones = social_zones_all_agents[i]
                        social_zones_drawn.append(zones[frame_num])
                        ax.add_artist(zones[frame_num])

                for i in range(self.human_num + 1):
                    orientation = orientations[i]
                    if i == 0:
                        arrows = [patches.FancyArrowPatch(*orientation[frame_num], color='black',
                                                          arrowstyle=arrow_style)]
                    else:
                        arrows.extend([patches.FancyArrowPatch(*orientation[frame_num], color=cmap(i - 1),
                                                               arrowstyle=arrow_style)])

                for arrow in arrows:
                    ax.add_artist(arrow)
                    # if hasattr(self.robot.policy, 'get_attention_weights'):
                    #     attention_scores[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text('Time: {:.2f}'.format(frame_num * self.time_step))

                if show_eval_info:
                    eval_text.set_text(f"Aggregated Time: {list_aggregated_time[frame_num]}\
                        \nPersonal Zone Violations: {list_personal_violation_cnt[frame_num]}\
                        \nSocial Zone Violations: {list_social_violation_cnt[frame_num]}\
                        \nJerk Cost: {list_jerk_cost[frame_num]: .3f}")

                if len(self.trajs) != 0:
                    for i, circles in enumerate(human_future_circles):
                        for j, circle in enumerate(circles):
                            circle.center = human_future_positions[global_step][i][j]

            def plot_value_heatmap():
                if self.robot.kinematics != 'holonomic':
                    print('Kinematics is not holonomic')
                    return
                # for agent in [self.states[global_step][0]] + self.states[global_step][1]:
                #     print(('{:.4f}, ' * 6 + '{:.4f}').format(agent.px, agent.py, agent.gx, agent.gy,
                #                                              agent.vx, agent.vy, agent.theta))

                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step % len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(z, (self.robot.policy.rotation_samples, self.robot.policy.speed_samples))
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color='k', ls='none')
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def print_matrix_A():
                # with np.printoptions(precision=3, suppress=True):
                #     print(self.As[global_step])
                h, w = self.As[global_step].shape
                print('   ' + ' '.join(['{:>5}'.format(i - 1) for i in range(w)]))
                for i in range(h):
                    print('{:<3}'.format(i-1) + ' '.join(['{:.3f}'.format(self.As[global_step][i][j]) for j in range(w)]))
                # with np.printoptions(precision=3, suppress=True):
                #     print('A is: ')
                #     print(self.As[global_step])

            def print_feat():
                with np.printoptions(precision=3, suppress=True):
                    print('feat is: ')
                    print(self.feats[global_step])

            def print_X():
                with np.printoptions(precision=3, suppress=True):
                    print('X is: ')
                    print(self.Xs[global_step])

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                    if event.key == 'a':
                        if hasattr(self.robot.policy, 'get_matrix_A'):
                            print_matrix_A()
                        if hasattr(self.robot.policy, 'get_feat'):
                            print_feat()
                        if hasattr(self.robot.policy, 'get_X'):
                            print_X()
                        # if hasattr(self.robot.policy, 'action_values'):
                        #    plot_value_heatmap()
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect('key_press_event', on_click)
            anim = animation.FuncAnimation(fig, update, frames=len(self.states), interval=self.time_step * 500, repeat_delay=500)
            anim.running = True

            if output_file is not None:
                # save as video
                # ffmpeg_writer = animation.FFMpegWriter(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                # anim.save(output_file, writer=ffmpeg_writer)

                # save output file as gif if imagemagic is installed
                plt.rcParams["animation.convert_path"] = r'/usr/bin/convert'
                anim.save(output_file, writer='imagemagick', fps=12)
            else:
                plt.show()
        else:
            raise NotImplementedError
    def predict(self, state):
        """
        A base class for all methods that takes pairwise joint state as input to value network.
        The input to the value network is always of shape (batch_size, # humans, rotated joint state length)

        """
        if self.phase is None or self.device is None:
            raise AttributeError('Phase, device attributes have to be set!')
        if self.phase == 'train' and self.epsilon is None:
            raise AttributeError('Epsilon attribute has to be set in training phase')

        if self.reach_destination(state):
            return ActionXY(0, 0) if self.kinematics == 'holonomic' else ActionRot(0, 0)
        if self.action_space is None:
            self.build_action_space(state.robot_state.v_pref)

        probability = np.random.random()
        if self.phase == 'train' and probability < self.epsilon:
            max_action_index = np.random.choice(len(self.action_space))
            max_action = self.action_space[max_action_index]
        else:
            max_action = None
            max_value = float('-inf')
            max_traj = None

            if self.do_action_clip:
                state_tensor = state.to_tensor(add_batch_size=True, device=self.device)
                action_space_clipped = self.action_clip(state_tensor, self.action_space, self.planning_width)
            else:
                action_space_clipped = self.action_space
            state_tensor = state.to_tensor(add_batch_size=True, device=self.device)
            actions = []
            if self.kinematics == "holonomic":
                actions.append(ActionXY(0, 0))
            else:
                actions.append(ActionRot(0, 0))
            # actions.append(ActionXY(0, 0))
            pre_next_state = self.state_predictor(state_tensor, actions)
            next_robot_states = None
            next_human_states = None
            next_value = []
            rewards = []
            for action in action_space_clipped:
                next_robot_state = self.compute_next_robot_state(state_tensor[0], action)
                next_human_state = pre_next_state[1]
                if next_robot_states is None and next_human_states is None:
                    next_robot_states = next_robot_state
                    next_human_states = next_human_state
                else:
                    next_robot_states = torch.cat((next_robot_states, next_robot_state), dim=0)
                    next_human_states = torch.cat((next_human_states, next_human_state), dim=0)
                next_state = tensor_to_joint_state((next_robot_state, next_human_state))
                reward_est, _ = self.reward_estimator.estimate_reward_on_predictor(state, next_state)
                max_next_return, max_next_traj = self.V_planning((next_robot_state, next_human_state), self.planning_depth, self.planning_width)
                value = reward_est + self.get_normalized_gamma() * max_next_return
                if value > max_value:
                    max_value = value
                    max_action = action
                    max_traj = [(state_tensor, action, reward_est)] + max_next_traj
                # reward_est = self.estimate_reward(state, action)
                # rewards.append(reward_est)
                # next_state = self.state_predictor(state_tensor, action)
            # rewards_tensor = torch.tensor(rewards).to(self.device)
            # next_state_batch = (next_robot_states, next_human_states)
            # next_value = self.value_estimator(next_state_batch).squeeze(1)
            # value = rewards_tensor + next_value * self.get_normalized_gamma()
            # max_action_index = value.argmax()
            # best_value = value[max_action_index]
            # if best_value > max_value:
            #     max_action = action_space_clipped[max_action_index]
            #
            #     next_state = tensor_to_joint_state((next_robot_states[max_action_index], next_human_states[max_action_index]))
            #     max_next_traj = [(next_state.to_tensor(), None, None)]
            #     # max_next_return, max_next_traj = self.V_planning(next_state, self.planning_depth, self.planning_width)
            #     # reward_est = self.estimate_reward(state, action)
            #     # value = reward_est + self.get_normalized_gamma() * max_next_return
            #     # if value > max_value:
            #     #     max_value = value
            #     #     max_action = action
            #     max_traj = [(state_tensor, max_action, rewards[max_action_index])] + max_next_traj
            if max_action is None:
                raise ValueError('Value network is not well trained.')

        if self.phase == 'train':
            self.last_state = self.transform(state)
        else:
            self.traj = max_traj
        for action_index in range(len(self.action_space)):
            action = self.action_space[action_index]
            if action is max_action:
                max_action_index = action_index
                break
        return max_action, int(max_action_index)
Exemple #7
0
 def estimate_reward_on_predictor(self, state, next_state):
     """ If the time step is small enough, it's okay to model agent as linear movement during this period
     """
     # collision detection
     info = Nothing()
     if isinstance(state, list) or isinstance(state, tuple):
         state = tensor_to_joint_state(state)
     human_states = state.human_states
     robot_state = state.robot_state
     weight_goal = self.goal_factor
     weight_safe = self.discomfort_penalty_factor
     weight_terminal = 1.0
     re_collision = self.collision_penalty
     re_arrival = self.success_reward
     next_robot_state = next_state.robot_state
     next_human_states = next_state.human_states
     cur_position = np.array((robot_state.px, robot_state.py))
     end_position = np.array((next_robot_state.px, next_robot_state.py))
     goal_position = np.array((robot_state.gx, robot_state.gy))
     reward_goal = (norm(cur_position - goal_position) -
                    norm(end_position - goal_position))
     # check if reaching the goal
     reaching_goal = norm(
         end_position -
         np.array([robot_state.gx, robot_state.gy])) < robot_state.radius
     dmin = float('inf')
     collision = False
     safety_penalty = 0
     if human_states is None or next_human_states is None:
         safety_penalty = 0.0
         collision = False
     else:
         for i, human in enumerate(human_states):
             next_human = next_human_states[i]
             px = human.px - robot_state.px
             py = human.py - robot_state.py
             ex = next_human.px - next_robot_state.px
             ey = next_human.py - next_robot_state.py
             # closest distance between boundaries of two agents
             closest_dist = point_to_segment_dist(
                 px, py, ex, ey, 0, 0) - human.radius - robot_state.radius
             if closest_dist < 0:
                 collision = True
             if closest_dist < dmin:
                 dmin = closest_dist
             if closest_dist < self.discomfort_dist:
                 safety_penalty = safety_penalty + (closest_dist -
                                                    self.discomfort_dist)
         # dis_begin = np.sqrt(px ** 2 + py ** 2) - human.radius - robot_state.radius
         # dis_end = np.sqrt(ex ** 2 + ey ** 2) - human.radius - robot_state.radius
         # penalty_begin = 0
         # penalty_end = 0
         # discomfort_dist = 0.5
         # if dis_begin < discomfort_dist:
         #     penalty_begin = dis_begin - discomfort_dist
         # if dis_end < discomfort_dist:
         #     penalty_end = dis_end - discomfort_dist
         # safety_penalty = safety_penalty + (penalty_end - penalty_begin)
     reward_col = 0
     reward_arrival = 0
     if collision:
         reward_col = re_collision
         info = Collision()
     elif reaching_goal:
         reward_arrival = re_arrival
         info = ReachGoal()
     reward_terminal = reward_col + reward_arrival
     reward = weight_terminal * reward_terminal + weight_goal * reward_goal + weight_safe * safety_penalty
     # if collision:
     # reward = reward - 100
     return reward, info
Exemple #8
0
    def V_planning(self, state, depth, width):
        """ Plans n steps into future based on state action value function. Computes the value for the current state as well as the trajectories
        defined as a list of (state, action, reward) triples
        """
        # current_state_value = self.value_estimator(state)
        robot_state_batch = state[0]
        human_state_batch = state[1]
        if state[1] is None:
            if depth == 0:
                q_value = torch.Tensor(self.value_estimator(state))
                max_action_value, max_action_indexes = torch.max(q_value,
                                                                 dim=1)
                trajs = []
                for i in range(robot_state_batch.shape[0]):
                    cur_state = (robot_state_batch[i, :, :].unsqueeze(0), None)
                    trajs.append([(cur_state, None, None)])
                return max_action_value, max_action_indexes, trajs
            else:
                q_value = torch.Tensor(self.value_estimator(state))
                max_action_value, max_action_indexes = torch.topk(q_value,
                                                                  width,
                                                                  dim=1)
            action_stay = []
            for i in range(robot_state_batch.shape[0]):
                if self.kinematics == "holonomic":
                    action_stay.append(ActionXY(0, 0))
                else:
                    action_stay.append(ActionRot(0, 0))
            pre_next_state = None
            next_robot_state_batch = None
            next_human_state_batch = None
            reward_est = torch.zeros(state[0].shape[0], width) * float('inf')

            for i in range(robot_state_batch.shape[0]):
                cur_state = (robot_state_batch[i, :, :].unsqueeze(0), None)
                next_human_state = None
                for j in range(width):
                    cur_action = self.action_space[max_action_indexes[i][j]]
                    next_robot_state = self.compute_next_robot_state(
                        cur_state[0], cur_action)
                    if next_robot_state_batch is None:
                        next_robot_state_batch = next_robot_state
                    else:
                        next_robot_state_batch = torch.cat(
                            (next_robot_state_batch, next_robot_state), dim=0)
                    reward_est[i][
                        j], _ = self.reward_estimator.estimate_reward_on_predictor(
                            tensor_to_joint_state(cur_state),
                            tensor_to_joint_state(
                                (next_robot_state, next_human_state)))

            next_state_batch = (next_robot_state_batch, next_human_state_batch)
            if self.planning_depth - depth >= 2 and self.planning_depth > 2:
                cur_width = 1
            else:
                cur_width = int(self.planning_width / 2)
            next_values, next_action_indexes, next_trajs = self.V_planning(
                next_state_batch, depth - 1, cur_width)
            next_values = next_values.view(state[0].shape[0], width)
            returns = (reward_est + self.get_normalized_gamma() * next_values +
                       max_action_value) / 2

            max_action_return, max_action_index = torch.max(returns, dim=1)
            trajs = []
            max_returns = []
            max_actions = []
            for i in range(robot_state_batch.shape[0]):
                cur_state = (robot_state_batch[i, :, :].unsqueeze(0), None)
                action_id = max_action_index[i]
                trajs_id = i * width + action_id
                action = max_action_indexes[i][action_id]
                next_traj = next_trajs[trajs_id]
                trajs.append([(cur_state, action, reward_est)] + next_traj)
                max_returns.append(max_action_return[i].data)
                max_actions.append(action)
            max_returns = torch.tensor(max_returns)
            return max_returns, max_actions, trajs
        else:
            if depth == 0:
                q_value = torch.Tensor(self.value_estimator(state))
                max_action_value, max_action_indexes = torch.max(q_value,
                                                                 dim=1)
                trajs = []
                for i in range(robot_state_batch.shape[0]):
                    cur_state = (robot_state_batch[i, :, :].unsqueeze(0),
                                 human_state_batch[i, :, :].unsqueeze(0))
                    trajs.append([(cur_state, None, None)])
                return max_action_value, max_action_indexes, trajs
            else:
                q_value = torch.Tensor(self.value_estimator(state))
                max_action_value, max_action_indexes = torch.topk(q_value,
                                                                  width,
                                                                  dim=1)
            action_stay = []
            for i in range(robot_state_batch.shape[0]):
                if self.kinematics == "holonomic":
                    action_stay.append(ActionXY(0, 0))
                else:
                    action_stay.append(ActionRot(0, 0))
            _, pre_next_state = self.state_predictor(state, action_stay)
            next_robot_state_batch = None
            next_human_state_batch = None
            reward_est = torch.zeros(state[0].shape[0], width) * float('inf')

            for i in range(robot_state_batch.shape[0]):
                cur_state = (robot_state_batch[i, :, :].unsqueeze(0),
                             human_state_batch[i, :, :].unsqueeze(0))
                next_human_state = pre_next_state[i, :, :].unsqueeze(0)
                for j in range(width):
                    cur_action = self.action_space[max_action_indexes[i][j]]
                    next_robot_state = self.compute_next_robot_state(
                        cur_state[0], cur_action)
                    if next_robot_state_batch is None:
                        next_robot_state_batch = next_robot_state
                        next_human_state_batch = next_human_state
                    else:
                        next_robot_state_batch = torch.cat(
                            (next_robot_state_batch, next_robot_state), dim=0)
                        next_human_state_batch = torch.cat(
                            (next_human_state_batch, next_human_state), dim=0)
                    reward_est[i][
                        j], _ = self.reward_estimator.estimate_reward_on_predictor(
                            tensor_to_joint_state(cur_state),
                            tensor_to_joint_state(
                                (next_robot_state, next_human_state)))
            next_state_batch = (next_robot_state_batch, next_human_state_batch)
            if self.planning_depth - depth >= 2 and self.planning_depth > 2:
                cur_width = 1
            else:
                cur_width = int(self.planning_width / 2)
            next_values, next_action_indexes, next_trajs = self.V_planning(
                next_state_batch, depth - 1, cur_width)
            next_values = next_values.view(state[0].shape[0], width)
            returns = (reward_est + self.get_normalized_gamma() * next_values +
                       max_action_value) / 2

            max_action_return, max_action_index = torch.max(returns, dim=1)
            trajs = []
            max_returns = []
            max_actions = []
            for i in range(robot_state_batch.shape[0]):
                cur_state = (robot_state_batch[i, :, :].unsqueeze(0),
                             human_state_batch[i, :, :].unsqueeze(0))
                action_id = max_action_index[i]
                trajs_id = i * width + action_id
                action = max_action_indexes[i][action_id]
                next_traj = next_trajs[trajs_id]
                trajs.append([(cur_state, action, reward_est)] + next_traj)
                max_returns.append(max_action_return[i].data)
                max_actions.append(action)
            max_returns = torch.tensor(max_returns)
            return max_returns, max_actions, trajs
    def render(self, mode="video", output_file=None):
        from matplotlib import animation
        import matplotlib.pyplot as plt

        # plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
        x_offset = 0.2
        y_offset = 0.4
        cmap = plt.cm.get_cmap("hsv", 10)
        robot_color = "black"
        arrow_style = patches.ArrowStyle("->", head_length=4, head_width=2)
        display_numbers = True

        if mode == "traj":
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=16)
            ax.set_xlim(-5, 5)
            ax.set_ylim(-5, 5)
            ax.set_xlabel("x(m)", fontsize=16)
            ax.set_ylabel("y(m)", fontsize=16)

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            for i in range(len(self.humans)):
                human = self.humans[i]
                human_goal = mlines.Line2D(
                    [human.get_goal_position()[0]],
                    [human.get_goal_position()[1]],
                    color=human_colors[i],
                    marker="*",
                    linestyle="None",
                    markersize=15,
                )
                ax.add_artist(human_goal)
                human_start = mlines.Line2D(
                    [human.get_start_position()[0]],
                    [human.get_start_position()[1]],
                    color=human_colors[i],
                    marker="o",
                    linestyle="None",
                    markersize=15,
                )
                ax.add_artist(human_start)

            robot_positions = [
                self.states[i][0].position for i in range(len(self.states))
            ]
            human_positions = [
                [self.states[i][1][j].position for j in range(len(self.humans))]
                for i in range(len(self.states))
            ]

            for k in range(len(self.states)):
                if k % 4 == 0 or k == len(self.states) - 1:
                    robot = plt.Circle(
                        robot_positions[k],
                        self.robot.radius,
                        fill=False,
                        color=robot_color,
                    )
                    humans = [
                        plt.Circle(
                            human_positions[k][i],
                            self.humans[i].radius,
                            fill=False,
                            color=cmap(i),
                        )
                        for i in range(len(self.humans))
                    ]
                    ax.add_artist(robot)
                    for human in humans:
                        ax.add_artist(human)

                # add time annotation
                global_time = k * self.time_step
                if global_time % 4 == 0 or k == len(self.states) - 1:
                    agents = humans + [robot]
                    times = [
                        plt.text(
                            agents[i].center[0] - x_offset,
                            agents[i].center[1] - y_offset,
                            "{:.1f}".format(global_time),
                            color="black",
                            fontsize=14,
                        )
                        for i in range(self.human_num + 1)
                    ]
                    for time in times:
                        ax.add_artist(time)
                if k != 0:
                    nav_direction = plt.Line2D(
                        (self.states[k - 1][0].px, self.states[k][0].px),
                        (self.states[k - 1][0].py, self.states[k][0].py),
                        color=robot_color,
                        ls="solid",
                    )
                    human_directions = [
                        plt.Line2D(
                            (self.states[k - 1][1][i].px, self.states[k][1][i].px),
                            (self.states[k - 1][1][i].py, self.states[k][1][i].py),
                            color=cmap(i),
                            ls="solid",
                        )
                        for i in range(self.human_num)
                    ]
                    ax.add_artist(nav_direction)
                    for human_direction in human_directions:
                        ax.add_artist(human_direction)
            plt.legend([robot], ["Robot"], fontsize=16)
            plt.show()
        elif mode == "video":
            fig, ax = plt.subplots(figsize=(7, 7))
            ax.tick_params(labelsize=12)
            ax.set_xlim(-11, 11)
            ax.set_ylim(-11, 11)
            ax.set_xlabel("x(m)", fontsize=14)
            ax.set_ylabel("y(m)", fontsize=14)
            show_human_start_goal = False

            # add human start positions and goals
            human_colors = [cmap(i) for i in range(len(self.humans))]
            if show_human_start_goal:
                for i in range(len(self.humans)):
                    human = self.humans[i]
                    human_goal = mlines.Line2D(
                        [human.get_goal_position()[0]],
                        [human.get_goal_position()[1]],
                        color=human_colors[i],
                        marker="*",
                        linestyle="None",
                        markersize=8,
                    )
                    ax.add_artist(human_goal)
                    human_start = mlines.Line2D(
                        [human.get_start_position()[0]],
                        [human.get_start_position()[1]],
                        color=human_colors[i],
                        marker="o",
                        linestyle="None",
                        markersize=8,
                    )
                    ax.add_artist(human_start)
            # add robot start position
            robot_start = mlines.Line2D(
                [self.robot.get_start_position()[0]],
                [self.robot.get_start_position()[1]],
                color=robot_color,
                marker="o",
                linestyle="None",
                markersize=8,
            )
            robot_start_position = [
                self.robot.get_start_position()[0],
                self.robot.get_start_position()[1],
            ]
            ax.add_artist(robot_start)
            # add robot and its goal
            robot_positions = [state[0].position for state in self.states]
            goal = mlines.Line2D(
                [self.robot.get_goal_position()[0]],
                [self.robot.get_goal_position()[1]],
                color=robot_color,
                marker="*",
                linestyle="None",
                markersize=15,
                label="Goal",
            )
            robot = plt.Circle(
                robot_positions[0], self.robot.radius, fill=False, color=robot_color
            )
            # sensor_range = plt.Circle(robot_positions[0], self.robot_sensor_range, fill=False, ls='dashed')
            ax.add_artist(robot)
            ax.add_artist(goal)
            plt.legend([robot, goal], ["Robot", "Goal"], fontsize=14)

            # add humans and their numbers
            human_positions = [
                [state[1][j].position for j in range(len(self.humans))]
                for state in self.states
            ]
            humans = [
                plt.Circle(
                    human_positions[0][i],
                    self.humans[i].radius,
                    fill=False,
                    color=cmap(i),
                )
                for i in range(len(self.humans))
            ]

            # disable showing human numbers
            if display_numbers:
                human_numbers = [
                    plt.text(
                        humans[i].center[0] - x_offset,
                        humans[i].center[1] + y_offset,
                        str(i),
                        color="black",
                    )
                    for i in range(len(self.humans))
                ]

            for i, human in enumerate(humans):
                ax.add_artist(human)
                if display_numbers:
                    ax.add_artist(human_numbers[i])

            # add time annotation
            time = plt.text(
                0.4, 0.9, "Time: {}".format(0), fontsize=16, transform=ax.transAxes
            )
            ax.add_artist(time)

            # visualize attention scores
            # if hasattr(self.robot.policy, 'get_attention_weights'):
            #     attention_scores = [
            #         plt.text(-5.5, 5 - 0.5 * i, 'Human {}: {:.2f}'.format(i + 1, self.attention_weights[0][i]),
            #                  fontsize=16) for i in range(len(self.humans))]

            # compute orientation in each step and use arrow to show the direction
            radius = self.robot.radius
            orientations = []
            for i in range(self.human_num + 1):
                orientation = []
                for state in self.states:
                    agent_state = state[0] if i == 0 else state[1][i - 1]
                    if self.robot.kinematics == "unicycle" and i == 0:
                        direction = (
                            (agent_state.px, agent_state.py),
                            (
                                agent_state.px + radius * np.cos(agent_state.theta),
                                agent_state.py + radius * np.sin(agent_state.theta),
                            ),
                        )
                    else:
                        theta = np.arctan2(agent_state.vy, agent_state.vx)
                        direction = (
                            (agent_state.px, agent_state.py),
                            (
                                agent_state.px + radius * np.cos(theta),
                                agent_state.py + radius * np.sin(theta),
                            ),
                        )
                    orientation.append(direction)
                orientations.append(orientation)
                if i == 0:
                    arrow_color = "black"
                    arrows = [
                        patches.FancyArrowPatch(
                            *orientation[0], color=arrow_color, arrowstyle=arrow_style
                        )
                    ]
                else:
                    arrows.extend(
                        [
                            patches.FancyArrowPatch(
                                *orientation[0],
                                color=human_colors[i - 1],
                                arrowstyle=arrow_style
                            )
                        ]
                    )

            for arrow in arrows:
                ax.add_artist(arrow)
            global_step = 0

            if len(self.trajs) != 0:
                human_future_positions = []
                human_future_circles = []
                for traj in self.trajs:
                    human_future_position = [
                        [
                            tensor_to_joint_state(traj[step + 1][0])
                            .human_states[i]
                            .position
                            for step in range(self.robot.policy.planning_depth)
                        ]
                        for i in range(self.human_num)
                    ]
                    human_future_positions.append(human_future_position)

                for i in range(self.human_num):
                    circles = []
                    for j in range(self.robot.policy.planning_depth):
                        circle = plt.Circle(
                            human_future_positions[0][i][j],
                            self.humans[0].radius / (1.7 + j),
                            fill=False,
                            color=cmap(i),
                        )
                        ax.add_artist(circle)
                        circles.append(circle)
                    human_future_circles.append(circles)

            def update(frame_num):
                nonlocal global_step
                nonlocal arrows
                global_step = frame_num
                robot.center = robot_positions[frame_num]

                for i, human in enumerate(humans):
                    human.center = human_positions[frame_num][i]
                    if display_numbers:
                        human_numbers[i].set_position(
                            (human.center[0] - x_offset, human.center[1] + y_offset)
                        )
                for arrow in arrows:
                    arrow.remove()

                for i in range(self.human_num + 1):
                    orientation = orientations[i]
                    if i == 0:
                        arrows = [
                            patches.FancyArrowPatch(
                                *orientation[frame_num],
                                color="black",
                                arrowstyle=arrow_style
                            )
                        ]
                    else:
                        arrows.extend(
                            [
                                patches.FancyArrowPatch(
                                    *orientation[frame_num],
                                    color=cmap(i - 1),
                                    arrowstyle=arrow_style
                                )
                            ]
                        )

                for arrow in arrows:
                    ax.add_artist(arrow)
                    # if hasattr(self.robot.policy, 'get_attention_weights'):
                    #     attention_scores[i].set_text('human {}: {:.2f}'.format(i, self.attention_weights[frame_num][i]))

                time.set_text("Time: {:.2f}".format(frame_num * self.time_step))

                if len(self.trajs) != 0:
                    for i, circles in enumerate(human_future_circles):
                        for j, circle in enumerate(circles):
                            circle.center = human_future_positions[global_step][i][j]

            def plot_value_heatmap():
                if self.robot.kinematics != "holonomic":
                    print("Kinematics is not holonomic")
                    return
                # for agent in [self.states[global_step][0]] + self.states[global_step][1]:
                #     print(('{:.4f}, ' * 6 + '{:.4f}').format(agent.px, agent.py, agent.gx, agent.gy,
                #                                              agent.vx, agent.vy, agent.theta))

                # when any key is pressed draw the action value plot
                fig, axis = plt.subplots()
                speeds = [0] + self.robot.policy.speeds
                rotations = self.robot.policy.rotations + [np.pi * 2]
                r, th = np.meshgrid(speeds, rotations)
                z = np.array(self.action_values[global_step % len(self.states)][1:])
                z = (z - np.min(z)) / (np.max(z) - np.min(z))
                z = np.reshape(
                    z,
                    (
                        self.robot.policy.rotation_samples,
                        self.robot.policy.speed_samples,
                    ),
                )
                polar = plt.subplot(projection="polar")
                polar.tick_params(labelsize=16)
                mesh = plt.pcolormesh(th, r, z, vmin=0, vmax=1)
                plt.plot(rotations, r, color="k", ls="none")
                plt.grid()
                cbaxes = fig.add_axes([0.85, 0.1, 0.03, 0.8])
                cbar = plt.colorbar(mesh, cax=cbaxes)
                cbar.ax.tick_params(labelsize=16)
                plt.show()

            def print_matrix_A():
                # with np.printoptions(precision=3, suppress=True):
                #     print(self.As[global_step])
                h, w = self.As[global_step].shape
                print("   " + " ".join(["{:>5}".format(i - 1) for i in range(w)]))
                for i in range(h):
                    print(
                        "{:<3}".format(i - 1)
                        + " ".join(
                            [
                                "{:.3f}".format(self.As[global_step][i][j])
                                for j in range(w)
                            ]
                        )
                    )
                # with np.printoptions(precision=3, suppress=True):
                #     print('A is: ')
                #     print(self.As[global_step])

            def print_feat():
                with np.printoptions(precision=3, suppress=True):
                    print("feat is: ")
                    print(self.feats[global_step])

            def print_X():
                with np.printoptions(precision=3, suppress=True):
                    print("X is: ")
                    print(self.Xs[global_step])

            def on_click(event):
                if anim.running:
                    anim.event_source.stop()
                    if event.key == "a":
                        if hasattr(self.robot.policy, "get_matrix_A"):
                            print_matrix_A()
                        if hasattr(self.robot.policy, "get_feat"):
                            print_feat()
                        if hasattr(self.robot.policy, "get_X"):
                            print_X()
                        # if hasattr(self.robot.policy, 'action_values'):
                        #    plot_value_heatmap()
                else:
                    anim.event_source.start()
                anim.running ^= True

            fig.canvas.mpl_connect("key_press_event", on_click)
            anim = animation.FuncAnimation(
                fig, update, frames=len(self.states), interval=self.time_step * 500
            )
            anim.running = True

            if output_file is not None:
                # save as video
                ffmpeg_writer = animation.FFMpegWriter(
                    fps=10, metadata=dict(artist="Me"), bitrate=1800
                )
                # writer = ffmpeg_writer(fps=10, metadata=dict(artist='Me'), bitrate=1800)
                anim.save(output_file, writer=ffmpeg_writer)

                # save output file as gif if imagemagic is installed
                # anim.save(output_file, writer='imagemagic', fps=12)
            else:
                plt.show()
        else:
            raise NotImplementedError