Beispiel #1
0
  def test_register_mismatch(self):
    of = obs.ObservationFactory(self.client)
    test_obs = CompliantObs(None)
    
    with self.subTest('obs_space mismatch'):
      test_obs.labels = ['1', '2', '3']
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)

    with self.subTest('observation mismatch'):
      test_obs.observation_space = spaces.Box(low=0, high=3, shape=(3,))
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)
Beispiel #2
0
    def test_get_obs_multiple_observations(self):
        of = obs.ObservationFactory(self.client)
        of.register_observation(CompliantObs(None))

        test_obs = CompliantObs(None)
        test_obs.compute = lambda: np.array([5, 6])
        test_obs.labels = ['5', '6']

        of.register_observation(test_obs)
        observations, labels = of.get_obs()

        np.testing.assert_array_equal(observations, np.array([1, 2, 5, 6]))
        self.assertListEqual(labels, ['1', '2', '5', '6'])