コード例 #1
0
ファイル: test_async_vector_env.py プロジェクト: odunboye/gym
def test_step_async_vector_env(shared_memory, use_single_action_space):
    env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
    try:
        env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
        observations = env.reset()
        if use_single_action_space:
            actions = [env.single_action_space.sample() for _ in range(8)]
        else:
            actions = env.action_space.sample()
        observations, rewards, dones, _ = env.step(actions)
    finally:
        env.close()

    assert isinstance(env.observation_space, Box)
    assert isinstance(observations, np.ndarray)
    assert observations.dtype == env.observation_space.dtype
    assert observations.shape == (8, ) + env.single_observation_space.shape
    assert observations.shape == env.observation_space.shape

    assert isinstance(rewards, np.ndarray)
    assert isinstance(rewards[0], (float, np.floating))
    assert rewards.ndim == 1
    assert rewards.size == 8

    assert isinstance(dones, np.ndarray)
    assert dones.dtype == np.bool_
    assert dones.ndim == 1
    assert dones.size == 8
コード例 #2
0
def test_custom_space_async_vector_env():
    env_fns = [make_custom_space_env(i) for i in range(4)]
    try:
        env = AsyncVectorEnv(env_fns, shared_memory=False)
        reset_observations = env.reset()

        assert isinstance(env.single_action_space, CustomSpace)
        assert isinstance(env.action_space, Tuple)

        actions = ("action-2", "action-3", "action-5", "action-7")
        step_observations, rewards, dones, _ = env.step(actions)
    finally:
        env.close()

    assert isinstance(env.single_observation_space, CustomSpace)
    assert isinstance(env.observation_space, Tuple)

    assert isinstance(reset_observations, tuple)
    assert reset_observations == ("reset", "reset", "reset", "reset")

    assert isinstance(step_observations, tuple)
    assert step_observations == (
        "step(action-2)",
        "step(action-3)",
        "step(action-5)",
        "step(action-7)",
    )
コード例 #3
0
def main():
    env_id = "Ant-v3"
    num_envs = 5
    vec_env = AsyncVectorEnv([make_env(env_id) for i in range(num_envs)])

    state = vec_env.reset()

    for i in range(5000):
        action = vec_env.action_space.sample()
        state, reward, done, _ = vec_env.step(action)
        if any(done):
            done_idx = [i for i, e in enumerate(done) if e]
            print(f"{done_idx}")
コード例 #4
0
ファイル: test_vector_env.py プロジェクト: MalteEbner/gym
def test_vector_env_equal(shared_memory):
    env_fns = [make_env("CubeCrash-v0", i) for i in range(4)]
    num_steps = 100
    try:
        async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
        sync_env = SyncVectorEnv(env_fns)

        async_env.seed(0)
        sync_env.seed(0)

        assert async_env.num_envs == sync_env.num_envs
        assert async_env.observation_space == sync_env.observation_space
        assert async_env.single_observation_space == sync_env.single_observation_space
        assert async_env.action_space == sync_env.action_space
        assert async_env.single_action_space == sync_env.single_action_space

        async_observations = async_env.reset()
        sync_observations = sync_env.reset()
        assert np.all(async_observations == sync_observations)

        for _ in range(num_steps):
            actions = async_env.action_space.sample()
            assert actions in sync_env.action_space

            # fmt: off
            async_observations, async_rewards, async_dones, async_infos = async_env.step(
                actions)
            sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(
                actions)
            # fmt: on

            for idx in range(len(sync_dones)):
                if sync_dones[idx]:
                    assert "terminal_observation" in async_infos[idx]
                    assert "terminal_observation" in sync_infos[idx]
                    assert sync_dones[idx]

            assert np.all(async_observations == sync_observations)
            assert np.all(async_rewards == sync_rewards)
            assert np.all(async_dones == sync_dones)

    finally:
        async_env.close()
        sync_env.close()
def test_custom_space_async_vector_env():
    env_fns = [make_custom_space_env(i) for i in range(4)]
    try:
        env = AsyncVectorEnv(env_fns, shared_memory=False)
        reset_observations = env.reset()
        actions = ('action-2', 'action-3', 'action-5', 'action-7')
        step_observations, rewards, dones, _ = env.step(actions)
    finally:
        env.close()

    assert isinstance(env.single_observation_space, CustomSpace)
    assert isinstance(env.observation_space, Tuple)

    assert isinstance(reset_observations, tuple)
    assert reset_observations == ('reset', 'reset', 'reset', 'reset')

    assert isinstance(step_observations, tuple)
    assert step_observations == ('step(action-2)', 'step(action-3)',
                                 'step(action-5)', 'step(action-7)')
コード例 #6
0
def test_vector_env_equal(shared_memory):
    env_fns = [make_env("CartPole-v1", i) for i in range(4)]
    num_steps = 100
    try:
        async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
        sync_env = SyncVectorEnv(env_fns)

        assert async_env.num_envs == sync_env.num_envs
        assert async_env.observation_space == sync_env.observation_space
        assert async_env.single_observation_space == sync_env.single_observation_space
        assert async_env.action_space == sync_env.action_space
        assert async_env.single_action_space == sync_env.single_action_space

        async_observations = async_env.reset(seed=0)
        sync_observations = sync_env.reset(seed=0)
        assert np.all(async_observations == sync_observations)

        for _ in range(num_steps):
            actions = async_env.action_space.sample()
            assert actions in sync_env.action_space

            # fmt: off
            async_observations, async_rewards, async_dones, async_infos = async_env.step(
                actions)
            sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(
                actions)
            # fmt: on

            if any(sync_dones):
                assert "terminal_observation" in async_infos
                assert "_terminal_observation" in async_infos
                assert "terminal_observation" in sync_infos
                assert "_terminal_observation" in sync_infos

            assert np.all(async_observations == sync_observations)
            assert np.all(async_rewards == sync_rewards)
            assert np.all(async_dones == sync_dones)

    finally:
        async_env.close()
        sync_env.close()
コード例 #7
0
def test_vector_env_equal(shared_memory):
    env_fns = [make_env('CubeCrash-v0', i) for i in range(4)]
    num_steps = 100
    try:
        async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
        sync_env = SyncVectorEnv(env_fns)

        async_env.seed(0)
        sync_env.seed(0)

        assert async_env.num_envs == sync_env.num_envs
        assert async_env.observation_space == sync_env.observation_space
        assert async_env.single_observation_space == sync_env.single_observation_space
        assert async_env.action_space == sync_env.action_space
        assert async_env.single_action_space == sync_env.single_action_space

        async_observations = async_env.reset()
        sync_observations = sync_env.reset()
        assert np.all(async_observations == sync_observations)

        for _ in range(num_steps):
            actions = async_env.action_space.sample()
            assert actions in sync_env.action_space

            async_observations, async_rewards, async_dones, _ = async_env.step(
                actions)
            sync_observations, sync_rewards, sync_dones, _ = sync_env.step(
                actions)

            assert np.all(async_observations == sync_observations)
            assert np.all(async_rewards == sync_rewards)
            assert np.all(async_dones == sync_dones)

    finally:
        async_env.close()
        sync_env.close()