示例#1
0
 def __init__(self, env):
     super().__init__(env)
     self.env = env
     self.action_space = spaces.Tuple((
         spaces.Discrete(2),
         spaces.Box(low=-1, high=1, shape=(2, ), dtype=np.float32),
         spaces.Box(low=-1, high=1, shape=(2, ), dtype=np.float32),
     ))
示例#2
0
 def __init__(self, max_steps):
     """Gym environment for testing that terminal observation is inserted
     correctly."""
     self.action_space = spaces.Discrete(2)
     self.observation_space = spaces.Box(np.array([0]),
                                         np.array([999]),
                                         dtype="int")
     self.max_steps = max_steps
     self.current_step = 0
示例#3
0
def test_identity_spaces(model_class, policy_class, env):
    """
    Additional tests for DQ/SAC/TD3 to check observation space support
    for MultiDiscrete and MultiBinary.
    """
    # DQN only support discrete actions
    if model_class == DQN:
        env.action_space = spaces.Discrete(4)

    env = gym.wrappers.TimeLimit(env, max_episode_steps=100)

    model = model_class(policy_class, env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64]))
    model.learn(total_timesteps=500)

    evaluate_policy(model, env, n_eval_episodes=5, warn=False)
示例#4
0
    def __init__(self):
        super(DummyEnv, self).__init__()

        # First entry in teh spaces.Tuple to be the selection of the skill
        # the remaining entries are the paramters for the specific skills
        # So if you have 2 skills the tuple will need 3 entries in total:
        # The first will be a a discrete space with n = 2
        # The second one will be the parameters for skill1
        # The third will be the parameters for skill2
        self.action_space = spaces.Tuple(
            (
                spaces.Discrete(2),
                spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
                spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32),
            )
        )
        # self.action_space = spaces.Box(low=-1, high=1, shape=(1, ))
        # Example for using image as input:
        self.observation_space = spaces.Box(low=0, high=255, shape=(2,), dtype=np.float32)
示例#5
0
def test_image_space_checks():
    not_image_space = spaces.Box(0, 1, shape=(10, ))
    assert not is_image_space(not_image_space)

    # Not uint8
    not_image_space = spaces.Box(0, 255, shape=(10, 10, 3))
    assert not is_image_space(not_image_space)

    # Not correct shape
    not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8)
    assert not is_image_space(not_image_space)

    # Not correct low/high
    not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
    assert not is_image_space(not_image_space)

    # Not correct space
    not_image_space = spaces.Discrete(n=10)
    assert not is_image_space(not_image_space)

    an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
    assert is_image_space(an_image_space)

    an_image_space_with_odd_channels = spaces.Box(0,
                                                  255,
                                                  shape=(10, 10, 5),
                                                  dtype=np.uint8)
    assert is_image_space(an_image_space_with_odd_channels)
    # Should not pass if we check if channels are valid for an image
    assert not is_image_space(an_image_space_with_odd_channels,
                              check_channels=True)

    # Test if channel-check works
    channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
    assert is_image_space_channels_first(channel_first_space)

    channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
    assert not is_image_space_channels_first(channel_last_space)

    channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8)
    # Should raise a warning
    with pytest.warns(Warning):
        assert not is_image_space_channels_first(channel_mid_space)
示例#6
0
def test_subproc_start_method():
    start_methods = [None]
    # Only test thread-safe methods. Others may deadlock tests! (gh/428)
    # Note: adding unsafe `fork` method as we are now using PyTorch
    all_methods = {"forkserver", "spawn", "fork"}
    available_methods = multiprocessing.get_all_start_methods()
    start_methods += list(all_methods.intersection(available_methods))
    space = spaces.Discrete(2)

    def obs_assert(obs):
        return check_vecenv_obs(obs, space)

    for start_method in start_methods:
        vec_env_class = functools.partial(SubprocVecEnv,
                                          start_method=start_method)
        check_vecenv_spaces(vec_env_class, space, obs_assert)

    with pytest.raises(ValueError,
                       match="cannot find context for 'illegal_method'"):
        vec_env_class = functools.partial(SubprocVecEnv,
                                          start_method="illegal_method")
        check_vecenv_spaces(vec_env_class, space, obs_assert)
示例#7
0
                assert np.all(prev_obs < terminal_obs)
                assert np.all(obs < prev_obs)

                if not isinstance(vec_env, VecNormalize):
                    # more precise tests that we can't do with VecNormalize
                    # (which changes observation values)
                    assert np.all(prev_obs + 1 == terminal_obs)
                    assert np.all(obs == 0)

        prev_obs_b = obs_b

    vec_env.close()


SPACES = collections.OrderedDict([
    ("discrete", spaces.Discrete(2)),
    ("multidiscrete", spaces.MultiDiscrete([2, 3])),
    ("multibinary", spaces.MultiBinary(3)),
    ("continuous", spaces.Box(low=np.zeros(2), high=np.ones(2))),
])


def check_vecenv_spaces(vec_env_class, space, obs_assert):
    """Helper method to check observation spaces in vectorized environments."""
    def make_env():
        return CustomGymEnv(space)

    vec_env = vec_env_class([make_env for _ in range(N_ENVS)])
    obs = vec_env.reset()
    obs_assert(obs)
示例#8
0
    check_env(env)


@pytest.mark.parametrize(
    "new_obs_space",
    [
        # Small image
        spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
        # Range not in [0, 255]
        spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
        # Wrong dtype
        spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
        # Not an image, it should be a 1D vector
        spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
        # Tuple space is not supported by SB
        spaces.Tuple([spaces.Discrete(5),
                      spaces.Discrete(10)]),
        # Dict space is not supported by SB when env is not a GoalEnv
        spaces.Dict({"position": spaces.Discrete(5)}),
    ],
)
def test_non_default_spaces(new_obs_space):
    env = FakeImageEnv()
    env.observation_space = new_obs_space
    # Patch methods to avoid errors
    env.reset = new_obs_space.sample

    def patched_step(_action):
        return new_obs_space.sample(), 0.0, False, {}

    env.step = patched_step