def __init__(self, config: configs.Solo8BaseConfig, use_gui: bool): """Create a solo8 env. Args: config (configs.Solo8BaseConfig): The SoloConfig. Defaults to None. use_gui (bool): Whether or not to show the pybullet GUI. Defaults to False. """ self.config = config self.client = bc.BulletClient( connection_mode=p.GUI if use_gui else p.DIRECT) self.client.setAdditionalSearchPath(pbd.getDataPath()) self.client.setGravity(*self.config.gravity) if self.config.dt: self.client.setPhysicsEngineParameter(fixedTimeStep=self.config.dt, numSubSteps=1) else: self.client.setRealTimeSimulation(1) self.client_configuration() self.plane = self.client.loadURDF('plane.urdf') self.load_bodies() self.obs_factory = obs.ObservationFactory(self.client) self.reward_factory = rewards.RewardFactory(self.client) self.termination_factory = terms.TerminationFactory() self.reset(init_call=True)
def test_register_happy(self): of = obs.ObservationFactory(self.client) with self.subTest('single obs'): test_obs = CompliantObs(None) of.register_observation(test_obs) self.assertEqual(len(of._observations), 1) self.assertEqual(of._observations[0], test_obs) self.assertEqual(of.get_observation_space(), CompliantObs.observation_space) with self.subTest('multiple_obs'): test_obs2 = CompliantObs(2) of.register_observation(test_obs2) self.assertEqual(len(of._observations), 2) self.assertNotEqual(of._observations[0], test_obs2) self.assertEqual(of._observations[1], test_obs2) with self.subTest('Cached observation space'): self.assertEqual(of.get_observation_space(), spaces.Box(low=0, high=3, shape=(2,))) with self.subTest('Fresh observation space'): self.assertEqual(of.get_observation_space(generate=True), spaces.Box(low=0, high=3, shape=(4,)))
def test_empty(self): of = obs.ObservationFactory(self.client) self.assertFalse(of._observations) self.assertIsNone(of._obs_space) with self.assertRaises(ValueError): observations, labels = of.get_obs()
def test_get_obs_single_observation(self): of = obs.ObservationFactory(self.client) of.register_observation(CompliantObs(None)) observations, labels = of.get_obs() np.testing.assert_array_equal(observations, np.array([1, 2])) self.assertListEqual(labels, ['1', '2'])
def test_register_mismatch(self): of = obs.ObservationFactory(self.client) test_obs = CompliantObs(None) with self.subTest('obs_space mismatch'): test_obs.labels = ['1', '2', '3'] with self.assertRaises(ValueError): of.register_observation(test_obs) with self.subTest('observation mismatch'): test_obs.observation_space = spaces.Box(low=0, high=3, shape=(3,)) with self.assertRaises(ValueError): of.register_observation(test_obs)
def test_get_obs_multiple_observations(self): of = obs.ObservationFactory(self.client) of.register_observation(CompliantObs(None)) test_obs = CompliantObs(None) test_obs.compute = lambda: np.array([5, 6]) test_obs.labels = ['5', '6'] of.register_observation(test_obs) observations, labels = of.get_obs() np.testing.assert_array_equal(observations, np.array([1, 2, 5, 6])) self.assertListEqual(labels, ['1', '2', '5', '6'])
def test_get_observation_space_normalized(self): of = obs.ObservationFactory(self.client, normalize=True) of.register_observation(CompliantObs(None)) test_obs = CompliantObs(None) test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), high=np.array([5., 6.])) of.register_observation(test_obs) self.assertEqual( of.get_observation_space(), spaces.Box(low=np.array([-1., -1., -1., -1.]), high=np.array([1., 1., 1., 1.])))
def test_get_observation_space_multiple_observations(self): of = obs.ObservationFactory(self.client) of.register_observation(CompliantObs(None)) of.get_observation_space() test_obs = CompliantObs(None) test_obs.observation_space = spaces.Box(low=np.array([5., 6.]), high=np.array([5., 6.])) of.register_observation(test_obs) with self.subTest('from cache'): self.assertEqual(of.get_observation_space(), CompliantObs.observation_space) with self.subTest('regenerate cache'): self.assertEqual(of.get_observation_space(generate=True), spaces.Box(low=np.array([0., 0., 5., 6.]), high=np.array([3., 3., 5., 6.])))
def test_get_obs_normalized(self): of = obs.ObservationFactory(self.client, normalize=True) test_obs1 = CompliantObs(None) test_obs1.compute = lambda: np.array([1.5, 1.5]) of.register_observation(test_obs1) test_obs2 = CompliantObs(None) test_obs2.compute = lambda: np.array([0., 0.]) of.register_observation(test_obs2) test_obs3 = CompliantObs(None) test_obs3.compute = lambda: np.array([3., 3.]) of.register_observation(test_obs3) observations, _ = of.get_obs() np.testing.assert_array_equal(observations, np.array([0., 0., -1., -1., 1., 1.]))
def test_get_obs_no_observations(self): of = obs.ObservationFactory(self.client) with self.assertRaises(ValueError): observations, labels = of.get_obs()
def test_get_observation_space_no_observations(self): of = obs.ObservationFactory(self.client) with self.assertRaises(ValueError): of.get_observation_space()