예제 #1
0
 def __init__(self, env, horizon):
     """
     Parameters
     ----------
     horizon: int
     """
     Wrapper.__init__(self, env)
     self.horizon = horizon
     assert self.horizon >= 1
     self.current_step = 0
예제 #2
0
def test_double_wrapper_copy_reseeding(ModelClass):

    seeding.set_global_seed(123)
    env = Wrapper(Wrapper(ModelClass()))

    c_env = deepcopy(env)
    c_env.reseed()

    if deepcopy(env).is_online():
        traj1 = get_env_trajectory(env, 500)
        traj2 = get_env_trajectory(c_env, 500)
        assert not compare_trajectories(traj1, traj2)
예제 #3
0
def test_wrapper():
    env = GridWorld()
    wrapped = Wrapper(env)
    assert isinstance(wrapped, Model)
    assert wrapped.is_online()
    assert wrapped.is_generative()

    # calling some functions
    wrapped.reset()
    wrapped.step(wrapped.action_space.sample())
    wrapped.sample(wrapped.observation_space.sample(),
                   wrapped.action_space.sample())
예제 #4
0
def test_gym_wrapper():
    gym_env = gym.make('Acrobot-v1')
    wrapped = Wrapper(gym_env)
    assert isinstance(wrapped, Model)
    assert wrapped.is_online()
    assert not wrapped.is_generative()

    wrapped.reseed()

    # calling some gym functions
    wrapped.close()
    wrapped.seed()
예제 #5
0
def test_gym_copy_reseeding_2():
    seeding.set_global_seed(123)
    if _GYM_INSTALLED:
        gym_env = gym.make('Acrobot-v1')
        # nested wrapping
        env = RescaleRewardWrapper(Wrapper(Wrapper(gym_env)), (0, 1))

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
예제 #6
0
def test_wrapper_seeding(ModelClass):

    seeding.set_global_seed(123)
    env1 = Wrapper(ModelClass())

    seeding.set_global_seed(456)
    env2 = Wrapper(ModelClass())

    seeding.set_global_seed(123)
    env3 = Wrapper(ModelClass())

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
예제 #7
0
def test_gym_copy_reseeding():
    seeding.set_global_seed(123)
    if _GYM_INSTALLED:
        gym_env = gym.make('Acrobot-v1')
        env = Wrapper(gym_env)

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
예제 #8
0
def gym_make(env_name):
    """
    Same as gym.make, but wraps the environment
    to ensure unified seeding with rlberry.
    """
    return Wrapper(gym.make(env_name))
예제 #9
0
 def __init__(self, env, reward_range):
     Wrapper.__init__(self, env)
     self.reward_range = reward_range
     assert reward_range[0] < reward_range[1]
     assert reward_range[0] > -np.inf and reward_range[1] < np.inf