def setUp(self): self.env = GymEnvironment('LunarLanderContinuous-v2') self.env.reset() self.parallel_env = DuplicateEnvironment([ GymEnvironment('LunarLanderContinuous-v2'), GymEnvironment('LunarLanderContinuous-v2'), ]) self.parallel_env.reset()
def test_reset(self): num_envs = 5 env = DuplicateEnvironment(make_vec_env(num_envs)) state = env.reset() self.assertEqual(state.observation.shape, (num_envs, 4)) self.assertTrue((state.reward == torch.zeros(num_envs, )).all()) self.assertTrue((state.done == torch.zeros(num_envs, )).all()) self.assertTrue((state.mask == torch.ones(num_envs, )).all())
class TestContinuousPresets(unittest.TestCase): def setUp(self): self.env = GymEnvironment('LunarLanderContinuous-v2') self.env.reset() self.parallel_env = DuplicateEnvironment([ GymEnvironment('LunarLanderContinuous-v2'), GymEnvironment('LunarLanderContinuous-v2'), ]) self.parallel_env.reset() def tearDown(self): if os.path.exists('test_preset.pt'): os.remove('test_preset.pt') def test_ddpg(self): self.validate(ddpg) def test_ppo(self): self.validate(ppo) def test_sac(self): self.validate(sac) def validate(self, builder): preset = builder.device('cpu').env(self.env).build() if isinstance(preset, ParallelPreset): return self.validate_parallel_preset(preset) return self.validate_standard_preset(preset) def validate_standard_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.env.state) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state) def validate_parallel_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.parallel_env.state_array) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # parallel test_agent parallel_test_agent = preset.test_agent() parallel_test_agent.act(self.parallel_env.state_array) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state)
def test_step_until_done(self): num_envs = 3 env = DuplicateEnvironment(make_vec_env(num_envs)) env.seed(5) env.reset() for _ in range(100): state = env.step(torch.ones(num_envs, dtype=torch.int32)) if state.done[0]: break self.assertEqual(state[0].observation.shape, (4, )) self.assertEqual(state[0].reward, 1.) self.assertTrue(state[0].done) self.assertEqual(state[0].mask, 0)
def test_same_as_duplicate(self): n_envs = 3 torch.manual_seed(42) env1 = DuplicateEnvironment([GymEnvironment('CartPole-v0') for i in range(n_envs)]) env2 = GymVectorEnvironment(make_vec_env(n_envs), "CartPole-v0") env1.seed(42) env2.seed(42) state1 = env1.reset() state2 = env2.reset() self.assertEqual(env1.name, env2.name) self.assertEqual(env1.action_space.n, env2.action_space.n) self.assertEqual(env1.observation_space.shape, env2.observation_space.shape) self.assertEqual(env1.num_envs, 3) self.assertEqual(env2.num_envs, 3) act_space = env1.action_space for i in range(2): self.assertTrue(torch.all(torch.eq(state1.observation, state2.observation))) self.assertTrue(torch.all(torch.eq(state1.reward, state2.reward))) self.assertTrue(torch.all(torch.eq(state1.done, state2.done))) self.assertTrue(torch.all(torch.eq(state1.mask, state2.mask))) actions = torch.tensor([act_space.sample() for i in range(n_envs)]) state1 = env1.step(actions) state2 = env2.step(actions)
def setUp(self): self.env = AtariEnvironment('Breakout') self.env.reset() self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')]) self.parallel_env.reset()
class TestAtariPresets(unittest.TestCase): def setUp(self): self.env = AtariEnvironment('Breakout') self.env.reset() self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')]) self.parallel_env.reset() def tearDown(self): if os.path.exists('test_preset.pt'): os.remove('test_preset.pt') def test_a2c(self): self.validate_preset(a2c) def test_c51(self): self.validate_preset(c51) def test_ddqn(self): self.validate_preset(ddqn) def test_dqn(self): self.validate_preset(dqn) def test_ppo(self): self.validate_preset(ppo) def test_rainbow(self): self.validate_preset(rainbow) def test_vac(self): self.validate_preset(vac) def test_vpq(self): self.validate_preset(vpg) def test_vsarsa(self): self.validate_preset(vsarsa) def test_vqn(self): self.validate_preset(vqn) def validate_preset(self, builder): preset = builder.device('cpu').env(self.env).build() if isinstance(preset, ParallelPreset): return self.validate_parallel_preset(preset) return self.validate_standard_preset(preset) def validate_standard_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.env.state) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state) def validate_parallel_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.parallel_env.state_array) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # parallel test_agent parallel_test_agent = preset.test_agent() parallel_test_agent.act(self.parallel_env.state_array) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state)
def test_num_envs(self): num_envs = 5 env = DuplicateEnvironment(make_vec_env(num_envs)) self.assertEqual(env.num_envs, num_envs) self.assertEqual((num_envs, ), env.reset().shape)
def test_env_name(self): env = DuplicateEnvironment(make_vec_env()) self.assertEqual(env.name, 'CartPole-v0')
def setUp(self): self.env = GymEnvironment('CartPole-v0') self.env.reset() self.parallel_env = DuplicateEnvironment([GymEnvironment('CartPole-v0'), GymEnvironment('CartPole-v0')]) self.parallel_env.reset()