コード例 #1
0
def get_multi_task_env(
    batch_size: int = 1,
) -> Environment[RLSetting.Observations, RLSetting.Actions, RLSetting.Rewards]:
    def single_env_fn() -> gym.Env:
        env = gym.make("CartPole-v0")
        env = TimeLimit(env, max_episode_steps=10)
        env = MultiTaskEnvironment(
            env,
            task_schedule={
                0: {"length": 0.1},
                100: {"length": 0.2},
                200: {"length": 0.3},
                300: {"length": 0.4},
                400: {"length": 0.5},
            },
            add_task_id_to_obs=True,
            new_random_task_on_reset=True,
        )
        return env

    batch_size = 1
    env = SyncVectorEnv([single_env_fn for _ in range(batch_size)])
    from sequoia.common.gym_wrappers import AddDoneToObservation
    from sequoia.settings.active import TypedObjectsWrapper

    env = AddDoneToObservation(env)
    # Wrap the observations so they appear as though they are from the given setting.
    env = TypedObjectsWrapper(
        env,
        observations_type=RLSetting.Observations,
        actions_type=RLSetting.Actions,
        rewards_type=RLSetting.Rewards,
    )
    env.seed(123)
    return env
コード例 #2
0
def test_space_with_tuple_observations(batch_size: int, n_workers: Optional[int]):
    def make_env():
        env = gym.make("Breakout-v0")
        env = MultiTaskEnvironment(
            env, add_task_id_to_obs=True, add_task_dict_to_info=True
        )
        return env

    env_fn = make_env
    env_fns = [env_fn for _ in range(batch_size)]

    # from gym.vector.utils import batch_space
    # env = BatchedVectorEnv(env_fns, n_workers=n_workers)
    from gym.vector import SyncVectorEnv
    env = SyncVectorEnv(env_fns) # FIXME: debugging
    # env = AsyncVectorEnv(env_fns)
    env.seed(123)

    assert env.observation_space == spaces.Dict(
        x=spaces.Box(0, 255, (batch_size, 210, 160, 3), np.uint8),
        task_labels=spaces.MultiDiscrete(np.ones(batch_size)),
    )

    assert env.single_observation_space == spaces.Dict(
        x=spaces.Box(0, 255, (210, 160, 3), np.uint8),
        task_labels=spaces.Discrete(1)
    )

    obs = env.reset()
    assert obs["x"].shape == env.observation_space["x"].shape
    assert obs["task_labels"].shape == env.observation_space["task_labels"].shape
    assert obs in env.observation_space

    actions = env.action_space.sample()
    step_obs, rewards, done, info = env.step(actions)
    assert step_obs in env.observation_space

    assert len(rewards) == batch_size
    assert len(done) == batch_size
    assert all([isinstance(v, bool) for v in done.tolist()]), [type(v) for v in done]
    assert len(info) == batch_size