Esempio n. 1
0
    def __init__(self, env: MazeEnv, logging_prefix: Optional[str] = None):
        """Avoid calling this constructor directly, use :method:`wrap` instead."""
        # BaseEnv is a subset of gym.Env
        super().__init__(env)

        # initialize step aggregator
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.episode_stats = LogStatsAggregator(LogStatsLevel.EPISODE, self.epoch_stats)
        self.step_stats = LogStatsAggregator(LogStatsLevel.STEP, self.episode_stats)

        self.stats_map = {
            LogStatsLevel.EPOCH: self.epoch_stats,
            LogStatsLevel.EPISODE: self.episode_stats,
            LogStatsLevel.STEP: self.step_stats
        }

        if logging_prefix is not None:
            self.epoch_stats.register_consumer(get_stats_logger(logging_prefix))

        self.last_env_time: Optional[int] = None
        self.reward_events = EventCollection()
        self.episode_event_log: Optional[EpisodeEventLog] = None

        self.step_stats_renderer = EventStatsRenderer()

        # register a post-step callback, so stats are recorded even in case that a wrapper
        # in the middle of the stack steps the environment (as done e.g. during step-skipping)
        if hasattr(env, "context") and isinstance(env.context, EnvironmentContext):
            env.context.register_post_step(self._record_stats_if_ready)
Esempio n. 2
0
    def step(self, action: Any) -> Tuple[Any, Any, bool, Dict[Any, Any]]:
        """Record available step-level data."""
        assert self.episode_record is not None, "Environment must be reset before stepping."

        observation, reward, done, info = self.env.step(action)

        # Recording of event logs and stats happens:
        #  - for TimeEnvs:   Only if the env time changed, so that we record once per time step
        #  - for other envs: Every step
        if not isinstance(self.env, TimeEnvMixin) or self.env.get_env_time() != self.last_env_time:
            self.last_env_time = self.env.get_env_time() if isinstance(self.env,
                                                                       TimeEnvMixin) else self.last_env_time + 1

            # Collect the MazeAction
            maze_action = deepcopy(self.env.get_maze_action()) if isinstance(self.env, RecordableEnvMixin) \
                else RawMazeAction(action)

            # Collect step events
            event_collection = EventCollection(
                self.env.get_step_events() if isinstance(self.env, EventEnvMixin) else [])
            step_event_log = StepEventLog(self.last_env_time, events=event_collection)

            # Record trajectory data
            step_record = StateRecord(self.last_maze_state, maze_action, step_event_log, reward, done, info,
                                      self.last_serializable_components)
            self.episode_record.step_records.append(step_record)

            # Collect state and components for the next step
            self._collect_state_and_components(observation)

        return observation, reward, done, info
Esempio n. 3
0
    def reset(self) -> Any:
        """Reset the environment and trigger the episode statistics calculation of the previous run.
        """
        # Generate the episode stats from the previous rollout if any
        self._calculate_kpis()
        self.episode_stats.reduce()
        self._write_episode_event_log()

        # Initialize recording for the new episode (so we can record events already during env reset)
        self.last_env_time = None
        self.reward_events = EventCollection()

        return self.env.reset()
Esempio n. 4
0
    def _write_episode_record(self):
        """Records final state of the episode and dispatches the record to trajectory data writers."""
        # Do not record empty episodes
        if not self.episode_record or len(self.episode_record.step_records) == 0:
            return

        # Record the final state
        env_time = self.env.get_env_time() if isinstance(self.env, TimeEnvMixin) else None
        event_collection = EventCollection(self.env.get_step_events() if isinstance(self.env, EventEnvMixin) else [])
        step_event_log = StepEventLog(env_time, events=event_collection)
        final_step_record = StateRecord(
            maze_state=self.last_maze_state,
            maze_action=None,
            step_event_log=step_event_log,
            reward=None,
            done=None,
            info=None,
            serializable_components=self.last_serializable_components)
        self.episode_record.step_records.append(final_step_record)

        # Write out the current episode
        TrajectoryWriterRegistry.record_trajectory_data(self.episode_record)
Esempio n. 5
0
    def _record_stats_if_ready(self) -> None:
        """Checks if stats are ready to record based on env time (for structured envs, we wait till the end
        of the whole structured step) and if so, does the recording.
        """
        if self.last_env_time is None:
            self.last_env_time = self.env.initial_env_time

        # Recording of event logs and stats happens:
        #  - for TimeEnvs:   Only if the env time changed, so that we record once per time step
        #  - for other envs: Every step
        if isinstance(self.env, TimeEnvMixin) and self.env.get_env_time() == self.last_env_time:
            return

        step_event_log = StepEventLog(env_time=self.last_env_time, events=self.reward_events)
        self.reward_events = EventCollection()

        if isinstance(self.env, EventEnvMixin):
            step_event_log.extend(self.env.get_step_events())

        # add all recorded events to the step aggregator
        for event_record in step_event_log.events:
            self.step_stats.add_event(event_record)

        # trigger logging statistics calculation
        self.step_stats.reduce()

        # lazy init new episode event log if needed
        if not self.episode_event_log:
            episode_id = self.env.get_episode_id() if isinstance(self.env, RecordableEnvMixin) else str(uuid.uuid4())
            self.episode_event_log = EpisodeEventLog(episode_id)

        # log raw events and init new step log
        self.episode_event_log.step_event_logs.append(step_event_log)

        # update the time of last stats recording
        self.last_env_time = self.env.get_env_time() if isinstance(self.env, TimeEnvMixin) else self.last_env_time + 1
Esempio n. 6
0
 def __init__(self, env_time: int, events: Optional[EventCollection] = None):
     self.events = events if events is not None else EventCollection()
     self.env_time = env_time
Esempio n. 7
0
    def reset(self):
        """Reset event aggregation.

        :return: None
        """
        self.events = EventCollection()
Esempio n. 8
0
 def __init__(self):
     self.events = EventCollection()
     self.reset()