コード例 #1
0
ファイル: gym_env.py プロジェクト: kristian-georgiev/garage
    def __new__(cls, *args, **kwargs):
        """Returns environment specific wrapper based on input environment type.

        Args:
            *args: Positional arguments
            **kwargs: Keyword arguments

        Returns:
             garage.envs.bullet.BulletEnv: if the environment is a bullet-based
                environment. Else returns a garage.envs.GymEnv
        """
        # pylint: disable=import-outside-toplevel
        # Determine if the input env is a bullet-based gym environment
        env = None
        if 'env' in kwargs:  # env passed as a keyword arg
            env = kwargs['env']
        elif len(args) >= 1:
            # env passed as a positional arg
            env = args[0]

        if isinstance(env, gym.Env):
            if env.spec and hasattr(env.spec,
                                    'id') and env.spec.id.find('Bullet') >= 0:
                from garage.envs.bullet import BulletEnv
                return BulletEnv(*args, **kwargs)
        elif isinstance(env, str):
            if 'Bullet' in env:
                from garage.envs.bullet import BulletEnv
                return BulletEnv(*args, **kwargs)

        return super(GymEnv, cls).__new__(cls)
コード例 #2
0
ファイル: garage_env.py プロジェクト: maciejwolczyk/garage-1
    def __new__(cls, *args, **kwargs):
        """Returns environment specific wrapper based on input environment type.

        Args:
            args: positional arguments
            kwargs: keyword arguments

        Returns:
             garage.envs.bullet.BulletEnv: if the environment is a bullet-based
                environment. Else returns a garage.envs.GarageEnv
        """
        # Determine if the input env is a bullet-based gym environment
        env = None
        if 'env' in kwargs:  # env passed as a keyword arg
            env = kwargs['env']
        elif len(args) >= 1 and isinstance(args[0], TimeLimit):
            # env passed as a positional arg
            # only checks env created by gym.make(), which has type TimeLimit
            env = args[0]
        if env and any(env.env.spec.id == name
                       for name in _get_bullet_env_list()):
            return BulletEnv(env)

        env_name = ''
        if 'env_name' in kwargs:  # env_name as a keyword arg
            env_name = kwargs['env_name']
        elif len(args) >= 2:
            # env_name as a positional arg
            env_name = args[1]
        if env_name != '' and any(env_name == name
                                  for name in _get_bullet_env_list()):
            return BulletEnv(gym.make(env_name))

        return super(GarageEnv, cls).__new__(cls)
コード例 #3
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage
def test_pickle_creates_new_server(env_ids):
    """Test pickleing a Bullet environment creates a new connection.

    If all pickleing create new connections, no repetition of client id
    should be found.
    """
    n_env = 4
    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id == 'RacecarZedBulletEnv-v0':
            bullet_env = BulletEnv(gym.make(env_id, renders=False))
        else:
            bullet_env = BulletEnv(gym.make(env_id))
        envs = [pickle.loads(pickle.dumps(bullet_env)) for _ in range(n_env)]
        id_set = set()

        if hasattr(bullet_env.env, '_pybullet_client'):
            id_set.add(bullet_env.env._pybullet_client._client)
            for e in envs:
                new_id = e._env._pybullet_client._client
                assert new_id not in id_set
                id_set.add(new_id)
        elif hasattr(bullet_env.env, '_p'):
            if isinstance(bullet_env.env._p, BulletClient):
                id_set.add(bullet_env.env._p._client)
                for e in envs:
                    new_id = e._env._p._client
                    assert new_id not in id_set
                    id_set.add(new_id)
            else:
                # Some environments have _p as the pybullet module, and they
                # don't store client id, so can't check here
                pass
コード例 #4
0
    def __new__(cls, *args, **kwargs):
        """Returns environment specific wrapper based on input environment type.

        Args:
            args: positional arguments
            kwargs: keyword arguments

        Returns:
             garage.envs.bullet.BulletEnv: if the environment is a bullet-based
                environment. Else returns a garage.envs.GarageEnv
        """
        # pylint: disable=import-outside-toplevel
        # Determine if the input env is a bullet-based gym environment
        env = None
        if 'env' in kwargs:  # env passed as a keyword arg
            env = kwargs['env']
        elif len(args) >= 1 and isinstance(args[0], TimeLimit):
            # env passed as a positional arg
            # only checks env created by gym.make(), which has type TimeLimit
            env = args[0]
        if env and env.env.spec.id.find('Bullet') >= 0:
            from garage.envs.bullet import BulletEnv
            return BulletEnv(env)

        env_name = ''
        if 'env_name' in kwargs:  # env_name as a keyword arg
            env_name = kwargs['env_name']
        elif len(args) >= 2:
            # env_name as a positional arg
            env_name = args[1]
        if env_name != '' and env_name.find('Bullet') >= 0:
            from garage.envs.bullet import BulletEnv
            return BulletEnv(gym.make(env_name))

        return super(GarageEnv, cls).__new__(cls)
コード例 #5
0
def test_pickleable(env_ids):
    """Test Bullet environments are pickle-able"""
    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        env = BulletEnv(env_name=env_id)
        round_trip = pickle.loads(pickle.dumps(env))
        assert round_trip
        env.close()
コード例 #6
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage-1
def test_pickleable(env_ids):
    """Test Bullet environments are pickle-able"""
    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id in _get_unsupported_env_list():
            pytest.skip('Skip unsupported Bullet environments')
        env = BulletEnv(env_name=env_id)
        round_trip = pickle.loads(pickle.dumps(env))
        assert round_trip
        env.close()
