Пример #1
0
def test_replay_buffer_normalization(replay_buffer_cls):
    env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls]
    env = make_vec_env(env)
    env = VecNormalize(env)

    buffer = replay_buffer_cls(100, env.observation_space, env.action_space)

    # Interract and store transitions
    env.reset()
    obs = env.get_original_obs()
    for _ in range(100):
        action = env.action_space.sample()
        _, _, done, info = env.step(action)
        next_obs = env.get_original_obs()
        reward = env.get_original_reward()
        buffer.add(obs, next_obs, action, reward, done, info)
        obs = next_obs

    sample = buffer.sample(50, env)
    # Test observation normalization
    for observations in [sample.observations, sample.next_observations]:
        if isinstance(sample, DictReplayBufferSamples):
            for key in observations.keys():
                assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1)
        elif isinstance(sample, ReplayBufferSamples):
            assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
    # Test reward normalization
    assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
Пример #2
0
def test_eval_friendly_error():
    # tests that eval callback does not crash when given a vector
    train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
    eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
    eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
    _ = train_env.reset()
    original_obs = train_env.get_original_obs()
    model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)

    eval_callback = EvalCallback(
        eval_env,
        eval_freq=100,
        warn=False,
    )
    model.learn(100, callback=eval_callback)

    # Check synchronization
    assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))

    wrong_eval_env = gym.make("CartPole-v1")
    eval_callback = EvalCallback(
        wrong_eval_env,
        eval_freq=100,
        warn=False,
    )

    with pytest.warns(Warning):
        with pytest.raises(AssertionError):
            model.learn(100, callback=eval_callback)