def get_multi_task_env( batch_size: int = 1, ) -> Environment[RLSetting.Observations, RLSetting.Actions, RLSetting.Rewards]: def single_env_fn() -> gym.Env: env = gym.make("CartPole-v0") env = TimeLimit(env, max_episode_steps=10) env = MultiTaskEnvironment( env, task_schedule={ 0: {"length": 0.1}, 100: {"length": 0.2}, 200: {"length": 0.3}, 300: {"length": 0.4}, 400: {"length": 0.5}, }, add_task_id_to_obs=True, new_random_task_on_reset=True, ) return env batch_size = 1 env = SyncVectorEnv([single_env_fn for _ in range(batch_size)]) from sequoia.common.gym_wrappers import AddDoneToObservation from sequoia.settings.active import TypedObjectsWrapper env = AddDoneToObservation(env) # Wrap the observations so they appear as though they are from the given setting. env = TypedObjectsWrapper( env, observations_type=RLSetting.Observations, actions_type=RLSetting.Actions, rewards_type=RLSetting.Rewards, ) env.seed(123) return env
def test_space_with_tuple_observations(batch_size: int, n_workers: Optional[int]): def make_env(): env = gym.make("Breakout-v0") env = MultiTaskEnvironment( env, add_task_id_to_obs=True, add_task_dict_to_info=True ) return env env_fn = make_env env_fns = [env_fn for _ in range(batch_size)] # from gym.vector.utils import batch_space # env = BatchedVectorEnv(env_fns, n_workers=n_workers) from gym.vector import SyncVectorEnv env = SyncVectorEnv(env_fns) # FIXME: debugging # env = AsyncVectorEnv(env_fns) env.seed(123) assert env.observation_space == spaces.Dict( x=spaces.Box(0, 255, (batch_size, 210, 160, 3), np.uint8), task_labels=spaces.MultiDiscrete(np.ones(batch_size)), ) assert env.single_observation_space == spaces.Dict( x=spaces.Box(0, 255, (210, 160, 3), np.uint8), task_labels=spaces.Discrete(1) ) obs = env.reset() assert obs["x"].shape == env.observation_space["x"].shape assert obs["task_labels"].shape == env.observation_space["task_labels"].shape assert obs in env.observation_space actions = env.action_space.sample() step_obs, rewards, done, info = env.step(actions) assert step_obs in env.observation_space assert len(rewards) == batch_size assert len(done) == batch_size assert all([isinstance(v, bool) for v in done.tolist()]), [type(v) for v in done] assert len(info) == batch_size