Esempio n. 1
0
def test_return_normalization_wrapper():
    """ Unit tests """
    observation_conversion = ObservationConversion()

    env = DummyEnvironment(
        core_env=DummyCoreEnvironment(observation_conversion.space()),
        action_conversion=[DictActionConversion()],
        observation_conversion=[observation_conversion])

    env = ReturnNormalizationRewardWrapper(env, gamma=0.99, epsilon=1e-8)
    env.reset()
    reward = env.step(env.action_space.sample())[1]
    assert isinstance(reward, float)
    assert not hasattr(reward, 'shape')
def test_reward_scaling_wrapper():
    """ Unit tests """
    observation_conversion = ObservationConversion()

    env = DummyEnvironment(core_env=DummyCoreEnvironment(
        observation_conversion.space()),
                           action_conversion=[DictActionConversion()],
                           observation_conversion=[observation_conversion])

    env.reset()
    action = env.action_space.sample()
    np.random.seed(1234)
    original_reward = env.step(action)[1]

    wrapped_env = RewardScalingWrapper(env, scale=0.1)
    np.random.seed(1234)
    wrapped_reward = wrapped_env.step(action)[1]

    assert original_reward == wrapped_reward * 10
def test_skipping_wrapper_and_reward_aggregation():
    """ Step skipping unit test """
    observation_conversion = ObservationConversion()

    env = DummyEnvironment(
        core_env=DummyCoreEnvironment(observation_conversion.space()),
        action_conversion=[DictActionConversion()],
        observation_conversion=[observation_conversion]
    )

    n_steps = 3
    env = StepSkipWrapper.wrap(env, n_steps=n_steps, skip_mode='sticky')

    env.reset()
    for _ in range(4):
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)

        assert(reward == n_steps*10)
def test_records_maze_states_and_actions():
    class CustomDummyRewardAggregator(RewardAggregator):
        """Customized dummy reward aggregator subscribed to BaseEnvEvents."""

        def get_interfaces(self) -> List[Type[ABC]]:
            """
            Return events class is subscribed to.
            """
            additional_interfaces: List[Type[ABC]] = [BaseEnvEvents]
            parent_interfaces = super().get_interfaces()
            return additional_interfaces + parent_interfaces

    class CustomDummyCoreEnv(DummyCoreEnvironment):
        """
        Customized dummy core env with serializable components that regenerates state only in step.
        """

        def __init__(self, observation_space):
            super().__init__(observation_space)
            self.reward_aggregator = CustomDummyRewardAggregator()
            self.maze_state = self.observation_space.sample()
            self.pubsub: Pubsub = Pubsub(self.context.event_service)
            self.pubsub.register_subscriber(self.reward_aggregator)
            self.base_event_publisher = self.pubsub.create_event_topic(BaseEnvEvents)
            self.renderer = DummyRenderer()

        def get_renderer(self) -> DummyRenderer:
            """
            Returns DummyRenderer.
            :return: DummyRenderer.
            """
            return self.renderer

        def step(self, maze_action):
            """
            Steps through the environment.
            """
            self.maze_state = self.observation_space.sample()
            self.base_event_publisher.reward(10)
            return super().step(maze_action)

        def get_maze_state(self):
            """
            Returns current state.
            """
            return self.maze_state

        def get_serializable_components(self) -> Dict[str, Any]:
            """
            Returns minimal dict. with components to serialize.
            """
            return {"value": 0}

    class TestWriter(TrajectoryWriter):
        """Mock writer checking the recorded data"""

        def __init__(self):
            self.episode_count = 0
            self.step_count = 0
            self.episode_records = []

        def write(self, episode_record: StateTrajectoryRecord):
            """Count steps and episodes & check instance types"""
            self.episode_records.append(episode_record)
            self.episode_count += 1
            self.step_count += len(episode_record.step_records)

            for step_record in episode_record.step_records[:-1]:
                assert isinstance(step_record.maze_state, dict)
                assert isinstance(step_record.maze_action, dict)
                assert step_record.serializable_components != {}
                assert len(step_record.step_event_log.events) > 0

            final_state_record = episode_record.step_records[-1]
            assert isinstance(final_state_record.maze_state, dict)
            assert final_state_record.maze_action is None
            assert final_state_record.serializable_components != {}

            assert isinstance(episode_record.renderer, Renderer)

    writer = TestWriter()
    TrajectoryWriterRegistry.writers = []  # Ensure there is no other writer
    TrajectoryWriterRegistry.register_writer(writer)

    # env = env_instantiation_example.example_1()
    observation_conversion = ObservationConversion()
    env = DummyEnvironment(
        core_env=CustomDummyCoreEnv(observation_conversion.space()),
        action_conversion=[DictActionConversion()],
        observation_conversion=[observation_conversion]
    )
    env = TrajectoryRecordingWrapper.wrap(env)

    policy = DummyGreedyPolicy()
    states = []  # Observe changes in states over time.

    for _ in range(5):
        obs = env.reset()
        for _ in range(10):
            maze_state = env.get_maze_state()
            states.append(deepcopy(maze_state))
            obs, _, _, _ = env.step(policy.compute_action(observation=obs, maze_state=maze_state, deterministic=True))

    # final env reset required
    env.reset()

    assert writer.step_count == 5 * (10 + 1)  # Count also the recorded final state
    assert writer.episode_count == 5

    # Compare if the recorded inventory changes from the first episode match with the trajectory records
    for step_id in range(10):
        assert np.all(
            (states[step_id][key] == writer.episode_records[0].step_records[step_id].maze_state[key])
            for key in env.observation_conversion.space().spaces
        )