def __init__(self, env, horizon): """ Parameters ---------- horizon: int """ Wrapper.__init__(self, env) self.horizon = horizon assert self.horizon >= 1 self.current_step = 0
def test_double_wrapper_copy_reseeding(ModelClass): seeding.set_global_seed(123) env = Wrapper(Wrapper(ModelClass())) c_env = deepcopy(env) c_env.reseed() if deepcopy(env).is_online(): traj1 = get_env_trajectory(env, 500) traj2 = get_env_trajectory(c_env, 500) assert not compare_trajectories(traj1, traj2)
def test_wrapper(): env = GridWorld() wrapped = Wrapper(env) assert isinstance(wrapped, Model) assert wrapped.is_online() assert wrapped.is_generative() # calling some functions wrapped.reset() wrapped.step(wrapped.action_space.sample()) wrapped.sample(wrapped.observation_space.sample(), wrapped.action_space.sample())
def test_gym_wrapper(): gym_env = gym.make('Acrobot-v1') wrapped = Wrapper(gym_env) assert isinstance(wrapped, Model) assert wrapped.is_online() assert not wrapped.is_generative() wrapped.reseed() # calling some gym functions wrapped.close() wrapped.seed()
def test_gym_copy_reseeding_2(): seeding.set_global_seed(123) if _GYM_INSTALLED: gym_env = gym.make('Acrobot-v1') # nested wrapping env = RescaleRewardWrapper(Wrapper(Wrapper(gym_env)), (0, 1)) c_env = deepcopy(env) c_env.reseed() if deepcopy(env).is_online(): traj1 = get_env_trajectory(env, 500) traj2 = get_env_trajectory(c_env, 500) assert not compare_trajectories(traj1, traj2)
def test_wrapper_seeding(ModelClass): seeding.set_global_seed(123) env1 = Wrapper(ModelClass()) seeding.set_global_seed(456) env2 = Wrapper(ModelClass()) seeding.set_global_seed(123) env3 = Wrapper(ModelClass()) if deepcopy(env1).is_online(): traj1 = get_env_trajectory(env1, 500) traj2 = get_env_trajectory(env2, 500) traj3 = get_env_trajectory(env3, 500) assert not compare_trajectories(traj1, traj2) assert compare_trajectories(traj1, traj3)
def test_gym_copy_reseeding(): seeding.set_global_seed(123) if _GYM_INSTALLED: gym_env = gym.make('Acrobot-v1') env = Wrapper(gym_env) c_env = deepcopy(env) c_env.reseed() if deepcopy(env).is_online(): traj1 = get_env_trajectory(env, 500) traj2 = get_env_trajectory(c_env, 500) assert not compare_trajectories(traj1, traj2)
def gym_make(env_name): """ Same as gym.make, but wraps the environment to ensure unified seeding with rlberry. """ return Wrapper(gym.make(env_name))
def __init__(self, env, reward_range): Wrapper.__init__(self, env) self.reward_range = reward_range assert reward_range[0] < reward_range[1] assert reward_range[0] > -np.inf and reward_range[1] < np.inf