Пример #1
0
def test_aggregation_chain():
    """ test the aggregation chain with a single event attribute """

    class _EventInterface(ABC):
        @define_epoch_stats(sum)
        @define_episode_stats(sum)
        @define_step_stats(sum)
        def event1(self, attr1):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
            agg_event.reduce()

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 1
        value = episode_stats[(_EventInterface.event1, None, None)]
        assert value == no_steps * 5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 1
    value = epoch_stats[(_EventInterface.event1, None, None)]
    assert value == no_episodes * no_steps * 5
Пример #2
0
def test_multi_group_projection():
    """ test grouping by three attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group1", "group2", "group3")
        @define_step_stats(sum, group_by="group1", output_name="g1")
        @define_step_stats(sum, group_by="group2", output_name="g2")
        @define_step_stats(sum, group_by="group3", output_name="g3")
        def event1(self, group1, group2, group3, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=1, group2=0, group3=0, attr1=1)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=0, group2=1, group3=0, attr1=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=0, group2=0, group3=1, attr1=4)))

    stats = agg.reduce()
    assert len(stats) == 6

    assert stats[(_EventInterface.event1, "g1", (0, None, None))] == 6
    assert stats[(_EventInterface.event1, "g1", (1, None, None))] == 1
    assert stats[(_EventInterface.event1, "g2", (None, 0, None))] == 5
    assert stats[(_EventInterface.event1, "g2", (None, 1, None))] == 2
    assert stats[(_EventInterface.event1, "g3", (None, None, 0))] == 3
    assert stats[(_EventInterface.event1, "g3", (None, None, 1))] == 4
Пример #3
0
def test_aggregation_chain_fork():
    """ test the aggregation chain with two event attributes and different aggregation operations """

    class _EventInterface(ABC):
        @define_epoch_stats(sum, input_name="attr1_sum")
        @define_epoch_stats(np.mean, input_name="attr2_mean")
        @define_episode_stats(sum, input_name="attr1_sum")
        @define_episode_stats(np.mean, input_name="attr2_mean")
        @define_step_stats(sum, input_name="attr1", output_name="attr1_sum")
        @define_step_stats(np.mean, input_name="attr1", output_name="attr1_mean")
        @define_step_stats(sum, input_name="attr2", output_name="attr2_sum")
        @define_step_stats(np.mean, input_name="attr2", output_name="attr2_mean")
        def event1(self, attr1, attr2):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2.0, attr2=-2.0)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3.0, attr2=-3.0)))

            step_stats = agg_event.reduce()
            assert len(step_stats) == 4
            value1_sum = step_stats[(_EventInterface.event1, "attr1_sum", None)]
            value1_mean = step_stats[(_EventInterface.event1, "attr1_mean", None)]
            value2_sum = step_stats[(_EventInterface.event1, "attr2_sum", None)]
            value2_mean = step_stats[(_EventInterface.event1, "attr2_mean", None)]
            assert value1_sum == 5.0
            assert value1_mean == 2.5
            assert value2_sum == -5.0
            assert value2_mean == -2.5

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 2
        value1 = episode_stats[(_EventInterface.event1, "attr1_sum", None)]
        value2 = episode_stats[(_EventInterface.event1, "attr2_mean", None)]
        assert value1 == no_steps * 5.0
        assert value2 == -2.5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 2
    value1 = epoch_stats[(_EventInterface.event1, "attr1_sum", None)]
    value2 = epoch_stats[(_EventInterface.event1, "attr2_mean", None)]
    assert value1 == no_episodes * no_steps * 5.0
    assert value2 == -2.5
Пример #4
0
def test_event_single_attribute():
    """ test if the aggregation function receives scalars if there is only a single event attribute """

    class _EventInterface(ABC):
        @define_step_stats(sum)
        def event1(self, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = next(iter(stats.items()))
    assert value == 3
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)
Пример #5
0
def test_event_counting():
    """ test counting as a simple aggregation that operates on the attributes dict """

    class _EventInterface(ABC):
        @define_step_stats(len)
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = list(stats.items())[0]
    assert value == 2
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)
Пример #6
0
def test_event_stats_histogram_2():
    """ test histogram loggin on an event level """

    class _EventInterface(ABC):
        @define_step_stats(histogram, input_name='attr1')
        @define_step_stats(histogram, input_name='attr2')
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, "attr1", None)]
    value2 = stats[(_EventInterface.event1, "attr2", None)]

    assert value1 == [1, 1]
    assert value2 == [2, 2]
Пример #7
0
def test_event_attributes():
    """ test the aggregation of individual event attributes """

    class _EventInterface(ABC):
        @define_step_stats(sum, input_name='attr1')
        @define_step_stats(sum, input_name='attr2')
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=3)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=3)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, "attr1", None)]
    value2 = stats[(_EventInterface.event1, "attr2", None)]

    assert value1 == 2
    assert value2 == 6
Пример #8
0
def test_grouping():
    """ test the aggregation of individual event attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group")
        @define_step_stats(sum)
        def event1(self, group, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    for v in [1, 3]:
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group=0, attr1=v)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group=1, attr1=v * 2)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, None, (0,))]
    value2 = stats[(_EventInterface.event1, None, (1,))]

    assert value1 == 4
    assert value2 == 8