コード例 #7
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage
def test_pickleable(env_ids):
    """Test Bullet environments are pickleable"""
    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id == 'RacecarZedBulletEnv-v0':
            env = BulletEnv(gym.make(env_id, renders=False))
        else:
            env = BulletEnv(gym.make(env_id))
        round_trip = pickle.loads(pickle.dumps(env))
        assert round_trip
コード例 #8
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage-1
def test_time_limit_env():
    """Test BulletEnv emits done signal when time limit expiration occurs.

    After setting max_episode_steps=50, info['BulletEnv.TimeLimitTerminated']
    is expected to be True after 50 steps.

    """
    env = BulletEnv(gym.make('MinitaurBulletEnv-v0'))
    env.env._max_episode_steps = 50
    env.reset()
    for _ in range(50):
        _, _, done, info = env.step(env.spec.action_space.sample())
    assert not done and info['TimeLimit.truncated']
    assert info['BulletEnv.TimeLimitTerminated']
コード例 #9
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage
def test_can_step(env_ids):
    """Test Bullet environments can step"""

    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id == 'KukaCamBulletEnv-v0':
            # Kuka environments calls py_bullet.resetSimulation() in reset()
            # unconditionally, which globally resets other simulations. So
            # only one Kuka environment is tested.
            continue
        if env_id == 'RacecarZedBulletEnv-v0':
            env = BulletEnv(gym.make(env_id, renders=False))
        else:
            env = BulletEnv(gym.make(env_id))
        ob_space = env.observation_space
        act_space = env.action_space
        env.reset()

        ob = ob_space.sample()
        assert ob_space.contains(ob)
        a = act_space.sample()
        assert act_space.contains(a)
        # Skip rendering because it causes TravisCI to run out of memory
        step_env(env, render=False)
        env.close()
コード例 #10
0
def trpo_cartpole_bullet(ctxt=None, seed=1):
    """Train TRPO with Pybullet's CartPoleBulletEnv environment.

    Args:
        ctxt (garage.experiment.ExperimentContext): The experiment
            configuration used by LocalRunner to create the snapshotter.
        seed (int): Used to seed the random number generator to produce
            determinism.

    """
    set_seed(seed)
    with LocalTFRunner(ctxt) as runner:
        env = BulletEnv(
            gym.make('CartPoleBulletEnv-v1',
                     renders=False,
                     discrete_actions=True))

        policy = CategoricalMLPPolicy(name='policy',
                                      env_spec=env.spec,
                                      hidden_sizes=(32, 32))

        baseline = LinearFeatureBaseline(env_spec=env.spec)

        algo = TRPO(env_spec=env.spec,
                    policy=policy,
                    baseline=baseline,
                    max_episode_length=1000,
                    discount=0.99,
                    max_kl_step=0.01)

        runner.setup(algo, env)
        runner.train(n_epochs=100, batch_size=4000)
コード例 #11
0
def step_bullet_kuka_env(n_steps=1000):
    """Load, step, and visualize a Bullet Kuka environment.

    Args:
        n_steps (int): number of steps to run.

    """
    # Construct the environment
    env = BulletEnv(
        gym.make('KukaBulletEnv-v0',
                 renders=True,
                 isDiscrete=True,
                 maxSteps=10000000))

    # Reset the environment and launch the viewer
    env.reset()
    env.render()

    # Step randomly until interrupted
    steps = 0
    while steps < n_steps:
        _, _, done, _ = env.step(env.action_space.sample())
        if done:
            break
        steps += 1
コード例 #12
0
ファイル: test_bullet_env.py プロジェクト: j-donahue/garage-1
def test_pickle_creates_new_server(env_ids):
    """Test pickling a Bullet environment creates a new connection.

    If all pickling create new connections, no repetition of client id
    should be found.
    """
    n_env = 4
    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id in _get_unsupported_env_list():
            pytest.skip('Skip unsupported Bullet environments')
        bullet_env = BulletEnv(env_name=env_id)
        envs = [pickle.loads(pickle.dumps(bullet_env)) for _ in range(n_env)]
        id_set = set()

        if hasattr(bullet_env.env, '_pybullet_client'):
            id_set.add(bullet_env.env._pybullet_client._client)
            for e in envs:
                new_id = e._env._pybullet_client._client
                assert new_id not in id_set
                id_set.add(new_id)
        elif hasattr(bullet_env.env, '_p'):
            if isinstance(bullet_env.env._p, BulletClient):
                id_set.add(bullet_env.env._p._client)
                for e in envs:
                    new_id = e._env._p._client
                    assert new_id not in id_set
                    id_set.add(new_id)
            else:
                # Some environments have _p as the pybullet module, and they
                # don't store client id, so can't check here
                pass

        for env in envs:
            env.close()
コード例 #13
0
ファイル: test_bullet_env.py プロジェクト: ziyiwu9494/garage
def test_can_step(env_ids):
    """Test Bullet environments can step"""

    for env_id in env_ids:
        # extract id string
        env_id = env_id.replace('- ', '')
        if env_id in ('KukaCamBulletEnv-v0', 'KukaDiverseObjectGrasping-v0'):
            # Kuka environments calls pybullet.resetSimulation() in reset()
            # unconditionally, which globally resets other simulations. So
            # only one Kuka environment is tested.
            continue
        env = BulletEnv(env_id)
        ob_space = env.observation_space
        act_space = env.action_space
        env.reset()

        ob = ob_space.sample()
        assert ob_space.contains(ob)
        a = act_space.sample()
        assert act_space.contains(a)
        # Skip rendering because it causes TravisCI to run out of memory
        step_env(env, visualize=False)
        env.close()