def test_wrapper():
    env = GridWorld()
    wrapped = Wrapper(env)
    assert isinstance(wrapped, Model)
    assert wrapped.is_online()
    assert wrapped.is_generative()

    # calling some functions
    wrapped.reset()
    wrapped.step(wrapped.action_space.sample())
    wrapped.sample(wrapped.observation_space.sample(),
                   wrapped.action_space.sample())
Exemple #2
0
def test_render2d_interface_wrapped(ModelClass):
    env = Wrapper(ModelClass())

    if isinstance(env.env, RenderInterface2D):
        env.enable_rendering()
        if env.is_online():
            for _ in range(2):
                state = env.reset()
                for _ in range(5):
                    assert env.observation_space.contains(state)
                    action = env.action_space.sample()
                    next_s, _, _, _ = env.step(action)
                    state = next_s
                env.render(loop=False)
            env.save_video("test_video.mp4")
            env.clear_render_buffer()
        try:
            os.remove("test_video.mp4")
        except Exception:
            pass