コード例 #1
0
 def test_step(self):
     env = AtariEnvironment('Breakout')
     env.reset()
     state = env.step(1)
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertFalse(state.done)
     self.assertEqual(state.mask, 1)
     self.assertEqual(state['life_lost'], False)
コード例 #2
0
 def test_step_until_done(self):
     env = AtariEnvironment('Breakout')
     env.reset()
     for _ in range(1000):
         state = env.step(1)
         if state.done:
             break
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertTrue(state.done)
     self.assertEqual(state.mask, 0)
     self.assertEqual(state['life_lost'], False)
コード例 #3
0
class TestAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = AtariEnvironment('Breakout')
        self.env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_a2c(self):
        self.validate_preset(a2c)

    def test_c51(self):
        self.validate_preset(c51)

    def test_ddqn(self):
        self.validate_preset(ddqn)

    def test_dqn(self):
        self.validate_preset(dqn)

    def test_ppo(self):
        self.validate_preset(ppo)

    def test_rainbow(self):
        self.validate_preset(rainbow)

    def test_vac(self):
        self.validate_preset(vac)

    def test_vpq(self):
        self.validate_preset(vpg)

    def test_vsarsa(self):
        self.validate_preset(vsarsa)

    def test_vqn(self):
        self.validate_preset(vqn)

    def validate_preset(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        # normal agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
コード例 #4
0
 def test_reset(self):
     env = AtariEnvironment('Breakout')
     state = env.reset()
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertFalse(state.done)
     self.assertEqual(state.mask, 1)
コード例 #5
0
    def test_runs(self):
        np.random.seed(0)
        torch.random.manual_seed(0)
        n = 4
        envs = []
        for i in range(n):
            env = AtariEnvironment('Breakout')
            env.reset()
            envs.append(env)
        agent = MockAgent(n, max_action=4)
        body = ParallelAtariBody(agent, envs, noop_max=30)

        for _ in range(200):
            states = [env.state for env in envs]
            rewards = torch.tensor([env.reward for env in envs]).float()
            actions = body.act(states, rewards)
            for i, env in enumerate(envs):
                if actions[i] is not None:
                    env.step(actions[i])
コード例 #6
0
 def test_rainbow_model_cpu(self):
     env = AtariEnvironment('Breakout')
     model = nature_rainbow(env)
     env.reset()
     x = torch.cat([env.state.raw] * 4, dim=1).float()
     out = model(x)
     tt.assert_almost_equal(
         out,
         torch.tensor([[
             0.0676, -0.0235, 0.0690, -0.0713, -0.0287, 0.0053, -0.0463,
             0.0495, -0.0222, -0.0504, 0.0064, -0.0204, 0.0168, 0.0127,
             -0.0113, -0.0586, -0.0544, 0.0114, -0.0077, 0.0666, -0.0663,
             -0.0420, -0.0698, -0.0314, 0.0272, 0.0361, -0.0537, 0.0301,
             0.0036, -0.0472, -0.0499, 0.0114, 0.0182, 0.0008, -0.0132,
             -0.0803, -0.0087, -0.0017, 0.0598, -0.0627, 0.0859, 0.0117,
             0.0105, 0.0309, -0.0370, -0.0111, -0.0262, 0.0338, 0.0141,
             -0.0385, 0.0547, 0.0648, -0.0370, 0.0107, -0.0629, -0.0163,
             0.0282, -0.0670, 0.0161, -0.0244, -0.0030, 0.0038, -0.0208,
             0.0005, 0.0125, 0.0608, -0.0089, 0.0026, 0.0562, -0.0678,
             0.0841, -0.0265, -0.0461, -0.0124, 0.0276, 0.0364, 0.0195,
             -0.0309, -0.0337, -0.0603, -0.0252, -0.0356, 0.0221, 0.0184,
             -0.0154, -0.0136, -0.0277, 0.0283, 0.0495, 0.0185, -0.0357,
             0.0305, -0.0052, -0.0432, -0.0135, -0.0554, -0.0094, 0.0272,
             0.1030, 0.0049, 0.0012, -0.0140, 0.0146, -0.0979, 0.0487,
             0.0122, -0.0204, 0.0496, -0.0055, -0.0015, -0.0170, 0.0053,
             0.0104, -0.0742, 0.0742, -0.0381, 0.0104, -0.0065, -0.0564,
             0.0453, -0.0057, -0.0029, -0.0722, 0.0094, -0.0561, 0.0284,
             0.0402, 0.0233, -0.0716, -0.0424, 0.0165, -0.0505, 0.0006,
             0.0219, -0.0601, 0.0656, -0.0175, -0.0524, 0.0355, 0.0007,
             -0.0042, -0.0443, 0.0871, -0.0403, -0.0031, 0.0171, -0.0359,
             -0.0520, -0.0344, 0.0239, 0.0099, 0.0004, 0.0235, 0.0238,
             -0.0153, 0.0501, -0.0052, 0.0162, 0.0313, -0.0121, 0.0009,
             -0.0366, -0.0628, 0.0386, -0.0671, 0.0480, -0.0595, 0.0568,
             -0.0604, -0.0540, 0.0403, -0.0187, 0.0649, 0.0029, -0.0003,
             0.0020, -0.0056, 0.0471, -0.0145, -0.0126, -0.0395, -0.0455,
             -0.0437, 0.0056, 0.0331, 0.0004, 0.0127, -0.0022, -0.0502,
             0.0362, 0.0624, -0.0012, -0.0515, 0.0303, -0.0357, -0.0420,
             0.0321, -0.0162, 0.0007, -0.0272, 0.0227, 0.0187, -0.0459,
             0.0496
         ]]),
         decimal=3)
コード例 #7
0
class TestAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = AtariEnvironment('Breakout')
        self.env.reset()
        self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
        self.parallel_env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_a2c(self):
        self.validate_preset(a2c)

    def test_c51(self):
        self.validate_preset(c51)

    def test_ddqn(self):
        self.validate_preset(ddqn)

    def test_dqn(self):
        self.validate_preset(dqn)

    def test_ppo(self):
        self.validate_preset(ppo)

    def test_rainbow(self):
        self.validate_preset(rainbow)

    def test_vac(self):
        self.validate_preset(vac)

    def test_vpq(self):
        self.validate_preset(vpg)

    def test_vsarsa(self):
        self.validate_preset(vsarsa)

    def test_vqn(self):
        self.validate_preset(vqn)

    def validate_preset(self, builder):
        preset = builder.device('cpu').env(self.env).build()
        if isinstance(preset, ParallelPreset):
            return self.validate_parallel_preset(preset)
        return self.validate_standard_preset(preset)

    def validate_standard_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.state)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)

    def validate_parallel_preset(self, preset):
        # train agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.parallel_env.state_array)
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
        # parallel test_agent
        parallel_test_agent = preset.test_agent()
        parallel_test_agent.act(self.parallel_env.state_array)
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.state)
コード例 #8
0
 def test_rainbow_model_cuda(self):
     env = AtariEnvironment('Breakout')
     model = nature_rainbow(env).cuda()
     env.reset()
     x = torch.cat([env.state.raw] * 4, dim=1).float().cuda()
     out = model(x)
     tt.assert_almost_equal(
         out.cpu(),
         torch.tensor([[
             -1.4765e-02, -4.0353e-02, -2.1705e-02, -2.2314e-02, 3.6881e-02,
             -1.4175e-02, 1.2442e-02, -6.8713e-03, 2.4970e-02, 2.5681e-02,
             -4.5859e-02, -2.3327e-02, 3.6205e-02, 7.1024e-03, -2.7564e-02,
             2.1592e-02, -3.2728e-02, 1.3602e-02, -1.1690e-02, -4.3082e-02,
             -1.2996e-02, 1.7184e-02, 1.3446e-02, -3.3587e-03, -4.6350e-02,
             -1.7646e-02, 2.1954e-02, 8.5546e-03, -2.1359e-02, -2.4206e-02,
             -2.3151e-02, -3.6330e-02, 4.4699e-02, 3.9887e-03, 1.5609e-02,
             -4.3950e-02, 1.0955e-02, -2.4277e-02, 1.4915e-02, 3.2508e-03,
             6.1454e-02, 3.5242e-02, -1.5274e-02, -2.6729e-02, -2.4072e-02,
             1.5696e-02, 2.6622e-02, -3.5404e-02, 5.1701e-02, -5.3047e-02,
             -1.8412e-02, 8.6640e-03, -3.1722e-02, 4.0329e-02, 1.2896e-02,
             -1.4139e-02, -4.9200e-02, -4.6193e-02, -2.9064e-03,
             -2.2078e-02, -4.0084e-02, -8.3519e-03, -2.7589e-02,
             -4.9979e-03, -1.6055e-02, -4.5311e-02, -2.6951e-02, 2.8032e-02,
             -4.0069e-03, 3.2405e-02, -5.3164e-03, -3.0139e-03, 6.6179e-04,
             -4.9243e-02, 3.2515e-02, 9.8307e-03, -3.4257e-03, -3.9522e-02,
             1.2594e-02, -2.7210e-02, 2.3451e-02, 4.2257e-02, 2.2239e-02,
             1.4304e-04, 4.2905e-04, 1.5193e-02, 3.1897e-03, -1.0828e-02,
             -4.8345e-02, 6.8747e-02, -7.1725e-03, -9.7815e-03, -1.6331e-02,
             1.0434e-02, -8.8083e-04, 3.8219e-02, 6.8332e-03, -2.0189e-02,
             2.8141e-02, 1.4913e-02, -2.4925e-02, -2.8922e-02, -7.1546e-03,
             1.9791e-02, 1.1160e-02, 1.0306e-02, -1.3631e-02, 2.7318e-03,
             1.4050e-03, -8.2064e-03, 3.5836e-02, -1.5877e-02, -1.1198e-02,
             1.9514e-02, 3.0832e-03, -6.2730e-02, 6.1493e-03, -1.2340e-02,
             3.9110e-02, -2.6895e-02, -5.1718e-03, 7.5017e-03, 1.2673e-03,
             4.7525e-02, 1.7373e-03, -5.1745e-03, -2.8621e-02, 3.4984e-02,
             -3.2622e-02, 1.0748e-02, 1.2499e-02, -1.8788e-02, -8.6717e-03,
             4.3620e-02, 2.8460e-02, -6.8146e-03, -3.5824e-02, 9.2931e-03,
             3.7893e-03, 2.4187e-02, 1.3393e-02, -5.9393e-03, -9.9837e-03,
             -8.1019e-03, -2.1840e-02, -3.8945e-02, 1.6736e-02, -4.7475e-02,
             4.9770e-02, 3.4695e-02, 1.8961e-02, 2.7416e-02, -1.3578e-02,
             -9.8595e-03, 2.2834e-03, 2.4829e-02, -4.3998e-02, 3.2398e-02,
             -1.4200e-02, 2.4907e-02, -2.2542e-02, -9.2765e-03, 2.0658e-03,
             -4.1246e-03, -1.8095e-02, -1.2732e-02, -3.2090e-03, 1.3127e-02,
             -2.0888e-02, 1.4931e-02, -4.0576e-02, 4.2877e-02, 7.9411e-05,
             -4.4377e-02, 3.2357e-03, 1.6201e-02, 4.0387e-02, -1.9023e-02,
             5.8033e-02, -3.3424e-02, 2.9598e-03, -1.8526e-02, -2.2967e-02,
             4.3449e-02, -1.2564e-02, -9.3756e-03, -2.1745e-02, -2.7089e-02,
             -3.6791e-02, -5.2018e-02, 2.4588e-02, 1.0037e-03, 3.9753e-02,
             4.3534e-02, 2.6446e-02, -1.1808e-02, 2.1426e-02, 7.5522e-03,
             2.2847e-03, -2.7211e-02, 4.1364e-02, -1.1281e-02, 1.6523e-03,
             -1.9913e-03
         ]]),
         decimal=3)
     optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
     loss = out.sum()
     loss.backward()
     optimizer.step()
     out = model(x)
     tt.assert_almost_equal(
         out.cpu(),
         torch.tensor([[
             -0.0247, -0.0172, -0.0633, -0.0154, -0.0156, -0.1156, -0.0793,
             -0.0184, -0.0408, 0.0005, -0.0920, -0.0481, -0.0597, -0.0243,
             0.0006, -0.1045, -0.0476, -0.0030, -0.0230, -0.0869, -0.0149,
             -0.0412, -0.0753, -0.0640, -0.1106, -0.0632, -0.0645, -0.0474,
             -0.0124, -0.0698, -0.0275, -0.0415, -0.0916, -0.0957, -0.0851,
             -0.1296, -0.1049, -0.0196, -0.0823, -0.0380, -0.1085, -0.0526,
             -0.0083, -0.1274, -0.0426, -0.0183, -0.0585, -0.0366, -0.1111,
             -0.0074, -0.1238, -0.0324, -0.0166, -0.0719, -0.0285, -0.0427,
             -0.1158, -0.0569, 0.0075, -0.0419, -0.0288, -0.1189, -0.0220,
             -0.0370, 0.0040, 0.0228, -0.0958, -0.0258, -0.0276, -0.0405,
             -0.0958, -0.0201, -0.0639, -0.0543, -0.0705, -0.0940, -0.0700,
             -0.0921, -0.0426, 0.0026, -0.0556, -0.0439, -0.0386, -0.0957,
             -0.0915, -0.0679, -0.1272, -0.0754, -0.0076, -0.1046, -0.0350,
             -0.0887, -0.0350, -0.0270, -0.1188, -0.0449, 0.0020, -0.0406,
             0.0011, -0.0842, -0.0422, -0.1280, -0.0205, 0.0002, -0.0789,
             -0.0185, -0.0510, -0.1180, -0.0550, -0.0159, -0.0702, -0.0029,
             -0.0891, -0.0253, -0.0485, -0.0128, 0.0010, -0.0870, -0.0230,
             -0.0233, -0.0411, -0.0870, -0.0419, -0.0688, -0.0583, -0.0448,
             -0.0864, -0.0926, -0.0758, -0.0540, 0.0058, -0.0843, -0.0365,
             -0.0608, -0.0787, -0.0938, -0.0680, -0.0995, -0.0764, 0.0061,
             -0.0821, -0.0636, -0.0848, -0.0373, -0.0285, -0.1086, -0.0464,
             -0.0228, -0.0464, -0.0279, -0.1053, -0.0224, -0.1268, -0.0006,
             -0.0186, -0.0836, -0.0011, -0.0415, -0.1222, -0.0668, -0.0015,
             -0.0535, -0.0071, -0.1202, -0.0257, -0.0503, 0.0004, 0.0099,
             -0.1113, -0.0182, -0.0080, -0.0216, -0.0661, -0.0115, -0.0468,
             -0.0716, -0.0404, -0.0950, -0.0681, -0.0933, -0.0699, -0.0154,
             -0.0853, -0.0414, -0.0403, -0.0700, -0.0685, -0.0975, -0.0934,
             -0.1016, -0.0121, -0.1084, -0.0391, -0.1006, -0.0441, -0.0024,
             -0.1232, -0.0159, 0.0012, -0.0480, -0.0013, -0.0789, -0.0309,
             -0.1101
         ]]),
         decimal=3)