def __init__(self, config: configs.Solo8BaseConfig, use_gui: bool): """Create a solo8 env. Args: config (configs.Solo8BaseConfig): The SoloConfig. Defaults to None. use_gui (bool): Whether or not to show the pybullet GUI. Defaults to False. """ self.config = config self.client = bc.BulletClient( connection_mode=p.GUI if use_gui else p.DIRECT) self.client.setAdditionalSearchPath(pbd.getDataPath()) self.client.setGravity(*self.config.gravity) if self.config.dt: self.client.setPhysicsEngineParameter(fixedTimeStep=self.config.dt, numSubSteps=1) else: self.client.setRealTimeSimulation(1) self.client_configuration() self.plane = self.client.loadURDF('plane.urdf') self.load_bodies() self.obs_factory = obs.ObservationFactory(self.client) self.reward_factory = rewards.RewardFactory(self.client) self.termination_factory = terms.TerminationFactory() self.reset(init_call=True)
def test_is_terminated(self): dummy_termination_true_1 = DummyTermination(0, True) dummy_termination_false_1 = DummyTermination(0, False) dummy_termination_false_2 = DummyTermination(0, False) with self.subTest('True and False Termination'): termination_factory = termination.TerminationFactory() termination_factory.register_termination( dummy_termination_true_1, dummy_termination_false_1) self.assertTrue(termination_factory.is_terminated()) with self.subTest('False and False Termination'): termination_factory = termination.TerminationFactory() termination_factory.register_termination( dummy_termination_false_1, dummy_termination_false_2) self.assertFalse(termination_factory.is_terminated())
def test_reset(self): dummy_termination = DummyTermination(0, True) termination_factory = termination.TerminationFactory() termination_factory.register_termination(dummy_termination) self.assertEqual(1, dummy_termination.reset_counter) termination_factory.reset() self.assertEqual(2, dummy_termination.reset_counter)
def test_register_termination(self): termination_factory = termination.TerminationFactory() dummy_termination_1 = DummyTermination(0, True) dummy_termination_2 = DummyTermination(0, False) termination_factory.register_termination(dummy_termination_1, dummy_termination_2) self.assertEqual(len(termination_factory._terminations), 2)
def test_initialization(self): termination_factory = termination.TerminationFactory() self.assertListEqual(termination_factory._terminations, []) self.assertTrue(termination_factory._use_or)