예제 #1
0
 def setup_method(self):
     self.env = PointEnv()
     self.base_len = len(self.env.reset())
     self.n_total_tasks = 5
     self.task_index = 1
     self.wrapped = TaskOnehotWrapper(self.env, self.task_index,
                                      self.n_total_tasks)
예제 #2
0
class TestSingleWrappedEnv:
    def setup_method(self):
        self.env = PointEnv()
        obs, _ = self.env.reset()
        self.base_len = len(obs)
        self.n_total_tasks = 5
        self.task_index = 1
        self.wrapped = TaskOnehotWrapper(self.env, self.task_index,
                                         self.n_total_tasks)

    def test_produces_correct_onehots(self):
        obs, _ = self.wrapped.reset()
        assert len(obs) == self.base_len + self.n_total_tasks
        assert (obs[-self.n_total_tasks:] == np.array([0, 1, 0, 0, 0])).all()

    def test_spec_obs_space(self):
        obs, _ = self.wrapped.reset()
        assert self.wrapped.observation_space.contains(obs)
        assert self.wrapped.spec.observation_space.contains(obs)
        assert (self.wrapped.spec.observation_space ==
                self.wrapped.observation_space)

    def test_visualization(self):
        assert self.env.render_modes == self.wrapped.render_modes
        mode = self.env.render_modes[0]
        assert self.env.render(mode) == self.wrapped.render(mode)
예제 #3
0
def test_wrapped_env_list_produces_correct_onehots():
    envs = [PointEnv(), PointEnv(), PointEnv(), PointEnv()]
    base_len = len(envs[0].reset())
    n_total_tasks = len(envs)
    wrapped = TaskOnehotWrapper.wrap_env_list(envs)
    assert len(wrapped) == n_total_tasks
    for i, env in enumerate(wrapped):
        obs = env.reset()
        assert len(obs) == base_len + n_total_tasks
        onehot = np.zeros(n_total_tasks)
        onehot[i] = 1.
        assert (obs[-n_total_tasks:] == onehot).all()
        next_obs, _, _, _ = env.step(env.action_space.sample())
        assert (next_obs[-n_total_tasks:] == onehot).all()
예제 #4
0
def test_wrapped_env_cons_list_produces_correct_onehots():
    env_cons = [PointEnv] * 6
    base_len = 3
    n_total_tasks = len(env_cons)
    wrapped_cons = TaskOnehotWrapper.wrap_env_cons_list(env_cons)
    wrapped_envs = [cons() for cons in wrapped_cons]
    assert len(wrapped_envs) == n_total_tasks
    for i, env in enumerate(wrapped_envs):
        obs, _ = env.reset()
        assert len(obs) == base_len + n_total_tasks
        onehot = np.zeros(n_total_tasks)
        onehot[i] = 1.
        assert (obs[-n_total_tasks:] == onehot).all()
        next_obs = env.step(env.action_space.sample()).observation
        assert (next_obs[-n_total_tasks:] == onehot).all()
예제 #5
0
        def wrap(env, task):
            """Wrap an environment in a metaworld benchmark.

            Args:
                env (gym.Env): A metaworld / gym environment.
                task (metaworld.Task): A metaworld task.

            Returns:
                garage.Env: The wrapped environment.

            """
            env = GymEnv(env, max_episode_length=env.max_path_length)
            env = TaskNameWrapper(env, task_name=task.env_name)
            if add_env_onehot:
                env = TaskOnehotWrapper(env,
                                        task_index=task_indices[task.env_name],
                                        n_total_tasks=len(task_indices))
            if inner_wrapper is not None:
                env = inner_wrapper(env, task)
            return env