def plot_rollouts(states: List[prediction.HighwayState], agent, env): Settings.ensure_run_plot_directory() plot_directory = os.path.join(Settings.FULL_LOG_DIR, "plots") for j, state in enumerate(states): start_state = state plt.figure() vector_state = dqn.get_state_vector_from_base_state(state) encoded_state = env._make_state(vector_state, False) first_action = agent.eval(encoded_state, 0).item() future_action = first_action crash_predicted = False i = 0 state.plot_state(i) while not (crash_predicted or i > max(Settings.ROLLOUT_LENGTH, 1)): i += 1 if i != 1: vector_state = dqn.get_state_vector_from_base_state(state) future_action = agent.eval(env._make_state(vector_state, False), 0).item() current_speed = state.ego_speed current_acceleration = state.ego_acceleration selected_speed = control.get_ego_speed_from_jerk(current_speed, current_acceleration, Settings.JERK_VALUES_DQN[future_action]) state, crash_predicted = state.predict_step_with_ego(selected_speed, Settings.TICK_LENGTH) state.plot_state(i) plt.savefig("{}/{}".format(plot_directory, j)) plt.close() # Get the ST prediction s_sequence, obstacles, s_values, t_values, distances = st.get_appropriate_base_st_path_and_obstacles(start_state) st.plot_s_path(obstacles, s_values, t_values, s_sequence) plt.savefig("{}/st_{}".format(plot_directory, j)) plt.close() print("Saved crash.")