Ejemplo n.º 1
0
def test_multidiscrete():
    gym_sp = gym.spaces.MultiDiscrete([5, 2, 2])
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.MultiDiscrete)
    sp.reseed(123)
    for _ in range(10):
        assert sp.contains(sp.sample())
Ejemplo n.º 2
0
def test_multibinary():
    for n in [1, 5, [3, 4]]:
        gym_sp = gym.spaces.MultiBinary(n)
        sp = convert_space_from_gym(gym_sp)
        assert isinstance(sp, rlberry.spaces.MultiBinary)
        for _ in range(10):
            assert sp.contains(sp.sample())
        sp.reseed(123)
Ejemplo n.º 3
0
def test_box_space_case_1(low, high, dim):
    shape = (dim, 1)
    gym_sp = gym.spaces.Box(low, high, shape=shape)
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.Box)
    sp.reseed(123)
    for _ in range(2**dim):
        assert sp.contains(sp.sample())
Ejemplo n.º 4
0
def test_dict():
    nested_observation_space = gym.spaces.Dict({
        "sensors":
        gym.spaces.Dict({
            "position":
            gym.spaces.Box(low=-100, high=100, shape=(3, )),
            "velocity":
            gym.spaces.Box(low=-1, high=1, shape=(3, )),
            "front_cam":
            gym.spaces.Tuple((
                gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
                gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
            )),
            "rear_cam":
            gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)),
        }),
        "ext_controller":
        gym.spaces.MultiDiscrete((5, 2, 2)),
        "inner_state":
        gym.spaces.Dict({
            "charge":
            gym.spaces.Discrete(100),
            "system_checks":
            gym.spaces.MultiBinary(10),
            "job_status":
            gym.spaces.Dict({
                "task":
                gym.spaces.Discrete(5),
                "progress":
                gym.spaces.Box(low=0, high=100, shape=()),
            }),
        }),
    })
    gym_sp = nested_observation_space
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.Dict)
    for _ in range(10):
        assert sp.contains(sp.sample())
    sp.reseed(123)

    gym_sp2 = gym.spaces.Dict(sp.spaces)
    sp2 = convert_space_from_gym(gym_sp2)
    assert isinstance(sp2, rlberry.spaces.Dict)
    for _ in range(10):
        assert sp.contains(sp2.sample())
    sp2.reseed(123)
Ejemplo n.º 5
0
def test_discrete_space(n):
    gym_sp = gym.spaces.Discrete(n)
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.Discrete)
    sp.reseed(123)
    for ii in range(n):
        assert sp.contains(ii)

    for ii in range(2 * n):
        assert sp.contains(sp.sample())
Ejemplo n.º 6
0
def test_tuple():
    sp1 = gym.spaces.Box(0.0, 1.0, shape=(3, 2))
    sp2 = gym.spaces.Discrete(2)
    gym_sp = gym.spaces.Tuple([sp1, sp2])
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.Tuple)
    assert isinstance(sp.spaces[0], rlberry.spaces.Box)
    assert isinstance(sp.spaces[1], rlberry.spaces.Discrete)
    sp.reseed(123)
    for _ in range(10):
        assert sp.contains(sp.sample())
Ejemplo n.º 7
0
def test_box_space_case_2(low, high):
    gym_sp = gym.spaces.Box(low, high, dtype=np.float64)
    sp = convert_space_from_gym(gym_sp)
    assert isinstance(sp, rlberry.spaces.Box)
    sp.reseed(123)
    if (-np.inf in low) or (np.inf in high):
        assert not sp.is_bounded()
    else:
        assert sp.is_bounded()
    for ii in range(2**sp.shape[0]):
        assert sp.contains(sp.sample())
Ejemplo n.º 8
0
    def __init__(self, env, wrap_spaces=False):
        # Init base class
        Model.__init__(self)

        # Save reference to env
        self.env = env
        self.metadata = self.env.metadata

        if wrap_spaces:
            self.observation_space = convert_space_from_gym(
                self.env.observation_space)
            self.action_space = convert_space_from_gym(self.env.action_space)
        else:
            self.observation_space = self.env.observation_space
            self.action_space = self.env.action_space

        try:
            self.reward_range = self.env.reward_range
        except AttributeError:
            self.reward_range = (-np.inf, np.inf)