Пример #1
0
    def __init__(self, config: configs.Solo8BaseConfig, use_gui: bool):
        """Create a solo8 env.

    Args:
      config (configs.Solo8BaseConfig): The SoloConfig. Defaults to None.
      use_gui (bool): Whether or not to show the pybullet GUI. Defaults to 
        False.
    """
        self.config = config

        self.client = bc.BulletClient(
            connection_mode=p.GUI if use_gui else p.DIRECT)
        self.client.setAdditionalSearchPath(pbd.getDataPath())
        self.client.setGravity(*self.config.gravity)

        if self.config.dt:
            self.client.setPhysicsEngineParameter(fixedTimeStep=self.config.dt,
                                                  numSubSteps=1)
        else:
            self.client.setRealTimeSimulation(1)

        self.client_configuration()

        self.plane = self.client.loadURDF('plane.urdf')
        self.load_bodies()

        self.obs_factory = obs.ObservationFactory(self.client)
        self.reward_factory = rewards.RewardFactory(self.client)
        self.termination_factory = terms.TerminationFactory()

        self.reset(init_call=True)
Пример #2
0
  def test_register_happy(self):
    of = obs.ObservationFactory(self.client)

    with self.subTest('single obs'):
      test_obs = CompliantObs(None)
      of.register_observation(test_obs)

      self.assertEqual(len(of._observations), 1)
      self.assertEqual(of._observations[0], test_obs)
      self.assertEqual(of.get_observation_space(),
                      CompliantObs.observation_space)

    with self.subTest('multiple_obs'):
      test_obs2 = CompliantObs(2)
      of.register_observation(test_obs2)

      self.assertEqual(len(of._observations), 2)
      self.assertNotEqual(of._observations[0], test_obs2)
      self.assertEqual(of._observations[1], test_obs2)

      with self.subTest('Cached observation space'):
        self.assertEqual(of.get_observation_space(),
                        spaces.Box(low=0, high=3, shape=(2,)))

      with self.subTest('Fresh observation space'):
        self.assertEqual(of.get_observation_space(generate=True),
                        spaces.Box(low=0, high=3, shape=(4,)))
Пример #3
0
  def test_empty(self):
    of = obs.ObservationFactory(self.client)

    self.assertFalse(of._observations)
    self.assertIsNone(of._obs_space)

    with self.assertRaises(ValueError):
      observations, labels = of.get_obs()
Пример #4
0
    def test_get_obs_single_observation(self):
        of = obs.ObservationFactory(self.client)
        of.register_observation(CompliantObs(None))

        observations, labels = of.get_obs()

        np.testing.assert_array_equal(observations, np.array([1, 2]))
        self.assertListEqual(labels, ['1', '2'])
Пример #5
0
  def test_register_mismatch(self):
    of = obs.ObservationFactory(self.client)
    test_obs = CompliantObs(None)
    
    with self.subTest('obs_space mismatch'):
      test_obs.labels = ['1', '2', '3']
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)

    with self.subTest('observation mismatch'):
      test_obs.observation_space = spaces.Box(low=0, high=3, shape=(3,))
      with self.assertRaises(ValueError):
        of.register_observation(test_obs)
Пример #6
0
    def test_get_obs_multiple_observations(self):
        of = obs.ObservationFactory(self.client)
        of.register_observation(CompliantObs(None))

        test_obs = CompliantObs(None)
        test_obs.compute = lambda: np.array([5, 6])
        test_obs.labels = ['5', '6']

        of.register_observation(test_obs)
        observations, labels = of.get_obs()

        np.testing.assert_array_equal(observations, np.array([1, 2, 5, 6]))
        self.assertListEqual(labels, ['1', '2', '5', '6'])
Пример #7
0
    def test_get_observation_space_normalized(self):
        of = obs.ObservationFactory(self.client, normalize=True)
        of.register_observation(CompliantObs(None))

        test_obs = CompliantObs(None)
        test_obs.observation_space = spaces.Box(low=np.array([5., 6.]),
                                                high=np.array([5., 6.]))
        of.register_observation(test_obs)

        self.assertEqual(
            of.get_observation_space(),
            spaces.Box(low=np.array([-1., -1., -1., -1.]),
                       high=np.array([1., 1., 1., 1.])))
Пример #8
0
  def test_get_observation_space_multiple_observations(self):
    of = obs.ObservationFactory(self.client)
    of.register_observation(CompliantObs(None))
    of.get_observation_space()

    test_obs = CompliantObs(None)
    test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), 
                                            high=np.array([5., 6.]))
    of.register_observation(test_obs)

    with self.subTest('from cache'):
      self.assertEqual(of.get_observation_space(),
                       CompliantObs.observation_space)

    with self.subTest('regenerate cache'):
      self.assertEqual(of.get_observation_space(generate=True),
                       spaces.Box(low=np.array([0., 0., 5., 6.]),
                                  high=np.array([3., 3., 5., 6.])))
Пример #9
0
    def test_get_obs_normalized(self):
        of = obs.ObservationFactory(self.client, normalize=True)

        test_obs1 = CompliantObs(None)
        test_obs1.compute = lambda: np.array([1.5, 1.5])
        of.register_observation(test_obs1)

        test_obs2 = CompliantObs(None)
        test_obs2.compute = lambda: np.array([0., 0.])
        of.register_observation(test_obs2)

        test_obs3 = CompliantObs(None)
        test_obs3.compute = lambda: np.array([3., 3.])
        of.register_observation(test_obs3)

        observations, _ = of.get_obs()
        np.testing.assert_array_equal(observations,
                                      np.array([0., 0., -1., -1., 1., 1.]))
Пример #10
0
 def test_get_obs_no_observations(self):
   of = obs.ObservationFactory(self.client)
   
   with self.assertRaises(ValueError):
     observations, labels = of.get_obs()
Пример #11
0
 def test_get_observation_space_no_observations(self):
   of = obs.ObservationFactory(self.client)
   with self.assertRaises(ValueError):
     of.get_observation_space()