def _check_for_episode_termination(self, reset_rules_status, agents_info_map): '''Check for whether a episode should be terminated Args: reset_rules_status: dictionary of reset rules status with key as reset rule names and value as reset rule bool status agents_info_map: dictionary of agents info map with key as agent name and value as agent info Returns: tuple (string, bool, bool): episode status, pause flag, and done flag ''' episode_status = EpisodeStatus.get_episode_status(reset_rules_status) pause = False done = False # Note: check EPISODE_COMPLETE as the first item because agent might crash # at the finish line. if EpisodeStatus.EPISODE_COMPLETE.value in reset_rules_status and \ reset_rules_status[EpisodeStatus.EPISODE_COMPLETE.value]: done = True episode_status = EpisodeStatus.EPISODE_COMPLETE.value elif EpisodeStatus.CRASHED.value in reset_rules_status and \ reset_rules_status[EpisodeStatus.CRASHED.value]: # only check for crash when at RUN phase if self._ctrl_status[ AgentCtrlStatus.AGENT_PHASE.value] == AgentPhase.RUN.value: self._curr_crashed_object_name = agents_info_map[ self._agent_name_][AgentInfo.CRASHED_OBJECT_NAME.value] # check crash with all other objects besides static obstacle if 'obstacle' not in self._curr_crashed_object_name: current_progress = agents_info_map[self._agent_name_][ AgentInfo.CURRENT_PROGRESS.value] crashed_obj_info = agents_info_map[ self._curr_crashed_object_name] crashed_obj_progress = crashed_obj_info[ AgentInfo.CURRENT_PROGRESS.value] crashed_obj_start_ndist = crashed_obj_info[ AgentInfo.START_NDIST.value] crashed_object_progress = get_normalized_progress( crashed_obj_progress, start_ndist=crashed_obj_start_ndist) current_progress = get_normalized_progress( current_progress, start_ndist=self._data_dict_['start_ndist']) if current_progress < crashed_object_progress: done, pause = self._check_for_phase_change() else: episode_status = EpisodeStatus.IN_PROGRESS.value else: done, pause = self._check_for_phase_change() else: pause = True elif any(reset_rules_status.values()): done, pause = self._check_for_phase_change() return episode_status, pause, done
def judge_action(self, agents_info_map): """Judge action to see whether reset is needed Args: agents_info_map: Dictionary contains all agents info with agent name as the key and info as the value Returns: tuple: None, None, None Raises: GenericRolloutException: bot car phase is not defined """ if self.bot_car_phase == AgentPhase.RUN.value: self.pause_duration = 0.0 for agent_name, agent_info in agents_info_map.items(): if not self.track_data.is_object_collidable(agent_name): continue # check racecar crash with a bot_car crashed_object_name = ( agent_info[AgentInfo.CRASHED_OBJECT_NAME.value] if AgentInfo.CRASHED_OBJECT_NAME.value in agent_info else "") # only trainable racecar agent has 'bot_car' as possible crashed object if "bot_car" in crashed_object_name: racecar_progress = get_normalized_progress( agent_info[AgentInfo.CURRENT_PROGRESS.value], start_ndist=agent_info[AgentInfo.START_NDIST.value], ) bot_car_info = agents_info_map[crashed_object_name] bot_car_progress = get_normalized_progress( bot_car_info[AgentInfo.CURRENT_PROGRESS.value], start_ndist=bot_car_info[AgentInfo.START_NDIST.value], ) # transition to AgentPhase.PAUSE.value if racecar_progress > bot_car_progress: self.bot_cars_lane_change_end_times = [ t + self.penalty_seconds for t in self.bot_cars_lane_change_end_times ] self.bot_car_crash_count += 1 self.pause_duration += self.penalty_seconds self.bot_car_phase = AgentPhase.PAUSE.value break elif self.bot_car_phase == AgentPhase.PAUSE.value: # transition to AgentPhase.RUN.value if self.pause_duration <= 0.0: self.bot_car_phase = AgentPhase.RUN.value else: raise GenericRolloutException( "bot car phase {} is not defined".format(self.bot_car_phase)) return None, None, None