def test_step(self): env = AtariEnvironment('Breakout') env.reset() state = env.step(1) self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) self.assertFalse(state.done) self.assertEqual(state.mask, 1) self.assertEqual(state['life_lost'], False)
def test_runs(self): np.random.seed(0) torch.random.manual_seed(0) n = 4 envs = [] for i in range(n): env = AtariEnvironment('Breakout') env.reset() envs.append(env) agent = MockAgent(n, max_action=4) body = ParallelAtariBody(agent, envs, noop_max=30) for _ in range(200): states = [env.state for env in envs] rewards = torch.tensor([env.reward for env in envs]).float() actions = body.act(states, rewards) for i, env in enumerate(envs): if actions[i] is not None: env.step(actions[i])
def test_step_until_done(self): env = AtariEnvironment('Breakout') env.reset() for _ in range(1000): state = env.step(1) if state.done: break self.assertEqual(state.observation.shape, (1, 84, 84)) self.assertEqual(state.reward, 0) self.assertTrue(state.done) self.assertEqual(state.mask, 0) self.assertEqual(state['life_lost'], False)