コード例 #1
0
class TestSimpleAndSMMObsWrapper(unittest.TestCase):
    def setUp(self):
        self._env = FootballEnv(config=Config())
        self._sut = SimpleAndSMMObsWrapper

    def _assert_obs_shape(self, obs: Tuple[np.ndarray, np.ndarray]):
        self.assertEqual(2, len(obs))
        self.assertEqual((1, 72, 96, 4), obs[0].shape)
        self.assertEqual((115, ), obs[1].shape)

    def test_shapes_as_expected_with_env_on_reset(self):
        # Arrange
        wrapped_env = self._sut(self._env)

        # Act
        obs = wrapped_env.reset()

        # Assert
        self._assert_obs_shape(obs)

    def test_shapes_as_expected_without_env_on_reset(self):
        # Arrange
        obs = self._env.reset()

        # Act
        obs = self._sut.process_obs(obs)

        # Assert
        self._assert_obs_shape(obs)

    def test_shapes_as_expected_with_env_on_step(self):
        # Arrange
        wrapped_env = self._sut(self._env)
        wrapped_env.reset()

        # Act
        obs, reward, done, info = wrapped_env.step(0)

        # Assert
        self._assert_obs_shape(obs)

    def test_shapes_as_expected_without_env_on_step(self):
        # Arrange
        _ = self._env.reset()
        obs, reward, done, info = self._env.step(0)

        # Act
        obs = self._sut.process_obs(obs)

        # Assert
        self._assert_obs_shape(obs)
コード例 #2
0
    def test_stacked_with_simple_and_smm_obs_wrapper_on_reset(self):
        # Arrange
        env = SMMFrameProcessWrapper(
            SimpleAndSMMObsWrapper(FootballEnv(config=Config())))

        # Act
        obs = env.reset()

        # Assert
        self.assertEqual(1, np.max(obs[0]))
        self.assertEqual(0, np.min(obs[0]))
        self.assertEqual((72, 96, 4), obs[0].shape)
        self.assertEqual((115, ), obs[1].shape)
コード例 #3
0
    def test_stacked_with_simple_and_smm_obs_wrapper_on_step(self):
        # Arrange
        env = SMMFrameProcessWrapper(
            SimpleAndSMMObsWrapper(FootballEnv(config=Config())))
        obs = env.reset()

        # Act
        first_call, _, _, _ = env.step(0)
        second_call, _, _, _ = env.step(0)
        third_call, _, _, _ = env.step(0)
        fourth_call, _, _, _ = env.step(0)

        # Assert
        self.assertEqual(1, np.max(first_call[0]))
        self.assertEqual(-1, np.min(first_call[0]))
        self.assertEqual(1, np.max(second_call[0]))
        self.assertEqual(-1, np.min(second_call[0]))
        self.assertEqual((72, 96, 4), obs[0].shape)
        self.assertEqual((115, ), obs[1].shape)
コード例 #4
0
 def setUp(self):
     self._kaggle_env = gym.make("GFootball-kaggle_11_vs_11-v0")
     self._gfootball_env = self._env = FootballEnv(config=Config())
     self._sut = RawObs()
コード例 #5
0
 def setUp(self):
     self._env = FootballEnv(config=Config())
     self._sut = SimpleAndSMMObsWrapper
コード例 #6
0
 def setUp(self) -> None:
     self._env = FootballEnv(config=Config())
     self._sut = SimpleAndRawObsWrapper
