def test_wrapper(): env = GridWorld() wrapped = Wrapper(env) assert isinstance(wrapped, Model) assert wrapped.is_online() assert wrapped.is_generative() # calling some functions wrapped.reset() wrapped.step(wrapped.action_space.sample()) wrapped.sample(wrapped.observation_space.sample(), wrapped.action_space.sample())
def test_render2d_interface_wrapped(ModelClass): env = Wrapper(ModelClass()) if isinstance(env.env, RenderInterface2D): env.enable_rendering() if env.is_online(): for _ in range(2): state = env.reset() for _ in range(5): assert env.observation_space.contains(state) action = env.action_space.sample() next_s, _, _, _ = env.step(action) state = next_s env.render(loop=False) env.save_video("test_video.mp4") env.clear_render_buffer() try: os.remove("test_video.mp4") except Exception: pass