class TestRawObs(unittest.TestCase): """ Note that this test tests with environments, rather than the fixture. The fixture matches the output passed to agent when evaluated with env.make() and env.run() using kaggle_environments.make. This differs slightly to the raw observations returned by the envs (gym.make or gfootball.make), which contains arrays in place of lists for some fields. The processing in Raw obs must work with both, so can't do things like [] + [] as if one of those is an array, something different happens.... """ @classmethod def setUpClass(cls) -> None: register_all() def setUp(self): self._kaggle_env = gym.make("GFootball-kaggle_11_vs_11-v0") self._gfootball_env = self._env = FootballEnv(config=Config()) self._sut = RawObs() def test_process_active_output_shapes_as_expected_with_kaggle_env(self): # Arrange ro = self._sut.set_obs(self._kaggle_env.reset()) # Act active = ro.process_key('active') # Assert self.assertEqual((1, RawObs.active_n), active.shape) def test_process_active_output_shapes_as_expected_with_gf_env(self): # Arrange ro = self._sut.set_obs(self._gfootball_env.reset()) # Act active = ro.process_key('active') # Assert self.assertEqual((1, RawObs.active_n), active.shape) def test_process_ball_output_shapes_as_expected_with_kaggle_env(self): # Arrange ro = self._sut.set_obs(self._kaggle_env.reset()) # Act active = ro.process_key('ball') # Assert self.assertEqual((1, RawObs.ball_n), active.shape) def test_process_ball_info_output_shapes_as_expected_with_gf_env(self): # Arrange ro = self._sut.set_obs(self._gfootball_env.reset()) # Act active = ro.process_key('ball') # Assert self.assertEqual((1, RawObs.ball_n), active.shape) def test_process_tired_factor_output_shapes_as_expected_with_kaggle_env( self): # Arrange ro = self._sut.set_obs(self._kaggle_env.reset()) # Act active = ro.process_key('left_team_tired_factor') # Assert self.assertEqual((1, RawObs.left_team_tired_factor_n), active.shape) def test_process_tired_factor_output_shapes_as_expected_with_gf_env(self): # Arrange ro = self._sut.set_obs(self._gfootball_env.reset()) # Act active = ro.process_key('right_team_tired_factor') # Assert self.assertEqual((1, RawObs.right_team_tired_factor_n), active.shape) def test_output_shape_as_expected_with_kaggle_env_reset(self): # Arrange obs = self._kaggle_env.reset() # Act raw_obs = RawObs.convert_observation(obs) # Assert self.assertEqual(RawObs().shape, raw_obs.shape) def test_output_shape_as_expected_with_gfootball_env_reset(self): obs = self._gfootball_env.reset() # Act raw_obs = RawObs.convert_observation(obs) # Assert self.assertEqual(RawObs().shape, raw_obs.shape) def test_process_all_shape_as_expected_with_kaggle_env(self): # Arrange raw_obs = self._kaggle_env.reset() # Act processed_obs = self._sut.set_obs(raw_obs).process() self.assertEqual(RawObs().shape, processed_obs.shape) def test_process_defaults_shape_as_expected_with_kaggle_env(self): # Arrange raw_obs = self._kaggle_env.reset() # Act processed_obs = RawObs( using=RawObs.standard_keys).set_obs(raw_obs).process() self.assertEqual( RawObs(using=RawObs.standard_keys).shape, processed_obs.shape) def test_process_all_shape_as_expected_with_gfootball_env(self): # Arrange raw_obs = self._gfootball_env.reset() # Act processed_obs = self._sut.set_obs(raw_obs).process() self.assertEqual(RawObs().shape, processed_obs.shape) def test_process_defaults_shape_as_expected_with_gfootball_env(self): # Arrange raw_obs = self._gfootball_env.reset() # Act processed_obs = RawObs( using=RawObs.standard_keys).set_obs(raw_obs).process() self.assertEqual( RawObs(using=RawObs.standard_keys).shape, processed_obs.shape)
class TestRawObs(unittest.TestCase): def setUp(self): self._raw_obs_fixture = RawObsFixture() self.sut = RawObs(using=RawObs.standard_keys) self.sut.set_obs(self._raw_obs_fixture.data) def test_process_list_field(self): # Act obs = self.sut.process_key('ball') # Assert self.assertEqual(2, len(obs.shape)) self.assertEqual(1, obs.shape[0]) def test_process_non_list_field(self): # Act obs = self.sut.process_key('active') # Assert self.assertEqual(2, len(obs.shape)) self.assertEqual(1, obs.shape[0]) def test_process_non_flat_field(self): # Act obs = self.sut.process_key('left_team') # Assert self.assertEqual(2, len(obs.shape)) self.assertEqual(1, obs.shape[0]) def test_process(self): # Act raw_obs = self.sut.process() # Assert self.assertEqual(self.sut.shape, raw_obs.shape) def test__add_distance_to_ball(self): # Act distance_to_ball = self.sut._add_distance_to_ball() self.assertEqual((1, self.sut.distance_to_ball_n), distance_to_ball.shape) def test_with_obs_indexed_out_of_list(self): # Arrange ro = RawObs().set_obs(self._raw_obs_fixture.data[0]) # Act raw_obs = ro.process() # Assert self.assertEqual(ro.shape, raw_obs.shape) def test_all_keys(self): for key in self.sut.standard_keys: obs = self.sut.process_key(key) self.assertEqual(getattr(self.sut, f"{key}_n"), obs.shape[1]) def test_process_returns_none_if_nothing_to_do(self): # Arrange ro = RawObs(using=[]).set_obs(self._raw_obs_fixture.data[0]) # Act raw_obs = ro.process() # Assert self.assertEqual((1, 0), ro.shape) self.assertIsNone(raw_obs)