Пример #1
0
    def test_get_env(self):
        bench = FastDownwardBenchmark()
        env = bench.get_environment()
        self.assertTrue(issubclass(type(env), FastDownwardEnv))

        bench.config.instance_set_path = "../instance_sets/fast_downward/childsnack"
        bench.read_instance_set()
        env = bench.get_environment()
        self.assertTrue(issubclass(type(env), FastDownwardEnv))
Пример #2
0
 def test_read_instances(self):
     bench = FastDownwardBenchmark()
     bench.read_instance_set()
     self.assertTrue(len(bench.config.instance_set) == 30)
     self.assertTrue(type(bench.config.instance_set[0]) == str)
     self.assertTrue(os.path.isfile(bench.config.instance_set[0]))
     path = bench.config.instance_set[0]
     bench2 = FastDownwardBenchmark()
     env = bench2.get_environment()
     self.assertTrue(type(env.instance_set[0]) == str)
     self.assertTrue(len(env.instance_set) == 30)
     self.assertTrue(path == env.instance_set[0])
Пример #3
0
 def test_scenarios(self):
     scenarios = [
         "fd_barman.json",
         "fd_blocksworld.json",
         "fd_visitall.json",
         "fd_childsnack.json",
         "fd_sokoban.json",
         "fd_rovers.json",
     ]
     for s in scenarios:
         path = os.path.join("dacbench/scenarios/fast_downward/", s)
         bench = FastDownwardBenchmark(path)
         self.assertTrue(bench.config is not None)
         env = bench.get_environment()
         state = env.reset()
         self.assertTrue(state is not None)
         state, _, _, _ = env.step(0)
         self.assertTrue(state is not None)
Пример #4
0
    def test_rendering(self):
        bench = FastDownwardBenchmark()
        env = bench.get_environment()
        wrapped = ActionFrequencyWrapper(env, 2)
        wrapped.reset()
        for _ in range(10):
            wrapped.step(1)
        img = wrapped.render_action_tracking()
        self.assertTrue(img.shape[-1] == 3)

        bench = CMAESBenchmark()
        env = bench.get_environment()
        wrapped = ActionFrequencyWrapper(env, 2)
        wrapped.reset()
        wrapped.step(np.ones(10))
        img = wrapped.render_action_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class dict_action_env:
            def __init__(self):
                self.action_space = gym.spaces.Dict({
                    "one":
                    gym.spaces.Discrete(2),
                    "two":
                    gym.spaces.Box(low=np.array([-1, 1]),
                                   high=np.array([1, 5])),
                })
                self.observation_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = dict_action_env()
        wrapped = ActionFrequencyWrapper(env)
        wrapped.reset()
        with self.assertRaises(NotImplementedError):
            wrapped.render_action_tracking()

        class tuple_action_env:
            def __init__(self):
                self.action_space = gym.spaces.Tuple((
                    gym.spaces.Discrete(2),
                    gym.spaces.Box(low=np.array([-1, 1]),
                                   high=np.array([1, 5])),
                ))
                self.observation_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = tuple_action_env()
        wrapped = ActionFrequencyWrapper(env)
        wrapped.reset()
        with self.assertRaises(NotImplementedError):
            wrapped.render_action_tracking()

        class multi_discrete_action_env:
            def __init__(self):
                self.action_space = gym.spaces.MultiDiscrete([2, 3])
                self.observation_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = multi_discrete_action_env()
        wrapped = ActionFrequencyWrapper(env, 5)
        wrapped.reset()
        for _ in range(10):
            wrapped.step([1, 2])
        img = wrapped.render_action_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class multi_binary_action_env:
            def __init__(self):
                self.action_space = gym.spaces.MultiBinary(2)
                self.observation_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = multi_binary_action_env()
        wrapped = ActionFrequencyWrapper(env)
        wrapped.reset()
        wrapped.step([1, 0])
        img = wrapped.render_action_tracking()
        self.assertTrue(img.shape[-1] == 3)

        class large_action_env:
            def __init__(self):
                self.action_space = gym.spaces.Box(low=np.zeros(15),
                                                   high=np.ones(15))
                self.observation_space = gym.spaces.Discrete(2)
                self.reward_range = (1, 2)
                self.metadata = {}

            def reset(self):
                return 1

            def step(self, action):
                return 1, 1, 1, 1

        env = large_action_env()
        wrapped = ActionFrequencyWrapper(env)
        wrapped.reset()
        wrapped.step(0.5 * np.ones(15))
        img = wrapped.render_action_tracking()
        self.assertTrue(img.shape[-1] == 3)
Пример #5
0
from chainerrl import wrappers
import matplotlib.pyplot as plt
from examples.example_utils import train_chainer, make_chainer_dqn
from dacbench.benchmarks import FastDownwardBenchmark
from dacbench.wrappers import StateTrackingWrapper

# Get FastDownward Environment
bench = FastDownwardBenchmark()
env = bench.get_environment()

# Wrap environment to track state
# In this case we also want the mean of each 5 step interval
env = StateTrackingWrapper(env, 5)

# Chainer requires casting to float32
env = wrappers.CastObservationToFloat32(env)

# Make chainer agent
obs_size = env.observation_space.low.size
agent = make_chainer_dqn(obs_size, env.action_space)

# Train for 10 episodes
train_chainer(agent, env)

# Plot state values after training
env.render_state_tracking()
plt.show()