Ejemplo n.º 1
0
    def test_last_video(self):
        """
        Check the last video is flushed

        Ref: https://gitlab.com/ymd_h/gym-notebook-wrapper/-/issues/2
        """
        env = gnwrapper.Monitor(gym.make('CartPole-v1'),
                                directory="./test_last_videos/",
                                video_callable=lambda ep: True)
        env.reset()

        n_video = 1
        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()
                n_video += 1

        env.display()
        self.assertEqual(len(env.videos),n_video)
        for f in env.videos:
            with self.subTest(file=f[0]):
                self.assertTrue(os.path.exists(f[0]))

        # Can run normally after
        env.reset()
        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()

        env.display()
Ejemplo n.º 2
0
    def test_display(self):
        env = gnwrapper.Monitor(gym.make('CartPole-v1'),directory="./")

        env.reset()

        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()

        env.display()
Ejemplo n.º 3
0
    def test_reset_videos(self):
        env = gnwrapper.Monitor(gym.make('CartPole-v1'),
                                directory="./test_reset_videos/")

        env.reset()
        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()
                self.assertNotEqual(len(env.videos),0)

        env.display(reset=True)
        self.assertEqual(len(env.videos),0)
Ejemplo n.º 4
0
    def test_default_directory(self):
        env = gnwrapper.Monitor(gym.make('CartPole-v1'))

        env.reset()

        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()

        for f in env.videos:
            with self.subTest(file=f[0]):
                self.assertIsNotNone(re.search(r"[0-9]{8}-[0-9]{6}",f[0]))
        env.display()
Ejemplo n.º 5
0
    def test_display_after_close(self):
        """
        Display after close
        """
        env = gnwrapper.Monitor(gym.make('CartPole-v1'),
                                directory="./test_display_after_close/",
                                video_callable=lambda ep: True)
        env.reset()

        for _ in range(100):
            o, r, d, i = env.step(env.action_space.sample())

            if d:
                env.reset()
        env.close()
        env.display()
Ejemplo n.º 6
0
    def test_KeyboardInterrupt(self):
        """
        After KeyboardInterrupt, notebook kernel dies.

        Ref: https://gitlab.com/ymd_h/gym-notebook-wrapper/-/issues/4
        """
        CartPole = "gym.envs.classic_control.cartpole.CartPoleEnv"
        VideoRecorder = "gym.wrappers.monitoring.video_recorder.VideoRecorder"

        env = gnwrapper.Monitor(gym.make('CartPole-v1'),
                                directory="./test_keyboard_interrupt/",
                                video_callable=lambda ep: True)

        for func in [f"{CartPole}.step",
                     f"{VideoRecorder}.capture_frame"]:
            env.reset()
            with self.subTest(function=func):
                with patch(func,
                           MagicMock(side_effect=KeyboardInterrupt)):
                    with self.assertRaises(KeyboardInterrupt):
                        env.step(env.action_space.sample())

                env.reset()
                env.step(env.action_space.sample())
                env.display()
                env.render(mode='rgb_array')


        for func in [f"{CartPole}.reset",
                     "os.waitpid"]:
            env.reset()
            with self.subTest(function=func):
                with patch(func,
                           MagicMock(side_effect=KeyboardInterrupt)):
                    with self.assertRaises(KeyboardInterrupt):
                        env.reset()

                env.reset()
                env.step(env.action_space.sample())
                env.display()
                env.render(mode='rgb_array')
Ejemplo n.º 7
0
import torch
import numpy as np
import gym
import gnwrapper
from torch import nn

from stable_baselines3 import A2C
from stable_baselines3.a2c import MlpPolicy
from stable_baselines3.common.env_checker import check_env

from gym_pybullet_drones.utils.Logger import Logger
from gym_pybullet_drones.envs.single_agent_rl.TakeoffAviary import TakeoffAviary
from gym_pybullet_drones.utils.utils import sync, str2bool

env = gym.make("takeoff-aviary-v0")
monitor_env = gnwrapper.Monitor(gym.make("takeoff-aviary-v0"), size=(
    400, 300), directory='.', force=True, video_callable=lambda ep: True)

episode_max_steps = 300


for episode_idx in range(10):
    monitor_env.reset()
    total_rew = 0.
    for _ in range(episode_max_steps):
        _, rew, done, _ = monitor_env.step(monitor_env.action_space.sample())
        total_rew += rew
        if done:
            break
    print("iter={0: 3d} total reward: {1: 4.4f}".format(
        episode_idx, total_rew))
monitor_env.display()