def __init__(self, env): super().__init__(env) self.env = env self.action_space = spaces.Tuple(( spaces.Discrete(2), spaces.Box(low=-1, high=1, shape=(2, ), dtype=np.float32), spaces.Box(low=-1, high=1, shape=(2, ), dtype=np.float32), ))
def __init__(self, max_steps): """Gym environment for testing that terminal observation is inserted correctly.""" self.action_space = spaces.Discrete(2) self.observation_space = spaces.Box(np.array([0]), np.array([999]), dtype="int") self.max_steps = max_steps self.current_step = 0
def test_identity_spaces(model_class, policy_class, env): """ Additional tests for DQ/SAC/TD3 to check observation space support for MultiDiscrete and MultiBinary. """ # DQN only support discrete actions if model_class == DQN: env.action_space = spaces.Discrete(4) env = gym.wrappers.TimeLimit(env, max_episode_steps=100) model = model_class(policy_class, env, gamma=0.5, seed=1, policy_kwargs=dict(net_arch=[64])) model.learn(total_timesteps=500) evaluate_policy(model, env, n_eval_episodes=5, warn=False)
def __init__(self): super(DummyEnv, self).__init__() # First entry in teh spaces.Tuple to be the selection of the skill # the remaining entries are the paramters for the specific skills # So if you have 2 skills the tuple will need 3 entries in total: # The first will be a a discrete space with n = 2 # The second one will be the parameters for skill1 # The third will be the parameters for skill2 self.action_space = spaces.Tuple( ( spaces.Discrete(2), spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32), spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32), ) ) # self.action_space = spaces.Box(low=-1, high=1, shape=(1, )) # Example for using image as input: self.observation_space = spaces.Box(low=0, high=255, shape=(2,), dtype=np.float32)
def test_image_space_checks(): not_image_space = spaces.Box(0, 1, shape=(10, )) assert not is_image_space(not_image_space) # Not uint8 not_image_space = spaces.Box(0, 255, shape=(10, 10, 3)) assert not is_image_space(not_image_space) # Not correct shape not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8) assert not is_image_space(not_image_space) # Not correct low/high not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8) assert not is_image_space(not_image_space) # Not correct space not_image_space = spaces.Discrete(n=10) assert not is_image_space(not_image_space) an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) assert is_image_space(an_image_space) an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8) assert is_image_space(an_image_space_with_odd_channels) # Should not pass if we check if channels are valid for an image assert not is_image_space(an_image_space_with_odd_channels, check_channels=True) # Test if channel-check works channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8) assert is_image_space_channels_first(channel_first_space) channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8) assert not is_image_space_channels_first(channel_last_space) channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8) # Should raise a warning with pytest.warns(Warning): assert not is_image_space_channels_first(channel_mid_space)
def test_subproc_start_method(): start_methods = [None] # Only test thread-safe methods. Others may deadlock tests! (gh/428) # Note: adding unsafe `fork` method as we are now using PyTorch all_methods = {"forkserver", "spawn", "fork"} available_methods = multiprocessing.get_all_start_methods() start_methods += list(all_methods.intersection(available_methods)) space = spaces.Discrete(2) def obs_assert(obs): return check_vecenv_obs(obs, space) for start_method in start_methods: vec_env_class = functools.partial(SubprocVecEnv, start_method=start_method) check_vecenv_spaces(vec_env_class, space, obs_assert) with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"): vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method") check_vecenv_spaces(vec_env_class, space, obs_assert)
assert np.all(prev_obs < terminal_obs) assert np.all(obs < prev_obs) if not isinstance(vec_env, VecNormalize): # more precise tests that we can't do with VecNormalize # (which changes observation values) assert np.all(prev_obs + 1 == terminal_obs) assert np.all(obs == 0) prev_obs_b = obs_b vec_env.close() SPACES = collections.OrderedDict([ ("discrete", spaces.Discrete(2)), ("multidiscrete", spaces.MultiDiscrete([2, 3])), ("multibinary", spaces.MultiBinary(3)), ("continuous", spaces.Box(low=np.zeros(2), high=np.ones(2))), ]) def check_vecenv_spaces(vec_env_class, space, obs_assert): """Helper method to check observation spaces in vectorized environments.""" def make_env(): return CustomGymEnv(space) vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) obs = vec_env.reset() obs_assert(obs)
check_env(env) @pytest.mark.parametrize( "new_obs_space", [ # Small image spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), # Range not in [0, 255] spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), # Wrong dtype spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), # Not an image, it should be a 1D vector spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), # Tuple space is not supported by SB spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), # Dict space is not supported by SB when env is not a GoalEnv spaces.Dict({"position": spaces.Discrete(5)}), ], ) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors env.reset = new_obs_space.sample def patched_step(_action): return new_obs_space.sample(), 0.0, False, {} env.step = patched_step