コード例 #1
0
  def test_register_mismatch(self):
    of = obs.ObservationFactory(self.client)
    test_obs = CompliantObs(None)
    
    with self.subTest('obs_space mismatch'):
      test_obs.labels = ['1', '2', '3']
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)

    with self.subTest('observation mismatch'):
      test_obs.observation_space = spaces.Box(low=0, high=3, shape=(3,))
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)
コード例 #2
0
ファイル: test_obs_factory.py プロジェクト: WPI-MMR/gym_solo
    def test_get_observation_space_normalized(self):
        of = obs.ObservationFactory(self.client, normalize=True)
        of.register_observation(CompliantObs(None))

        test_obs = CompliantObs(None)
        test_obs.observation_space = spaces.Box(low=np.array([5., 6.]),
                                                high=np.array([5., 6.]))
        of.register_observation(test_obs)

        self.assertEqual(
            of.get_observation_space(),
            spaces.Box(low=np.array([-1., -1., -1., -1.]),
                       high=np.array([1., 1., 1., 1.])))
コード例 #3
0
  def test_get_observation_space_multiple_observations(self):
    of = obs.ObservationFactory(self.client)
    of.register_observation(CompliantObs(None))
    of.get_observation_space()

    test_obs = CompliantObs(None)
    test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), 
                                            high=np.array([5., 6.]))
    of.register_observation(test_obs)

    with self.subTest('from cache'):
      self.assertEqual(of.get_observation_space(),
                       CompliantObs.observation_space)

    with self.subTest('regenerate cache'):
      self.assertEqual(of.get_observation_space(generate=True),
                       spaces.Box(low=np.array([0., 0., 5., 6.]),
                                  high=np.array([3., 3., 5., 6.])))