class TestSingleWrappedEnv: def setup_method(self): self.env = PointEnv() obs, _ = self.env.reset() self.base_len = len(obs) self.n_total_tasks = 5 self.task_index = 1 self.wrapped = TaskOnehotWrapper(self.env, self.task_index, self.n_total_tasks) def test_produces_correct_onehots(self): obs, _ = self.wrapped.reset() assert len(obs) == self.base_len + self.n_total_tasks assert (obs[-self.n_total_tasks:] == np.array([0, 1, 0, 0, 0])).all() def test_spec_obs_space(self): obs, _ = self.wrapped.reset() assert self.wrapped.observation_space.contains(obs) assert self.wrapped.spec.observation_space.contains(obs) assert (self.wrapped.spec.observation_space == self.wrapped.observation_space) def test_visualization(self): assert self.env.render_modes == self.wrapped.render_modes mode = self.env.render_modes[0] assert self.env.render(mode) == self.wrapped.render(mode)
def test_observation_dimension(self): env = PointEnv() wrapped_env = RL2Env(PointEnv()) assert wrapped_env.spec.observation_space.shape[0] == ( env.observation_space.shape[0] + env.action_space.shape[0] + 2) obs = env.reset() obs2 = wrapped_env.reset() assert obs.shape[0] + env.action_space.shape[0] + 2 == obs2.shape[0] obs, _, _, _ = env.step(env.action_space.sample()) obs2, _, _, _ = wrapped_env.step(env.action_space.sample()) assert obs.shape[0] + env.action_space.shape[0] + 2 == obs2.shape[0]