def test_register_mismatch(self): of = obs.ObservationFactory(self.client) test_obs = CompliantObs(None) with self.subTest('obs_space mismatch'): test_obs.labels = ['1', '2', '3'] with self.assertRaises(ValueError): of.register_observation(test_obs) with self.subTest('observation mismatch'): test_obs.observation_space = spaces.Box(low=0, high=3, shape=(3,)) with self.assertRaises(ValueError): of.register_observation(test_obs)
def test_get_observation_space_normalized(self): of = obs.ObservationFactory(self.client, normalize=True) of.register_observation(CompliantObs(None)) test_obs = CompliantObs(None) test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), high=np.array([5., 6.])) of.register_observation(test_obs) self.assertEqual( of.get_observation_space(), spaces.Box(low=np.array([-1., -1., -1., -1.]), high=np.array([1., 1., 1., 1.])))
def test_get_observation_space_multiple_observations(self): of = obs.ObservationFactory(self.client) of.register_observation(CompliantObs(None)) of.get_observation_space() test_obs = CompliantObs(None) test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), high=np.array([5., 6.])) of.register_observation(test_obs) with self.subTest('from cache'): self.assertEqual(of.get_observation_space(), CompliantObs.observation_space) with self.subTest('regenerate cache'): self.assertEqual(of.get_observation_space(generate=True), spaces.Box(low=np.array([0., 0., 5., 6.]), high=np.array([3., 3., 5., 6.])))