def _check_for_episode_termination(self, reset_rules_status,
                                       agents_info_map):
        '''Check for whether a episode should be terminated

        Args:
            reset_rules_status: dictionary of reset rules status with key as reset rule names and value as
                                reset rule bool status
            agents_info_map: dictionary of agents info map with key as agent name and value as agent info

        Returns:
            tuple (string, bool, bool): episode status, pause flag, and done flag
        '''
        episode_status = EpisodeStatus.get_episode_status(reset_rules_status)
        pause = False
        done = False
        # Note: check EPISODE_COMPLETE as the first item because agent might crash
        # at the finish line.
        if EpisodeStatus.EPISODE_COMPLETE.value in reset_rules_status and \
                reset_rules_status[EpisodeStatus.EPISODE_COMPLETE.value]:
            done = True
            episode_status = EpisodeStatus.EPISODE_COMPLETE.value
        elif EpisodeStatus.CRASHED.value in reset_rules_status and \
                reset_rules_status[EpisodeStatus.CRASHED.value]:
            # only check for crash when at RUN phase
            if self._ctrl_status[
                    AgentCtrlStatus.AGENT_PHASE.value] == AgentPhase.RUN.value:
                self._curr_crashed_object_name = agents_info_map[
                    self._agent_name_][AgentInfo.CRASHED_OBJECT_NAME.value]
                # check crash with all other objects besides static obstacle
                if 'obstacle' not in self._curr_crashed_object_name:
                    current_progress = agents_info_map[self._agent_name_][
                        AgentInfo.CURRENT_PROGRESS.value]
                    crashed_obj_info = agents_info_map[
                        self._curr_crashed_object_name]
                    crashed_obj_progress = crashed_obj_info[
                        AgentInfo.CURRENT_PROGRESS.value]
                    crashed_obj_start_ndist = crashed_obj_info[
                        AgentInfo.START_NDIST.value]
                    crashed_object_progress = get_normalized_progress(
                        crashed_obj_progress,
                        start_ndist=crashed_obj_start_ndist)
                    current_progress = get_normalized_progress(
                        current_progress,
                        start_ndist=self._data_dict_['start_ndist'])
                    if current_progress < crashed_object_progress:
                        done, pause = self._check_for_phase_change()
                    else:
                        episode_status = EpisodeStatus.IN_PROGRESS.value
                else:
                    done, pause = self._check_for_phase_change()
            else:
                pause = True
        elif any(reset_rules_status.values()):
            done, pause = self._check_for_phase_change()
        return episode_status, pause, done
Exemplo n.º 2
0
    def judge_action(self, agents_info_map):
        """Judge action to see whether reset is needed

        Args:
            agents_info_map: Dictionary contains all agents info with agent name as the key
                             and info as the value

        Returns:
            tuple: None, None, None

        Raises:
            GenericRolloutException: bot car phase is not defined
        """
        if self.bot_car_phase == AgentPhase.RUN.value:
            self.pause_duration = 0.0
            for agent_name, agent_info in agents_info_map.items():
                if not self.track_data.is_object_collidable(agent_name):
                    continue
                # check racecar crash with a bot_car
                crashed_object_name = (
                    agent_info[AgentInfo.CRASHED_OBJECT_NAME.value] if
                    AgentInfo.CRASHED_OBJECT_NAME.value in agent_info else "")
                # only trainable racecar agent has 'bot_car' as possible crashed object
                if "bot_car" in crashed_object_name:
                    racecar_progress = get_normalized_progress(
                        agent_info[AgentInfo.CURRENT_PROGRESS.value],
                        start_ndist=agent_info[AgentInfo.START_NDIST.value],
                    )
                    bot_car_info = agents_info_map[crashed_object_name]
                    bot_car_progress = get_normalized_progress(
                        bot_car_info[AgentInfo.CURRENT_PROGRESS.value],
                        start_ndist=bot_car_info[AgentInfo.START_NDIST.value],
                    )

                    # transition to AgentPhase.PAUSE.value
                    if racecar_progress > bot_car_progress:
                        self.bot_cars_lane_change_end_times = [
                            t + self.penalty_seconds
                            for t in self.bot_cars_lane_change_end_times
                        ]
                        self.bot_car_crash_count += 1
                        self.pause_duration += self.penalty_seconds
                        self.bot_car_phase = AgentPhase.PAUSE.value
                        break
        elif self.bot_car_phase == AgentPhase.PAUSE.value:
            # transition to AgentPhase.RUN.value
            if self.pause_duration <= 0.0:
                self.bot_car_phase = AgentPhase.RUN.value
        else:
            raise GenericRolloutException(
                "bot car phase {} is not defined".format(self.bot_car_phase))
        return None, None, None