def __new__(cls, *args, **kwargs): """Returns environment specific wrapper based on input environment type. Args: *args: Positional arguments **kwargs: Keyword arguments Returns: garage.envs.bullet.BulletEnv: if the environment is a bullet-based environment. Else returns a garage.envs.GymEnv """ # pylint: disable=import-outside-toplevel # Determine if the input env is a bullet-based gym environment env = None if 'env' in kwargs: # env passed as a keyword arg env = kwargs['env'] elif len(args) >= 1: # env passed as a positional arg env = args[0] if isinstance(env, gym.Env): if env.spec and hasattr(env.spec, 'id') and env.spec.id.find('Bullet') >= 0: from garage.envs.bullet import BulletEnv return BulletEnv(*args, **kwargs) elif isinstance(env, str): if 'Bullet' in env: from garage.envs.bullet import BulletEnv return BulletEnv(*args, **kwargs) return super(GymEnv, cls).__new__(cls)
def __new__(cls, *args, **kwargs): """Returns environment specific wrapper based on input environment type. Args: args: positional arguments kwargs: keyword arguments Returns: garage.envs.bullet.BulletEnv: if the environment is a bullet-based environment. Else returns a garage.envs.GarageEnv """ # Determine if the input env is a bullet-based gym environment env = None if 'env' in kwargs: # env passed as a keyword arg env = kwargs['env'] elif len(args) >= 1 and isinstance(args[0], TimeLimit): # env passed as a positional arg # only checks env created by gym.make(), which has type TimeLimit env = args[0] if env and any(env.env.spec.id == name for name in _get_bullet_env_list()): return BulletEnv(env) env_name = '' if 'env_name' in kwargs: # env_name as a keyword arg env_name = kwargs['env_name'] elif len(args) >= 2: # env_name as a positional arg env_name = args[1] if env_name != '' and any(env_name == name for name in _get_bullet_env_list()): return BulletEnv(gym.make(env_name)) return super(GarageEnv, cls).__new__(cls)
def test_pickle_creates_new_server(env_ids): """Test pickleing a Bullet environment creates a new connection. If all pickleing create new connections, no repetition of client id should be found. """ n_env = 4 for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id == 'RacecarZedBulletEnv-v0': bullet_env = BulletEnv(gym.make(env_id, renders=False)) else: bullet_env = BulletEnv(gym.make(env_id)) envs = [pickle.loads(pickle.dumps(bullet_env)) for _ in range(n_env)] id_set = set() if hasattr(bullet_env.env, '_pybullet_client'): id_set.add(bullet_env.env._pybullet_client._client) for e in envs: new_id = e._env._pybullet_client._client assert new_id not in id_set id_set.add(new_id) elif hasattr(bullet_env.env, '_p'): if isinstance(bullet_env.env._p, BulletClient): id_set.add(bullet_env.env._p._client) for e in envs: new_id = e._env._p._client assert new_id not in id_set id_set.add(new_id) else: # Some environments have _p as the pybullet module, and they # don't store client id, so can't check here pass
def __new__(cls, *args, **kwargs): """Returns environment specific wrapper based on input environment type. Args: args: positional arguments kwargs: keyword arguments Returns: garage.envs.bullet.BulletEnv: if the environment is a bullet-based environment. Else returns a garage.envs.GarageEnv """ # pylint: disable=import-outside-toplevel # Determine if the input env is a bullet-based gym environment env = None if 'env' in kwargs: # env passed as a keyword arg env = kwargs['env'] elif len(args) >= 1 and isinstance(args[0], TimeLimit): # env passed as a positional arg # only checks env created by gym.make(), which has type TimeLimit env = args[0] if env and env.env.spec.id.find('Bullet') >= 0: from garage.envs.bullet import BulletEnv return BulletEnv(env) env_name = '' if 'env_name' in kwargs: # env_name as a keyword arg env_name = kwargs['env_name'] elif len(args) >= 2: # env_name as a positional arg env_name = args[1] if env_name != '' and env_name.find('Bullet') >= 0: from garage.envs.bullet import BulletEnv return BulletEnv(gym.make(env_name)) return super(GarageEnv, cls).__new__(cls)
def test_pickleable(env_ids): """Test Bullet environments are pickle-able""" for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') env = BulletEnv(env_name=env_id) round_trip = pickle.loads(pickle.dumps(env)) assert round_trip env.close()
def test_pickleable(env_ids): """Test Bullet environments are pickle-able""" for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id in _get_unsupported_env_list(): pytest.skip('Skip unsupported Bullet environments') env = BulletEnv(env_name=env_id) round_trip = pickle.loads(pickle.dumps(env)) assert round_trip env.close()
def test_pickleable(env_ids): """Test Bullet environments are pickleable""" for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id == 'RacecarZedBulletEnv-v0': env = BulletEnv(gym.make(env_id, renders=False)) else: env = BulletEnv(gym.make(env_id)) round_trip = pickle.loads(pickle.dumps(env)) assert round_trip
def test_time_limit_env(): """Test BulletEnv emits done signal when time limit expiration occurs. After setting max_episode_steps=50, info['BulletEnv.TimeLimitTerminated'] is expected to be True after 50 steps. """ env = BulletEnv(gym.make('MinitaurBulletEnv-v0')) env.env._max_episode_steps = 50 env.reset() for _ in range(50): _, _, done, info = env.step(env.spec.action_space.sample()) assert not done and info['TimeLimit.truncated'] assert info['BulletEnv.TimeLimitTerminated']
def test_can_step(env_ids): """Test Bullet environments can step""" for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id == 'KukaCamBulletEnv-v0': # Kuka environments calls py_bullet.resetSimulation() in reset() # unconditionally, which globally resets other simulations. So # only one Kuka environment is tested. continue if env_id == 'RacecarZedBulletEnv-v0': env = BulletEnv(gym.make(env_id, renders=False)) else: env = BulletEnv(gym.make(env_id)) ob_space = env.observation_space act_space = env.action_space env.reset() ob = ob_space.sample() assert ob_space.contains(ob) a = act_space.sample() assert act_space.contains(a) # Skip rendering because it causes TravisCI to run out of memory step_env(env, render=False) env.close()
def trpo_cartpole_bullet(ctxt=None, seed=1): """Train TRPO with Pybullet's CartPoleBulletEnv environment. Args: ctxt (garage.experiment.ExperimentContext): The experiment configuration used by LocalRunner to create the snapshotter. seed (int): Used to seed the random number generator to produce determinism. """ set_seed(seed) with LocalTFRunner(ctxt) as runner: env = BulletEnv( gym.make('CartPoleBulletEnv-v1', renders=False, discrete_actions=True)) policy = CategoricalMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) baseline = LinearFeatureBaseline(env_spec=env.spec) algo = TRPO(env_spec=env.spec, policy=policy, baseline=baseline, max_episode_length=1000, discount=0.99, max_kl_step=0.01) runner.setup(algo, env) runner.train(n_epochs=100, batch_size=4000)
def step_bullet_kuka_env(n_steps=1000): """Load, step, and visualize a Bullet Kuka environment. Args: n_steps (int): number of steps to run. """ # Construct the environment env = BulletEnv( gym.make('KukaBulletEnv-v0', renders=True, isDiscrete=True, maxSteps=10000000)) # Reset the environment and launch the viewer env.reset() env.render() # Step randomly until interrupted steps = 0 while steps < n_steps: _, _, done, _ = env.step(env.action_space.sample()) if done: break steps += 1
def test_pickle_creates_new_server(env_ids): """Test pickling a Bullet environment creates a new connection. If all pickling create new connections, no repetition of client id should be found. """ n_env = 4 for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id in _get_unsupported_env_list(): pytest.skip('Skip unsupported Bullet environments') bullet_env = BulletEnv(env_name=env_id) envs = [pickle.loads(pickle.dumps(bullet_env)) for _ in range(n_env)] id_set = set() if hasattr(bullet_env.env, '_pybullet_client'): id_set.add(bullet_env.env._pybullet_client._client) for e in envs: new_id = e._env._pybullet_client._client assert new_id not in id_set id_set.add(new_id) elif hasattr(bullet_env.env, '_p'): if isinstance(bullet_env.env._p, BulletClient): id_set.add(bullet_env.env._p._client) for e in envs: new_id = e._env._p._client assert new_id not in id_set id_set.add(new_id) else: # Some environments have _p as the pybullet module, and they # don't store client id, so can't check here pass for env in envs: env.close()
def test_can_step(env_ids): """Test Bullet environments can step""" for env_id in env_ids: # extract id string env_id = env_id.replace('- ', '') if env_id in ('KukaCamBulletEnv-v0', 'KukaDiverseObjectGrasping-v0'): # Kuka environments calls pybullet.resetSimulation() in reset() # unconditionally, which globally resets other simulations. So # only one Kuka environment is tested. continue env = BulletEnv(env_id) ob_space = env.observation_space act_space = env.action_space env.reset() ob = ob_space.sample() assert ob_space.contains(ob) a = act_space.sample() assert act_space.contains(a) # Skip rendering because it causes TravisCI to run out of memory step_env(env, visualize=False) env.close()