Пример #1
0
    def test_rendering(self):
        bench = CMAESBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        with pytest.raises(NotImplementedError):
            wrapped.render_state_tracking()

        bench = CMAESBenchmark()

        def dummy():
            return [1, [2, 3]]

        bench.config.state_method = dummy
        bench.config.observation_space = gym.spaces.Tuple(
            (
                gym.spaces.Discrete(2),
                gym.spaces.Box(low=np.array([-1, 1]), high=np.array([5, 5])),
            )
        )
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        with pytest.raises(NotImplementedError):
            wrapped.render_state_tracking()

        def dummy2():
            return [0.5]

        bench.config.state_method = dummy2
        bench.config.observation_space = gym.spaces.Box(
            low=np.array([0]), high=np.array([1])
        )
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        bench = LubyBenchmark()
        env = bench.get_environment()
        wrapped = StateTrackingWrapper(env, 2)
        wrapped.reset()
        wrapped.step(1)
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class discrete_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.Discrete(2)
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = discrete_obs_env()
        wrapped = StateTrackingWrapper(env, 2)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class multi_discrete_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.MultiDiscrete([2, 3])
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return [1, 2]

            def step(self, action):
                return [1, 2], 1, 1, 1

        env = multi_discrete_obs_env()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class multi_binary_obs_env:
            def __init__(self):
                self.observation_space = gym.spaces.MultiBinary(2)
                self.action_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return [1, 1]

            def step(self, action):
                return [1, 1], 1, 1, 1

        env = multi_binary_obs_env()
        wrapped = StateTrackingWrapper(env)
        wrapped.reset()
        wrapped.step(1)
        img = wrapped.render_state_tracking()
        self.assertTrue(img.shape[-1] == 3)
Пример #2
0
from chainerrl import wrappers
import matplotlib.pyplot as plt
from examples.example_utils import train_chainer, make_chainer_dqn
from dacbench.benchmarks import FastDownwardBenchmark
from dacbench.wrappers import StateTrackingWrapper

# Get FastDownward Environment
bench = FastDownwardBenchmark()
env = bench.get_environment()

# Wrap environment to track state
# In this case we also want the mean of each 5 step interval
env = StateTrackingWrapper(env, 5)

# Chainer requires casting to float32
env = wrappers.CastObservationToFloat32(env)

# Make chainer agent
obs_size = env.observation_space.low.size
agent = make_chainer_dqn(obs_size, env.action_space)

# Train for 10 episodes
train_chainer(agent, env)

# Plot state values after training
env.render_state_tracking()
plt.show()