コード例 #1
0
 def setUp(self):
     self.env = GymEnvironment('LunarLanderContinuous-v2')
     self.env.reset()
     self.parallel_env = DuplicateEnvironment([
         GymEnvironment('LunarLanderContinuous-v2'),
         GymEnvironment('LunarLanderContinuous-v2'),
     ])
     self.parallel_env.reset()
コード例 #2
0
 def test_reset(self):
     num_envs = 5
     env = DuplicateEnvironment(make_vec_env(num_envs))
     state = env.reset()
     self.assertEqual(state.observation.shape, (num_envs, 4))
     self.assertTrue((state.reward == torch.zeros(num_envs, )).all())
     self.assertTrue((state.done == torch.zeros(num_envs, )).all())
     self.assertTrue((state.mask == torch.ones(num_envs, )).all())
コード例 #3
0
class TestContinuousPresets(unittest.TestCase):
    def setUp(self):
        self.env = GymEnvironment('LunarLanderContinuous-v2')
        self.env.reset()
        self.parallel_env = DuplicateEnvironment([
            GymEnvironment('LunarLanderContinuous-v2'),
            GymEnvironment('LunarLanderContinuous-v2'),
        ])
        self.parallel_env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_ddpg(self):
        self.validate(ddpg)

    def test_ppo(self):
        self.validate(ppo)

    def test_sac(self):
        self.validate(sac)

    def validate(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        if isinstance(preset, ParallelPreset):
            return self.validate_parallel_preset(preset)
        return self.validate_standard_preset(preset)

    def validate_standard_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)

    def validate_parallel_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.parallel_env.state_array)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # parallel test_agent
        parallel_test_agent = preset.test_agent()
        parallel_test_agent.act(self.parallel_env.state_array)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
コード例 #4
0
 def test_step_until_done(self):
     num_envs = 3
     env = DuplicateEnvironment(make_vec_env(num_envs))
     env.seed(5)
     env.reset()
     for _ in range(100):
         state = env.step(torch.ones(num_envs, dtype=torch.int32))
         if state.done[0]:
             break
     self.assertEqual(state[0].observation.shape, (4, ))
     self.assertEqual(state[0].reward, 1.)
     self.assertTrue(state[0].done)
     self.assertEqual(state[0].mask, 0)
コード例 #5
0
 def test_same_as_duplicate(self):
     n_envs = 3
     torch.manual_seed(42)
     env1 = DuplicateEnvironment([GymEnvironment('CartPole-v0') for i in range(n_envs)])
     env2 = GymVectorEnvironment(make_vec_env(n_envs), "CartPole-v0")
     env1.seed(42)
     env2.seed(42)
     state1 = env1.reset()
     state2 = env2.reset()
     self.assertEqual(env1.name, env2.name)
     self.assertEqual(env1.action_space.n, env2.action_space.n)
     self.assertEqual(env1.observation_space.shape, env2.observation_space.shape)
     self.assertEqual(env1.num_envs, 3)
     self.assertEqual(env2.num_envs, 3)
     act_space = env1.action_space
     for i in range(2):
         self.assertTrue(torch.all(torch.eq(state1.observation, state2.observation)))
         self.assertTrue(torch.all(torch.eq(state1.reward, state2.reward)))
         self.assertTrue(torch.all(torch.eq(state1.done, state2.done)))
         self.assertTrue(torch.all(torch.eq(state1.mask, state2.mask)))
         actions = torch.tensor([act_space.sample() for i in range(n_envs)])
         state1 = env1.step(actions)
         state2 = env2.step(actions)
コード例 #6
0
 def setUp(self):
     self.env = AtariEnvironment('Breakout')
     self.env.reset()
     self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
     self.parallel_env.reset()
コード例 #7
0
class TestAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = AtariEnvironment('Breakout')
        self.env.reset()
        self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
        self.parallel_env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_a2c(self):
        self.validate_preset(a2c)

    def test_c51(self):
        self.validate_preset(c51)

    def test_ddqn(self):
        self.validate_preset(ddqn)

    def test_dqn(self):
        self.validate_preset(dqn)

    def test_ppo(self):
        self.validate_preset(ppo)

    def test_rainbow(self):
        self.validate_preset(rainbow)

    def test_vac(self):
        self.validate_preset(vac)

    def test_vpq(self):
        self.validate_preset(vpg)

    def test_vsarsa(self):
        self.validate_preset(vsarsa)

    def test_vqn(self):
        self.validate_preset(vqn)

    def validate_preset(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        if isinstance(preset, ParallelPreset):
            return self.validate_parallel_preset(preset)
        return self.validate_standard_preset(preset)

    def validate_standard_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)

    def validate_parallel_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.parallel_env.state_array)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # parallel test_agent
        parallel_test_agent = preset.test_agent()
        parallel_test_agent.act(self.parallel_env.state_array)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
コード例 #8
0
 def test_num_envs(self):
     num_envs = 5
     env = DuplicateEnvironment(make_vec_env(num_envs))
     self.assertEqual(env.num_envs, num_envs)
     self.assertEqual((num_envs, ), env.reset().shape)
コード例 #9
0
 def test_env_name(self):
     env = DuplicateEnvironment(make_vec_env())
     self.assertEqual(env.name, 'CartPole-v0')
コード例 #10
0
 def setUp(self):
     self.env = GymEnvironment('CartPole-v0')
     self.env.reset()
     self.parallel_env = DuplicateEnvironment([GymEnvironment('CartPole-v0'), GymEnvironment('CartPole-v0')])
     self.parallel_env.reset()