def test_monitor(tmp_path):
    """
    Test the monitor wrapper
    """
    env = gym.make("CartPole-v1")
    env.seed(0)
    monitor_file = os.path.join(
        str(tmp_path),
        "stable_baselines-test-{}.monitor.csv".format(uuid.uuid4()))
    monitor_env = Monitor(env, monitor_file)
    monitor_env.reset()
    total_steps = 1000
    ep_rewards = []
    ep_lengths = []
    ep_len, ep_reward = 0, 0
    for _ in range(total_steps):
        _, reward, done, _ = monitor_env.step(
            monitor_env.action_space.sample())
        ep_len += 1
        ep_reward += reward
        if done:
            ep_rewards.append(ep_reward)
            ep_lengths.append(ep_len)
            monitor_env.reset()
            ep_len, ep_reward = 0, 0

    monitor_env.close()
    assert monitor_env.get_total_steps() == total_steps
    assert sum(ep_lengths) == sum(monitor_env.get_episode_lengths())
    assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
    _ = monitor_env.get_episode_times()

    with open(monitor_file, "rt") as file_handler:
        first_line = file_handler.readline()
        assert first_line.startswith("#")
        metadata = json.loads(first_line[1:])
        assert metadata["env_id"] == "CartPole-v1"
        assert set(metadata.keys()) == {"env_id", "t_start"
                                        }, "Incorrect keys in monitor metadata"

        last_logline = pandas.read_csv(file_handler, index_col=None)
        assert set(
            last_logline.keys()) == {"l", "t",
                                     "r"}, "Incorrect keys in monitor logline"
    os.remove(monitor_file)
예제 #2
0
               out_ref=inflow.expectation(),
               **env_cfg)
with open(f'{timestamp}/env.txt', 'w') as f:
    print(str(env), file=f)
with open(f'{timestamp}/inflow.txt', 'w') as f:
    print(str(inflow), file=f)
env = Monitor(env)

model = PPO('MlpPolicy',
            env,
            verbose=1,
            tensorboard_log=f'{timestamp}/',
            gamma=.5)
model.learn(total_timesteps=5000000)
model.save(f'{timestamp}/model')

visualize_nets(env, model, timestamp)

rec = VideoRecorder(env, f'{timestamp}/vid.mp4')
obs = env.reset()
for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    env.unwrapped.render()
    rec.capture_frame()
    if done:
        obs = env.reset()

env.close()
rec.close()