コード例 #7
0
class TestSimpleAndRawObsWrapper(unittest.TestCase):
    _raw_obs_fixture = RawObsFixture()

    @classmethod
    def setUpClass(cls) -> None:
        register_all()
        cls.tmp_dir = tempfile.TemporaryDirectory()

    @classmethod
    def tearDownClass(cls) -> None:
        cls.tmp_dir.cleanup()

    def setUp(self) -> None:
        self._env = FootballEnv(config=Config())
        self._sut = SimpleAndRawObsWrapper

    def test_shapes_as_expected_with_env_on_reset(self):
        # Arrange
        wrapped_env = self._sut(self._env)

        # Act
        obs = wrapped_env.reset()

        # Assert
        self.assertIsInstance(obs, np.ndarray)
        self.assertEqual(wrapped_env.observation_space.shape, obs.shape)
        self.assertEqual((RawObs().shape[1] + 115, ), obs.shape)

    def test_shapes_as_expected_with_kaggle_env(self):
        # Arrange
        wrapped_env = self._sut(gym.make("GFootball-kaggle_11_vs_11-v0"))

        # Act
        obs = wrapped_env.reset()

        # Assert
        self.assertIsInstance(obs, np.ndarray)
        self.assertEqual(wrapped_env.observation_space.shape, obs.shape)
        self.assertEqual((RawObs().shape[1] + 115, ), obs.shape)

    def test_shapes_as_expected_without_env_on_reset(self):
        # Arrange
        raw_obs = self._env.reset()

        # Act
        processed_obs = self._sut.process_obs(raw_obs)

        # Assert
        self.assertIsInstance(processed_obs, np.ndarray)
        self.assertEqual((RawObs().shape[1] + 115, ), processed_obs.shape)

    def test_shapes_as_expected_with_env_on_reset_with_all(self):
        # Arrange
        wrapped_env = self._sut(self._env, raw_using=RawObs.standard_keys)

        # Act
        obs = wrapped_env.reset()

        # Assert
        self.assertIsInstance(obs, np.ndarray)
        self.assertEqual(wrapped_env.observation_space.shape, obs.shape)
        self.assertEqual((RawObs(using=RawObs.standard_keys).shape[1] + 115, ),
                         obs.shape)

    def test_shapes_as_expected_with_env_on_step(self):
        # Arrange
        wrapped_env = self._sut(self._env)
        wrapped_env.reset()

        # Act
        obs, reward, done, info = wrapped_env.step(0)

        # Assert
        self.assertIsInstance(obs, np.ndarray)
        self.assertEqual(wrapped_env.observation_space.shape, obs.shape)
        self.assertEqual((RawObs().shape[1] + 115, ), obs.shape)

    def test_shapes_as_expected_without_env_on_step(self):
        # Arrange
        _ = self._env.reset()
        obs, reward, done, info = self._env.step(0)

        # Act
        obs = self._sut.process_obs(obs)

        # Assert
        self.assertIsInstance(obs, np.ndarray)
        self.assertEqual((RawObs().shape[1] + 115, ), obs.shape)

    def test_dump_created_when_path_set(self):
        # Arrange
        raw_dump_path = f"{self.tmp_dir.name}/1.json"
        wrapped_env = self._sut(self._env, raw_dump_path=raw_dump_path)

        # Act
        _ = wrapped_env.reset()
        obs, _, _, _ = wrapped_env.step(0)

        # Assert
        self.assertTrue(os.path.exists(raw_dump_path))
        with open(raw_dump_path, 'r') as f:
            loaded_json = json.load(f)
        self.assertListEqual(['players_raw'], list(loaded_json.keys()))
        self.assertListEqual(
            sorted(RawObs.standard_keys),
            sorted(list(loaded_json['players_raw'][0].keys())))

    def test_dumps_but_does_not_add_to_returned_obs_when_using_is_empty(self):
        # Arrange
        raw_dump_path = f"{self.tmp_dir.name}/2.json"
        wrapped_env = self._sut(self._env,
                                raw_dump_path=raw_dump_path,
                                raw_using=[])

        # Act
        _ = wrapped_env.reset()
        obs, _, _, _ = wrapped_env.step(0)

        # Assert
        self.assertEqual((115, ), obs.shape)
        self.assertTrue(os.path.exists(raw_dump_path))
        with open(raw_dump_path, 'r') as f:
            loaded_json = json.load(f)
        self.assertListEqual(['players_raw'], list(loaded_json.keys()))
        self.assertListEqual(
            sorted(RawObs.standard_keys),
            sorted(list(loaded_json['players_raw'][0].keys())))