コード例 #1
0
class TestMultiagentAtariPresets(unittest.TestCase):
    def setUp(self):
        self.env = MultiagentAtariEnv('pong_v2', device='cpu')
        self.env.reset()

    def tearDown(self):
        if os.path.exists('test_preset.pt'):
            os.remove('test_preset.pt')

    def test_independent(self):
        env = MultiagentAtariEnv('pong_v2', device='cpu')
        presets = {
            agent_id: dqn.device('cpu').env(env.subenvs[agent_id]).build()
            for agent_id in env.agents
        }
        self.validate_preset(
            IndependentMultiagentPreset('independent', 'cpu', presets), env)

    def validate_preset(self, preset, env):
        # normal agent
        agent = preset.agent(writer=DummyWriter(), train_steps=100000)
        agent.act(self.env.last())
        # test agent
        test_agent = preset.test_agent()
        test_agent.act(self.env.last())
        # test save/load
        preset.save('test_preset.pt')
        preset = torch.load('test_preset.pt')
        test_agent = preset.test_agent()
        test_agent.act(self.env.last())
コード例 #2
0
 def test_last(self):
     env = MultiagentAtariEnv('pong_v1', device='cpu')
     env.reset()
     state = env.last()
     self.assertEqual(state.observation.shape, (1, 84, 84))
     self.assertEqual(state.reward, 0)
     self.assertEqual(state.done, False)
     self.assertEqual(state.mask, 1.)
     self.assertEqual(state['agent'], 'first_0')