def test_step(self): env = AtariEnvironment('Breakout') env.reset() state = env.step(1) self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) self.assertFalse(state.done) self.assertEqual(state.mask, 1) self.assertEqual(state['life_lost'], False)
def test_step_until_done(self): env = AtariEnvironment('Breakout') env.reset() for _ in range(1000): state = env.step(1) if state.done: break self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) self.assertTrue(state.done) self.assertEqual(state.mask, 0) self.assertEqual(state['life_lost'], False)
class TestAtariPresets(unittest.TestCase): def setUp(self): self.env = AtariEnvironment('Breakout') self.env.reset() def tearDown(self): if os.path.exists('test_preset.pt'): os.remove('test_preset.pt') def test_a2c(self): self.validate_preset(a2c) def test_c51(self): self.validate_preset(c51) def test_ddqn(self): self.validate_preset(ddqn) def test_dqn(self): self.validate_preset(dqn) def test_ppo(self): self.validate_preset(ppo) def test_rainbow(self): self.validate_preset(rainbow) def test_vac(self): self.validate_preset(vac) def test_vpq(self): self.validate_preset(vpg) def test_vsarsa(self): self.validate_preset(vsarsa) def test_vqn(self): self.validate_preset(vqn) def validate_preset(self, builder): preset = builder.device('cpu').env(self.env).build() # normal agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.env.state) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state)
def test_reset(self): env = AtariEnvironment('Breakout') state = env.reset() self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) self.assertFalse(state.done) self.assertEqual(state.mask, 1)
def test_runs(self): np.random.seed(0) torch.random.manual_seed(0) n = 4 envs = [] for i in range(n): env = AtariEnvironment('Breakout') env.reset() envs.append(env) agent = MockAgent(n, max_action=4) body = ParallelAtariBody(agent, envs, noop_max=30) for _ in range(200): states = [env.state for env in envs] rewards = torch.tensor([env.reward for env in envs]).float() actions = body.act(states, rewards) for i, env in enumerate(envs): if actions[i] is not None: env.step(actions[i])
def test_rainbow_model_cpu(self): env = AtariEnvironment('Breakout') model = nature_rainbow(env) env.reset() x = torch.cat([env.state.raw] * 4, dim=1).float() out = model(x) tt.assert_almost_equal( out, torch.tensor([[ 0.0676, -0.0235, 0.0690, -0.0713, -0.0287, 0.0053, -0.0463, 0.0495, -0.0222, -0.0504, 0.0064, -0.0204, 0.0168, 0.0127, -0.0113, -0.0586, -0.0544, 0.0114, -0.0077, 0.0666, -0.0663, -0.0420, -0.0698, -0.0314, 0.0272, 0.0361, -0.0537, 0.0301, 0.0036, -0.0472, -0.0499, 0.0114, 0.0182, 0.0008, -0.0132, -0.0803, -0.0087, -0.0017, 0.0598, -0.0627, 0.0859, 0.0117, 0.0105, 0.0309, -0.0370, -0.0111, -0.0262, 0.0338, 0.0141, -0.0385, 0.0547, 0.0648, -0.0370, 0.0107, -0.0629, -0.0163, 0.0282, -0.0670, 0.0161, -0.0244, -0.0030, 0.0038, -0.0208, 0.0005, 0.0125, 0.0608, -0.0089, 0.0026, 0.0562, -0.0678, 0.0841, -0.0265, -0.0461, -0.0124, 0.0276, 0.0364, 0.0195, -0.0309, -0.0337, -0.0603, -0.0252, -0.0356, 0.0221, 0.0184, -0.0154, -0.0136, -0.0277, 0.0283, 0.0495, 0.0185, -0.0357, 0.0305, -0.0052, -0.0432, -0.0135, -0.0554, -0.0094, 0.0272, 0.1030, 0.0049, 0.0012, -0.0140, 0.0146, -0.0979, 0.0487, 0.0122, -0.0204, 0.0496, -0.0055, -0.0015, -0.0170, 0.0053, 0.0104, -0.0742, 0.0742, -0.0381, 0.0104, -0.0065, -0.0564, 0.0453, -0.0057, -0.0029, -0.0722, 0.0094, -0.0561, 0.0284, 0.0402, 0.0233, -0.0716, -0.0424, 0.0165, -0.0505, 0.0006, 0.0219, -0.0601, 0.0656, -0.0175, -0.0524, 0.0355, 0.0007, -0.0042, -0.0443, 0.0871, -0.0403, -0.0031, 0.0171, -0.0359, -0.0520, -0.0344, 0.0239, 0.0099, 0.0004, 0.0235, 0.0238, -0.0153, 0.0501, -0.0052, 0.0162, 0.0313, -0.0121, 0.0009, -0.0366, -0.0628, 0.0386, -0.0671, 0.0480, -0.0595, 0.0568, -0.0604, -0.0540, 0.0403, -0.0187, 0.0649, 0.0029, -0.0003, 0.0020, -0.0056, 0.0471, -0.0145, -0.0126, -0.0395, -0.0455, -0.0437, 0.0056, 0.0331, 0.0004, 0.0127, -0.0022, -0.0502, 0.0362, 0.0624, -0.0012, -0.0515, 0.0303, -0.0357, -0.0420, 0.0321, -0.0162, 0.0007, -0.0272, 0.0227, 0.0187, -0.0459, 0.0496 ]]), decimal=3)
class TestAtariPresets(unittest.TestCase): def setUp(self): self.env = AtariEnvironment('Breakout') self.env.reset() self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')]) self.parallel_env.reset() def tearDown(self): if os.path.exists('test_preset.pt'): os.remove('test_preset.pt') def test_a2c(self): self.validate_preset(a2c) def test_c51(self): self.validate_preset(c51) def test_ddqn(self): self.validate_preset(ddqn) def test_dqn(self): self.validate_preset(dqn) def test_ppo(self): self.validate_preset(ppo) def test_rainbow(self): self.validate_preset(rainbow) def test_vac(self): self.validate_preset(vac) def test_vpq(self): self.validate_preset(vpg) def test_vsarsa(self): self.validate_preset(vsarsa) def test_vqn(self): self.validate_preset(vqn) def validate_preset(self, builder): preset = builder.device('cpu').env(self.env).build() if isinstance(preset, ParallelPreset): return self.validate_parallel_preset(preset) return self.validate_standard_preset(preset) def validate_standard_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.env.state) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state) def validate_parallel_preset(self, preset): # train agent agent = preset.agent(writer=DummyWriter(), train_steps=100000) agent.act(self.parallel_env.state_array) # test agent test_agent = preset.test_agent() test_agent.act(self.env.state) # parallel test_agent parallel_test_agent = preset.test_agent() parallel_test_agent.act(self.parallel_env.state_array) # test save/load preset.save('test_preset.pt') preset = torch.load('test_preset.pt') test_agent = preset.test_agent() test_agent.act(self.env.state)
def test_rainbow_model_cuda(self): env = AtariEnvironment('Breakout') model = nature_rainbow(env).cuda() env.reset() x = torch.cat([env.state.raw] * 4, dim=1).float().cuda() out = model(x) tt.assert_almost_equal( out.cpu(), torch.tensor([[ -1.4765e-02, -4.0353e-02, -2.1705e-02, -2.2314e-02, 3.6881e-02, -1.4175e-02, 1.2442e-02, -6.8713e-03, 2.4970e-02, 2.5681e-02, -4.5859e-02, -2.3327e-02, 3.6205e-02, 7.1024e-03, -2.7564e-02, 2.1592e-02, -3.2728e-02, 1.3602e-02, -1.1690e-02, -4.3082e-02, -1.2996e-02, 1.7184e-02, 1.3446e-02, -3.3587e-03, -4.6350e-02, -1.7646e-02, 2.1954e-02, 8.5546e-03, -2.1359e-02, -2.4206e-02, -2.3151e-02, -3.6330e-02, 4.4699e-02, 3.9887e-03, 1.5609e-02, -4.3950e-02, 1.0955e-02, -2.4277e-02, 1.4915e-02, 3.2508e-03, 6.1454e-02, 3.5242e-02, -1.5274e-02, -2.6729e-02, -2.4072e-02, 1.5696e-02, 2.6622e-02, -3.5404e-02, 5.1701e-02, -5.3047e-02, -1.8412e-02, 8.6640e-03, -3.1722e-02, 4.0329e-02, 1.2896e-02, -1.4139e-02, -4.9200e-02, -4.6193e-02, -2.9064e-03, -2.2078e-02, -4.0084e-02, -8.3519e-03, -2.7589e-02, -4.9979e-03, -1.6055e-02, -4.5311e-02, -2.6951e-02, 2.8032e-02, -4.0069e-03, 3.2405e-02, -5.3164e-03, -3.0139e-03, 6.6179e-04, -4.9243e-02, 3.2515e-02, 9.8307e-03, -3.4257e-03, -3.9522e-02, 1.2594e-02, -2.7210e-02, 2.3451e-02, 4.2257e-02, 2.2239e-02, 1.4304e-04, 4.2905e-04, 1.5193e-02, 3.1897e-03, -1.0828e-02, -4.8345e-02, 6.8747e-02, -7.1725e-03, -9.7815e-03, -1.6331e-02, 1.0434e-02, -8.8083e-04, 3.8219e-02, 6.8332e-03, -2.0189e-02, 2.8141e-02, 1.4913e-02, -2.4925e-02, -2.8922e-02, -7.1546e-03, 1.9791e-02, 1.1160e-02, 1.0306e-02, -1.3631e-02, 2.7318e-03, 1.4050e-03, -8.2064e-03, 3.5836e-02, -1.5877e-02, -1.1198e-02, 1.9514e-02, 3.0832e-03, -6.2730e-02, 6.1493e-03, -1.2340e-02, 3.9110e-02, -2.6895e-02, -5.1718e-03, 7.5017e-03, 1.2673e-03, 4.7525e-02, 1.7373e-03, -5.1745e-03, -2.8621e-02, 3.4984e-02, -3.2622e-02, 1.0748e-02, 1.2499e-02, -1.8788e-02, -8.6717e-03, 4.3620e-02, 2.8460e-02, -6.8146e-03, -3.5824e-02, 9.2931e-03, 3.7893e-03, 2.4187e-02, 1.3393e-02, -5.9393e-03, -9.9837e-03, -8.1019e-03, -2.1840e-02, -3.8945e-02, 1.6736e-02, -4.7475e-02, 4.9770e-02, 3.4695e-02, 1.8961e-02, 2.7416e-02, -1.3578e-02, -9.8595e-03, 2.2834e-03, 2.4829e-02, -4.3998e-02, 3.2398e-02, -1.4200e-02, 2.4907e-02, -2.2542e-02, -9.2765e-03, 2.0658e-03, -4.1246e-03, -1.8095e-02, -1.2732e-02, -3.2090e-03, 1.3127e-02, -2.0888e-02, 1.4931e-02, -4.0576e-02, 4.2877e-02, 7.9411e-05, -4.4377e-02, 3.2357e-03, 1.6201e-02, 4.0387e-02, -1.9023e-02, 5.8033e-02, -3.3424e-02, 2.9598e-03, -1.8526e-02, -2.2967e-02, 4.3449e-02, -1.2564e-02, -9.3756e-03, -2.1745e-02, -2.7089e-02, -3.6791e-02, -5.2018e-02, 2.4588e-02, 1.0037e-03, 3.9753e-02, 4.3534e-02, 2.6446e-02, -1.1808e-02, 2.1426e-02, 7.5522e-03, 2.2847e-03, -2.7211e-02, 4.1364e-02, -1.1281e-02, 1.6523e-03, -1.9913e-03 ]]), decimal=3) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) loss = out.sum() loss.backward() optimizer.step() out = model(x) tt.assert_almost_equal( out.cpu(), torch.tensor([[ -0.0247, -0.0172, -0.0633, -0.0154, -0.0156, -0.1156, -0.0793, -0.0184, -0.0408, 0.0005, -0.0920, -0.0481, -0.0597, -0.0243, 0.0006, -0.1045, -0.0476, -0.0030, -0.0230, -0.0869, -0.0149, -0.0412, -0.0753, -0.0640, -0.1106, -0.0632, -0.0645, -0.0474, -0.0124, -0.0698, -0.0275, -0.0415, -0.0916, -0.0957, -0.0851, -0.1296, -0.1049, -0.0196, -0.0823, -0.0380, -0.1085, -0.0526, -0.0083, -0.1274, -0.0426, -0.0183, -0.0585, -0.0366, -0.1111, -0.0074, -0.1238, -0.0324, -0.0166, -0.0719, -0.0285, -0.0427, -0.1158, -0.0569, 0.0075, -0.0419, -0.0288, -0.1189, -0.0220, -0.0370, 0.0040, 0.0228, -0.0958, -0.0258, -0.0276, -0.0405, -0.0958, -0.0201, -0.0639, -0.0543, -0.0705, -0.0940, -0.0700, -0.0921, -0.0426, 0.0026, -0.0556, -0.0439, -0.0386, -0.0957, -0.0915, -0.0679, -0.1272, -0.0754, -0.0076, -0.1046, -0.0350, -0.0887, -0.0350, -0.0270, -0.1188, -0.0449, 0.0020, -0.0406, 0.0011, -0.0842, -0.0422, -0.1280, -0.0205, 0.0002, -0.0789, -0.0185, -0.0510, -0.1180, -0.0550, -0.0159, -0.0702, -0.0029, -0.0891, -0.0253, -0.0485, -0.0128, 0.0010, -0.0870, -0.0230, -0.0233, -0.0411, -0.0870, -0.0419, -0.0688, -0.0583, -0.0448, -0.0864, -0.0926, -0.0758, -0.0540, 0.0058, -0.0843, -0.0365, -0.0608, -0.0787, -0.0938, -0.0680, -0.0995, -0.0764, 0.0061, -0.0821, -0.0636, -0.0848, -0.0373, -0.0285, -0.1086, -0.0464, -0.0228, -0.0464, -0.0279, -0.1053, -0.0224, -0.1268, -0.0006, -0.0186, -0.0836, -0.0011, -0.0415, -0.1222, -0.0668, -0.0015, -0.0535, -0.0071, -0.1202, -0.0257, -0.0503, 0.0004, 0.0099, -0.1113, -0.0182, -0.0080, -0.0216, -0.0661, -0.0115, -0.0468, -0.0716, -0.0404, -0.0950, -0.0681, -0.0933, -0.0699, -0.0154, -0.0853, -0.0414, -0.0403, -0.0700, -0.0685, -0.0975, -0.0934, -0.1016, -0.0121, -0.1084, -0.0391, -0.1006, -0.0441, -0.0024, -0.1232, -0.0159, 0.0012, -0.0480, -0.0013, -0.0789, -0.0309, -0.1101 ]]), decimal=3)