Пример #9
0
def test_aggregation_chain_multi_attribute():
    """ test the aggregation chain with two event attributes """

    class _EventInterface(ABC):
        @define_epoch_stats(sum, input_name="attr1")
        @define_epoch_stats(sum, input_name="attr2")
        @define_episode_stats(sum, input_name="attr1")
        @define_episode_stats(sum, input_name="attr2")
        @define_step_stats(sum, input_name="attr1")
        @define_step_stats(sum, input_name="attr2")
        def event1(self, attr1, attr2):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2, attr2=-2)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3, attr2=-3)))
            agg_event.reduce()

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 2
        value1 = episode_stats[(_EventInterface.event1, "attr1", None)]
        value2 = episode_stats[(_EventInterface.event1, "attr2", None)]
        assert value1 == no_steps * 5
        assert value2 == -no_steps * 5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 2
    value1 = epoch_stats[(_EventInterface.event1, "attr1", None)]
    value2 = epoch_stats[(_EventInterface.event1, "attr2", None)]
    assert value1 == no_episodes * no_steps * 5
    assert value2 == -no_episodes * no_steps * 5
Пример #10
0
def test_event_skip_aggregation():
    """ test the once-per-step logging """

    class _EventInterface(ABC):
        @define_step_stats(None)
        def event1(self, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = next(iter(stats.items()))
    assert value == 3
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)

    # check if multiple calls per step are correctly detected
    with pytest.raises(AssertionError):
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
        agg.reduce()
