def test_replay_buffer_normalization(replay_buffer_cls): env = {ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv}[replay_buffer_cls] env = make_vec_env(env) env = VecNormalize(env) buffer = replay_buffer_cls(100, env.observation_space, env.action_space) # Interract and store transitions env.reset() obs = env.get_original_obs() for _ in range(100): action = env.action_space.sample() _, _, done, info = env.step(action) next_obs = env.get_original_obs() reward = env.get_original_reward() buffer.add(obs, next_obs, action, reward, done, info) obs = next_obs sample = buffer.sample(50, env) # Test observation normalization for observations in [sample.observations, sample.next_observations]: if isinstance(sample, DictReplayBufferSamples): for key in observations.keys(): assert th.allclose(observations[key].mean(0), th.zeros(1), atol=1) elif isinstance(sample, ReplayBufferSamples): assert th.allclose(observations.mean(0), th.zeros(1), atol=1) # Test reward normalization assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)
def test_eval_friendly_error(): # tests that eval callback does not crash when given a vector train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")])) eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) eval_env = VecNormalize(eval_env, training=False, norm_reward=False) _ = train_env.reset() original_obs = train_env.get_original_obs() model = A2C("MlpPolicy", train_env, n_steps=50, seed=0) eval_callback = EvalCallback( eval_env, eval_freq=100, warn=False, ) model.learn(100, callback=eval_callback) # Check synchronization assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs)) wrong_eval_env = gym.make("CartPole-v1") eval_callback = EvalCallback( wrong_eval_env, eval_freq=100, warn=False, ) with pytest.warns(Warning): with pytest.raises(AssertionError): model.learn(100, callback=eval_callback)