예제 #1
0
    def test_resets_after_limit(self):
        max_steps = 5
        base_env = mock.MagicMock()
        wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps)

        base_env.gym.game_over = False
        base_env.reset.return_value = ts.restart(1)
        base_env.step.return_value = ts.transition(2, 0)
        action = 1

        for _ in range(max_steps + 1):
            wrapped_env.step(action)

        self.assertTrue(wrapped_env.game_over)
        self.assertEqual(1, base_env.reset.call_count)

        wrapped_env.step(action)
        self.assertFalse(wrapped_env.game_over)
        self.assertEqual(2, base_env.reset.call_count)
예제 #2
0
    def test_game_over_after_limit(self):
        max_steps = 5
        base_env = mock.MagicMock()
        wrapped_env = atari_wrappers.AtariTimeLimit(base_env, max_steps)

        base_env.gym.game_over = False
        base_env.reset.return_value = ts.restart(1)
        base_env.step.return_value = ts.transition(2, 0)
        action = 1

        self.assertFalse(wrapped_env.game_over)

        for _ in range(max_steps):
            time_step = wrapped_env.step(action)
            self.assertFalse(time_step.is_last())
            self.assertFalse(wrapped_env.game_over)

        time_step = wrapped_env.step(action)
        self.assertTrue(time_step.is_last())
        self.assertTrue(wrapped_env.game_over)