Ejemplo n.º 1
0
class TestSingleWrappedEnv:
    def setup_method(self):
        self.env = PointEnv()
        obs, _ = self.env.reset()
        self.base_len = len(obs)
        self.n_total_tasks = 5
        self.task_index = 1
        self.wrapped = TaskOnehotWrapper(self.env, self.task_index,
                                         self.n_total_tasks)

    def test_produces_correct_onehots(self):
        obs, _ = self.wrapped.reset()
        assert len(obs) == self.base_len + self.n_total_tasks
        assert (obs[-self.n_total_tasks:] == np.array([0, 1, 0, 0, 0])).all()

    def test_spec_obs_space(self):
        obs, _ = self.wrapped.reset()
        assert self.wrapped.observation_space.contains(obs)
        assert self.wrapped.spec.observation_space.contains(obs)
        assert (self.wrapped.spec.observation_space ==
                self.wrapped.observation_space)

    def test_visualization(self):
        assert self.env.render_modes == self.wrapped.render_modes
        mode = self.env.render_modes[0]
        assert self.env.render(mode) == self.wrapped.render(mode)
Ejemplo n.º 2
0
    def test_visualization(self):
        inner_env = PointEnv(goal=(1., 2.))
        env = NormalizedEnv(inner_env)

        env.visualize()
        env.reset()
        assert inner_env.render_modes == env.render_modes
        mode = inner_env.render_modes[0]
        assert inner_env.render(mode) == env.render(mode)
Ejemplo n.º 3
0
    def test_visualization(self):
        env = PointEnv()
        wrapped_env = RL2Env(env)

        assert env.render_modes == wrapped_env.render_modes
        mode = env.render_modes[0]
        assert env.render(mode) == wrapped_env.render(mode)

        wrapped_env.reset()
        wrapped_env.visualize()
        wrapped_env.step(wrapped_env.action_space.sample())
        wrapped_env.close()