示例#1
0
文件: common.py 项目: ymetz/imitation
def _reward_fn_normalize_inputs(
    obs: np.ndarray,
    acts: np.ndarray,
    next_obs: np.ndarray,
    dones: np.ndarray,
    *,
    reward_fn: RewardFn,
    vec_normalize: vec_env.VecNormalize,
    norm_reward: bool = True,
) -> np.ndarray:
    """Combine with `functools.partial` to create an input-normalizing RewardFn.

    Args:
        reward_fn: The reward function that normalized inputs are evaluated on.
        vec_normalize: Instance of VecNormalize used to normalize inputs and
            rewards.
        norm_reward: If True, then also normalize reward before returning.

    Returns:
        The possibly normalized reward.
    """
    norm_obs = vec_normalize.normalize_obs(obs)
    norm_next_obs = vec_normalize.normalize_obs(next_obs)
    rew = reward_fn(norm_obs, acts, norm_next_obs, dones)
    if norm_reward:
        rew = vec_normalize.normalize_reward(rew)
    return rew
def test_eval_friendly_error():
    # tests that eval callback does not crash when given a vector
    train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
    eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
    eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
    _ = train_env.reset()
    original_obs = train_env.get_original_obs()
    model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)

    eval_callback = EvalCallback(
        eval_env,
        eval_freq=100,
        warn=False,
    )
    model.learn(100, callback=eval_callback)

    # Check synchronization
    assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))

    wrong_eval_env = gym.make("CartPole-v1")
    eval_callback = EvalCallback(
        wrong_eval_env,
        eval_freq=100,
        warn=False,
    )

    with pytest.warns(Warning):
        with pytest.raises(AssertionError):
            model.learn(100, callback=eval_callback)