Пример #11
0
def test_multi_grouping():
    """ test grouping by three attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group1", "group2", "group3")
        @define_step_stats(sum)
        def event1(self, group1, group2, group3, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    for i in [1, 8]:
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=1, group2=0, group3=0, attr1=1 * i)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=0, group2=1, group3=0, attr1=2 * i)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=0, group2=0, group3=1, attr1=4 * i)))

    stats = agg.reduce()
    assert len(stats) == 3

    assert stats[(_EventInterface.event1, None, (1, 0, 0))] == 9
    assert stats[(_EventInterface.event1, None, (0, 1, 0))] == 18
    assert stats[(_EventInterface.event1, None, (0, 0, 1))] == 36
Пример #12
0
class LogStatsWrapper(Wrapper[MazeEnv], LogStatsEnv):
    """A statistics logging wrapper for :class:`~maze.core.env.base_env.BaseEnv`.

    :param env: The environment to wrap.
    """

    def __init__(self, env: MazeEnv, logging_prefix: Optional[str] = None):
        """Avoid calling this constructor directly, use :method:`wrap` instead."""
        # BaseEnv is a subset of gym.Env
        super().__init__(env)

        # initialize step aggregator
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.episode_stats = LogStatsAggregator(LogStatsLevel.EPISODE, self.epoch_stats)
        self.step_stats = LogStatsAggregator(LogStatsLevel.STEP, self.episode_stats)

        self.stats_map = {
            LogStatsLevel.EPOCH: self.epoch_stats,
            LogStatsLevel.EPISODE: self.episode_stats,
            LogStatsLevel.STEP: self.step_stats
        }

        if logging_prefix is not None:
            self.epoch_stats.register_consumer(get_stats_logger(logging_prefix))

        self.last_env_time: Optional[int] = None
        self.reward_events = EventCollection()
        self.episode_event_log: Optional[EpisodeEventLog] = None

        self.step_stats_renderer = EventStatsRenderer()

        # register a post-step callback, so stats are recorded even in case that a wrapper
        # in the middle of the stack steps the environment (as done e.g. during step-skipping)
        if hasattr(env, "context") and isinstance(env.context, EnvironmentContext):
            env.context.register_post_step(self._record_stats_if_ready)

    T = TypeVar("T")

    @classmethod
    def wrap(cls, env: T, logging_prefix: Optional[str] = None) -> Union[T, LogStatsEnv]:
        """Creation method providing appropriate type hints. Preferred method to construct the wrapper
        compared to calling the class constructor directly.

        :param env: The environment to be wrapped.
        :param logging_prefix: The episode statistics is connected to the logging system with this tagging
                               prefix. If None, no logging happens.

        :return A newly created wrapper instance.
        """
        return cls(env, logging_prefix)

    @override(BaseEnv)
    def step(self, action: Any) -> Tuple[Any, Any, bool, Dict[Any, Any]]:
        """Collect the rewards for the logging statistics
        """

        # get identifier of current substep
        substep_id, _ = self.env.actor_id() if isinstance(self.env, StructuredEnv) else (None, None)

        # take core env step
        obs, rew, done, info = self.env.step(action)

        # record the reward
        self.reward_events.append(EventRecord(BaseEnvEvents, BaseEnvEvents.reward, dict(value=rew)))

        self._record_stats_if_ready()

        return obs, rew, done, info

    def _record_stats_if_ready(self) -> None:
        """Checks if stats are ready to record based on env time (for structured envs, we wait till the end
        of the whole structured step) and if so, does the recording.
        """
        if self.last_env_time is None:
            self.last_env_time = self.env.initial_env_time

        # Recording of event logs and stats happens:
        #  - for TimeEnvs:   Only if the env time changed, so that we record once per time step
        #  - for other envs: Every step
        if isinstance(self.env, TimeEnvMixin) and self.env.get_env_time() == self.last_env_time:
            return

        step_event_log = StepEventLog(env_time=self.last_env_time, events=self.reward_events)
        self.reward_events = EventCollection()

        if isinstance(self.env, EventEnvMixin):
            step_event_log.extend(self.env.get_step_events())

        # add all recorded events to the step aggregator
        for event_record in step_event_log.events:
            self.step_stats.add_event(event_record)

        # trigger logging statistics calculation
        self.step_stats.reduce()

        # lazy init new episode event log if needed
        if not self.episode_event_log:
            episode_id = self.env.get_episode_id() if isinstance(self.env, RecordableEnvMixin) else str(uuid.uuid4())
            self.episode_event_log = EpisodeEventLog(episode_id)

        # log raw events and init new step log
        self.episode_event_log.step_event_logs.append(step_event_log)

        # update the time of last stats recording
        self.last_env_time = self.env.get_env_time() if isinstance(self.env, TimeEnvMixin) else self.last_env_time + 1

    @override(BaseEnv)
    def reset(self) -> Any:
        """Reset the environment and trigger the episode statistics calculation of the previous run.
        """
        # Generate the episode stats from the previous rollout if any
        self._calculate_kpis()
        self.episode_stats.reduce()
        self._write_episode_event_log()

        # Initialize recording for the new episode (so we can record events already during env reset)
        self.last_env_time = None
        self.reward_events = EventCollection()

        return self.env.reset()

    @override(BaseEnv)
    def close(self):
        """Close the stats rendering figure if needed."""
        self.step_stats_renderer.close()

    @override(LogStatsEnv)
    def get_stats(self, level: LogStatsLevel) -> LogStatsAggregator:
        """Implementation of the LogStatsEnv interface, return the statistics aggregator."""
        aggregator = self.stats_map[level]
        return aggregator

    @override(LogStatsEnv)
    def write_epoch_stats(self):
        """Implementation of the LogStatsEnv interface, call reduce on the episode aggregator.
        """
        if self.episode_event_log:
            self._calculate_kpis()
            self.episode_stats.reduce()
        self.epoch_stats.reduce()
        self._write_episode_event_log()
        self.episode_event_log = None

    @override(LogStatsEnv)
    def get_stats_value(self,
                        event: Callable,
                        level: LogStatsLevel,
                        name: Optional[str] = None) -> LogStatsValue:
        """Implementation of the LogStatsEnv interface, obtain the value from the cached aggregator statistics.
        """
        return self.epoch_stats.last_stats[(event, name, None)]

    @override(LogStatsEnv)
    def clear_epoch_stats(self) -> None:
        """Implementation of the LogStatsEnv interface, clear out episode statistics collected so far in this epoch."""
        self.epoch_stats.clear_inputs()

    def render_stats(self,
                     event_name: str = "BaseEnvEvents.reward",
                     metric_name: str = "value",
                     aggregation_func: Optional[Union[str, Callable]] = None,
                     group_by: str = None,
                     post_processing_func: Optional[Union[str, Callable]] = 'cumsum'):
        """Render statistics from the currently running episode.

        Rendering is based on event logs. You can select arbitrary events from those dispatched by the currently
        running environment.

        :param event_name: Name of the even the even log corresponds to
        :param metric_name: Metric to use (one of the event attributes, e.g. "n_items" -- depends on the event type)
        :param aggregation_func: Optionally, specifies how to aggregate the metric on step level, i.e. when there
                                 are multiple same events dispatched during the same step.
        :param group_by: Optionally, another of event attributes to group by on the step level (e.g. "product_id")
        :param post_processing_func: Optionally, a function to post-process the data ("cumsum" is often used)"""
        self.step_stats_renderer.render_current_episode_stats(
            self.episode_event_log, event_name, metric_name,
            aggregation_func, group_by, post_processing_func)

    def _calculate_kpis(self):
        """Calculate KPIs and append them to both aggregated and logged events."""
        if not isinstance(self.env, EventEnvMixin) or not self.episode_event_log:
            return

        kpi_calculator = self.env.get_kpi_calculator()
        if kpi_calculator is None:
            return

        last_maze_state = self.env.get_maze_state() if isinstance(self.env, RecordableEnvMixin) else None

        kpis_dict = kpi_calculator.calculate_kpis(self.episode_event_log, last_maze_state)
        kpi_events = []
        for name, value in kpis_dict.items():
            kpi_events.append(EventRecord(BaseEnvEvents, BaseEnvEvents.kpi, dict(name=name, value=value)))

        for event_record in kpi_events:
            self.episode_stats.add_event(event_record)  # Add the events to episode aggregator
            self.episode_event_log.step_event_logs[-1].events.append(event_record)  # Log the events

    def _write_episode_event_log(self):
        """Send the episode event log to writers."""
        if self.episode_event_log:
            LogEventsWriterRegistry.record_event_logs(self.episode_event_log)

        self.episode_event_log = None

    @override(Wrapper)
    def get_observation_and_action_dicts(self, maze_state: Optional[MazeStateType],
                                         maze_action: Optional[MazeActionType],
                                         first_step_in_episode: bool) \
            -> Tuple[Optional[Dict[Union[int, str], Any]], Optional[Dict[Union[int, str], Any]]]:
        """Keep both actions and observation the same."""
        return self.env.get_observation_and_action_dicts(maze_state, maze_action, first_step_in_episode)

    @override(SimulatedEnvMixin)
    def clone_from(self, env: 'LogStatsWrapper') -> None:
        """implementation of :class:`~maze.core.env.simulated_env_mixin.SimulatedEnvMixin`."""
        raise RuntimeError("Cloning the 'LogStatsWrapper' is not supported.")

    def get_last_step_events(self, query: Union[Callable, Iterable[Callable]] = None):
        """Convenience accessor to all events recorded during the last step.

        :param query: Specify which events to return (one or more interface methods)
        :return: Recorded events from the last step (all if no query is present)
        """
        if not self.episode_event_log or len(self.episode_event_log.step_event_logs) == 0:
            return []

        last_step_log = self.episode_event_log.step_event_logs[-1]
        if query:
            return list(last_step_log.events.query_events(query))
        else:
            return list(last_step_log.events.events)