Esempio n. 1
0
    def __init__(
        self,
        env,
        n_bins_obs=10,
        memory_size=100,
        state_preprocess_fn=None,
        state_preprocess_kwargs=None,
    ):
        Wrapper.__init__(self, env)

        if state_preprocess_fn is None:
            assert isinstance(env.observation_space, spaces.Box)
        assert isinstance(env.action_space, spaces.Discrete)

        self.state_preprocess_fn = state_preprocess_fn or identity
        self.state_preprocess_kwargs = state_preprocess_kwargs or {}

        self.memory = TrajectoryMemory(memory_size)
        self.total_visit_counter = DiscreteCounter(self.env.observation_space,
                                                   self.env.action_space,
                                                   n_bins_obs=n_bins_obs)
        self.episode_visit_counter = DiscreteCounter(
            self.env.observation_space,
            self.env.action_space,
            n_bins_obs=n_bins_obs)
        self.current_state = None
        self.curret_step = 0
Esempio n. 2
0
 def __init__(self, env):
     Wrapper.__init__(self, env, wrap_spaces=True)
     obs_space = self.env.observation_space
     assert isinstance(obs_space, Discrete)
     self.observation_space = Box(
         low=0.0, high=1.0, shape=(obs_space.n,), dtype=np.uint32
     )
Esempio n. 3
0
 def __init__(self, env, horizon):
     """
     Parameters
     ----------
     horizon: int
     """
     Wrapper.__init__(self, env)
     self.horizon = horizon
     assert self.horizon >= 1
     self.current_step = 0
def test_double_wrapper_copy_reseeding(ModelClass):

    seeding.set_global_seed(123)
    env = Wrapper(Wrapper(ModelClass()))

    c_env = deepcopy(env)
    c_env.reseed()

    if deepcopy(env).is_online():
        traj1 = get_env_trajectory(env, 500)
        traj2 = get_env_trajectory(c_env, 500)
        assert not compare_trajectories(traj1, traj2)
Esempio n. 5
0
def test_wrapper():
    env = GridWorld()
    wrapped = Wrapper(env)
    assert isinstance(wrapped, Model)
    assert wrapped.is_online()
    assert wrapped.is_generative()

    # calling some functions
    wrapped.reset()
    wrapped.step(wrapped.action_space.sample())
    wrapped.sample(wrapped.observation_space.sample(),
                   wrapped.action_space.sample())
Esempio n. 6
0
def test_gym_wrapper():
    gym_env = gym.make("Acrobot-v1")
    wrapped = Wrapper(gym_env)
    assert isinstance(wrapped, Model)
    assert wrapped.is_online()
    assert not wrapped.is_generative()

    wrapped.reseed()

    # calling some gym functions
    wrapped.close()
    wrapped.seed()
Esempio n. 7
0
def test_wrapper_copy_reseeding(ModelClass):
    env = Wrapper(ModelClass())
    seeder = Seeder(123)
    env.reseed(seeder)

    c_env = deepcopy(env)
    c_env.reseed()

    if deepcopy(env).is_online():
        traj1 = get_env_trajectory(env, 500)
        traj2 = get_env_trajectory(c_env, 500)
        assert not compare_trajectories(traj1, traj2)
def test_gym_copy_reseeding_2():
    seeding.set_global_seed(123)
    if _GYM_INSTALLED:
        gym_env = gym.make('Acrobot-v1')
        # nested wrapping
        env = RescaleRewardWrapper(Wrapper(Wrapper(gym_env)), (0, 1))

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
Esempio n. 9
0
def test_gym_copy_reseeding():
    seeder = Seeder(123)
    if _GYM_INSTALLED:
        gym_env = gym.make("Acrobot-v1")
        env = Wrapper(gym_env)
        env.reseed(seeder)

        c_env = deepcopy(env)
        c_env.reseed()

        if deepcopy(env).is_online():
            traj1 = get_env_trajectory(env, 500)
            traj2 = get_env_trajectory(c_env, 500)
            assert not compare_trajectories(traj1, traj2)
def test_wrapper_seeding(ModelClass):

    seeding.set_global_seed(123)
    env1 = Wrapper(ModelClass())

    seeding.set_global_seed(456)
    env2 = Wrapper(ModelClass())

    seeding.set_global_seed(123)
    env3 = Wrapper(ModelClass())

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)
Esempio n. 11
0
    def __init__(self,
                 env,
                 uncertainty_estimator_fn,
                 uncertainty_estimator_kwargs=None,
                 bonus_scale_factor=1.0,
                 bonus_max=np.inf):
        Wrapper.__init__(self, env)

        self.bonus_scale_factor = bonus_scale_factor
        self.bonus_max = bonus_max
        uncertainty_estimator_kwargs = uncertainty_estimator_kwargs or {}

        uncertainty_estimator_fn = load(uncertainty_estimator_fn) if isinstance(uncertainty_estimator_fn, str) else \
            uncertainty_estimator_fn
        self.uncertainty_estimator = uncertainty_estimator_fn(
            env.observation_space, env.action_space,
            **uncertainty_estimator_kwargs)
        self.previous_obs = None
Esempio n. 12
0
 def __init__(self, env):
     Wrapper.__init__(self, env)
Esempio n. 13
0
 def __init__(self, env, reward_range):
     Wrapper.__init__(self, env)
     self.reward_range = reward_range
     assert reward_range[0] < reward_range[1]
     assert reward_range[0] > -np.inf and reward_range[1] < np.inf
Esempio n. 14
0
def test_render2d_interface_wrapped(ModelClass):
    env = Wrapper(ModelClass())

    if isinstance(env.env, RenderInterface2D):
        env.enable_rendering()
        if env.is_online():
            for _ in range(2):
                state = env.reset()
                for _ in range(5):
                    assert env.observation_space.contains(state)
                    action = env.action_space.sample()
                    next_s, _, _, _ = env.step(action)
                    state = next_s
                env.render(loop=False)
            env.save_video("test_video.mp4")
            env.clear_render_buffer()
        try:
            os.remove("test_video.mp4")
        except Exception:
            pass
Esempio n. 15
0
def test_wrapper_seeding(ModelClass):
    env1 = Wrapper(ModelClass())
    seeder = Seeder(123)
    env1.reseed(seeder)

    env2 = Wrapper(ModelClass())
    seeder = Seeder(456)
    env2.reseed(seeder)

    env3 = Wrapper(ModelClass())
    seeder = Seeder(123)
    env3.reseed(seeder)

    if deepcopy(env1).is_online():
        traj1 = get_env_trajectory(env1, 500)
        traj2 = get_env_trajectory(env2, 500)
        traj3 = get_env_trajectory(env3, 500)

        assert not compare_trajectories(traj1, traj2)
        assert compare_trajectories(traj1, traj3)