def test_last_video(self): """ Check the last video is flushed Ref: https://gitlab.com/ymd_h/gym-notebook-wrapper/-/issues/2 """ env = gnwrapper.Monitor(gym.make('CartPole-v1'), directory="./test_last_videos/", video_callable=lambda ep: True) env.reset() n_video = 1 for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() n_video += 1 env.display() self.assertEqual(len(env.videos),n_video) for f in env.videos: with self.subTest(file=f[0]): self.assertTrue(os.path.exists(f[0])) # Can run normally after env.reset() for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() env.display()
def test_display(self): env = gnwrapper.Monitor(gym.make('CartPole-v1'),directory="./") env.reset() for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() env.display()
def test_reset_videos(self): env = gnwrapper.Monitor(gym.make('CartPole-v1'), directory="./test_reset_videos/") env.reset() for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() self.assertNotEqual(len(env.videos),0) env.display(reset=True) self.assertEqual(len(env.videos),0)
def test_default_directory(self): env = gnwrapper.Monitor(gym.make('CartPole-v1')) env.reset() for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() for f in env.videos: with self.subTest(file=f[0]): self.assertIsNotNone(re.search(r"[0-9]{8}-[0-9]{6}",f[0])) env.display()
def test_display_after_close(self): """ Display after close """ env = gnwrapper.Monitor(gym.make('CartPole-v1'), directory="./test_display_after_close/", video_callable=lambda ep: True) env.reset() for _ in range(100): o, r, d, i = env.step(env.action_space.sample()) if d: env.reset() env.close() env.display()
def test_KeyboardInterrupt(self): """ After KeyboardInterrupt, notebook kernel dies. Ref: https://gitlab.com/ymd_h/gym-notebook-wrapper/-/issues/4 """ CartPole = "gym.envs.classic_control.cartpole.CartPoleEnv" VideoRecorder = "gym.wrappers.monitoring.video_recorder.VideoRecorder" env = gnwrapper.Monitor(gym.make('CartPole-v1'), directory="./test_keyboard_interrupt/", video_callable=lambda ep: True) for func in [f"{CartPole}.step", f"{VideoRecorder}.capture_frame"]: env.reset() with self.subTest(function=func): with patch(func, MagicMock(side_effect=KeyboardInterrupt)): with self.assertRaises(KeyboardInterrupt): env.step(env.action_space.sample()) env.reset() env.step(env.action_space.sample()) env.display() env.render(mode='rgb_array') for func in [f"{CartPole}.reset", "os.waitpid"]: env.reset() with self.subTest(function=func): with patch(func, MagicMock(side_effect=KeyboardInterrupt)): with self.assertRaises(KeyboardInterrupt): env.reset() env.reset() env.step(env.action_space.sample()) env.display() env.render(mode='rgb_array')
import torch import numpy as np import gym import gnwrapper from torch import nn from stable_baselines3 import A2C from stable_baselines3.a2c import MlpPolicy from stable_baselines3.common.env_checker import check_env from gym_pybullet_drones.utils.Logger import Logger from gym_pybullet_drones.envs.single_agent_rl.TakeoffAviary import TakeoffAviary from gym_pybullet_drones.utils.utils import sync, str2bool env = gym.make("takeoff-aviary-v0") monitor_env = gnwrapper.Monitor(gym.make("takeoff-aviary-v0"), size=( 400, 300), directory='.', force=True, video_callable=lambda ep: True) episode_max_steps = 300 for episode_idx in range(10): monitor_env.reset() total_rew = 0. for _ in range(episode_max_steps): _, rew, done, _ = monitor_env.step(monitor_env.action_space.sample()) total_rew += rew if done: break print("iter={0: 3d} total reward: {1: 4.4f}".format( episode_idx, total_rew)) monitor_env.display()