def _reward_fn_normalize_inputs( obs: np.ndarray, acts: np.ndarray, next_obs: np.ndarray, dones: np.ndarray, *, reward_fn: RewardFn, vec_normalize: vec_env.VecNormalize, norm_reward: bool = True, ) -> np.ndarray: """Combine with `functools.partial` to create an input-normalizing RewardFn. Args: reward_fn: The reward function that normalized inputs are evaluated on. vec_normalize: Instance of VecNormalize used to normalize inputs and rewards. norm_reward: If True, then also normalize reward before returning. Returns: The possibly normalized reward. """ norm_obs = vec_normalize.normalize_obs(obs) norm_next_obs = vec_normalize.normalize_obs(next_obs) rew = reward_fn(norm_obs, acts, norm_next_obs, dones) if norm_reward: rew = vec_normalize.normalize_reward(rew) return rew
def test_eval_friendly_error(): # tests that eval callback does not crash when given a vector train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")])) eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) eval_env = VecNormalize(eval_env, training=False, norm_reward=False) _ = train_env.reset() original_obs = train_env.get_original_obs() model = A2C("MlpPolicy", train_env, n_steps=50, seed=0) eval_callback = EvalCallback( eval_env, eval_freq=100, warn=False, ) model.learn(100, callback=eval_callback) # Check synchronization assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs)) wrong_eval_env = gym.make("CartPole-v1") eval_callback = EvalCallback( wrong_eval_env, eval_freq=100, warn=False, ) with pytest.warns(Warning): with pytest.raises(AssertionError): model.learn(100, callback=eval_callback)