Ejemplo n.º 1
0
    def test_get_states(self):
        bench = LubyBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        for i in range(4):
            wrapped.step(i)
        wrapped2 = StateTrackingWrapper(env, 2)
        wrapped2.reset()
        for i in range(4):
            wrapped2.step(i)

        overall_states_only = wrapped.get_states()
        overall_states, intervals = wrapped2.get_states()
        self.assertTrue(np.array_equal(overall_states, overall_states_only))
        self.assertTrue(len(overall_states_only) == 5)
        self.assertTrue(len(overall_states_only[4]) == 6)

        self.assertTrue(len(intervals) == 3)
        self.assertTrue(len(intervals[0]) == 2)
        self.assertTrue(len(intervals[1]) == 2)
        self.assertTrue(len(intervals[2]) == 1)
Ejemplo n.º 2
0
    def test_step_reset(self):
        bench = LubyBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env, 2)

        state = wrapped.reset()
        self.assertTrue(len(state) > 1)
        self.assertTrue(len(wrapped.overall_states) == 1)

        state, reward, done, _ = wrapped.step(1)
        self.assertTrue(len(state) > 1)
        self.assertTrue(reward <= 0)
        self.assertFalse(done)

        self.assertTrue(len(wrapped.overall_states) == 2)
        self.assertTrue(len(wrapped.current_states) == 2)
        self.assertTrue(len(wrapped.state_intervals) == 0)

        state = wrapped.reset()
        self.assertTrue(len(wrapped.overall_states) == 3)
        self.assertTrue(len(wrapped.current_states) == 1)
        self.assertTrue(len(wrapped.state_intervals) == 1)
Ejemplo n.º 3
0
    def test_rendering(self):
        bench = CMAESBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        with pytest.raises(NotImplementedError):
            wrapped.render_state_tracking()

        bench = CMAESBenchmark()

        def dummy():
            return [1, [2, 3]]

        bench.config.state_method = dummy
        bench.config.observation_space = gym.spaces.Tuple(
            (
                gym.spaces.Discrete(2),
                gym.spaces.Box(low=np.array([-1, 1]), high=np.array([5, 5])),
            )
        )
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        with pytest.raises(NotImplementedError):
            wrapped.render_state_tracking()

        def dummy2():
            return [0.5]

        bench.config.state_method = dummy2
        bench.config.observation_space = gym.spaces.Box(
            low=np.array([0]), high=np.array([1])
        )
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        bench = LubyBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env, 2)
        wrapped.reset()
        wrapped.step(1)
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class discrete_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.Discrete(2)
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = discrete_obs_env()
        wrapped = StateTrackingWrapper(env, 2)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class multi_discrete_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.MultiDiscrete([2, 3])
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return [1, 2]

            def step(self, action):
                return [1, 2], 1, 1, 1

        env = multi_discrete_obs_env()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class multi_binary_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.MultiBinary(2)
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return [1, 1]

            def step(self, action):
                return [1, 1], 1, 1, 1

        env = multi_binary_obs_env()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)