def observation(self, obs): new_obs = OrderedDict() new_obs["user"] = obs["user"] new_obs["item"] = np.vstack(list(obs["doc"].values())) new_obs["response"] = obs["response"] new_obs = convert_element_to_space_type(new_obs, self._sampled_obs) return new_obs
def check_shape(self, observation: Any) -> None: """Checks the shape of the given observation.""" if self._i % OBS_VALIDATION_INTERVAL == 0: # Convert lists to np.ndarrays. if type(observation) is list and isinstance( self._obs_space, gym.spaces.Box): observation = np.array(observation).astype(np.float32) if not self._obs_space.contains(observation): observation = convert_element_to_space_type( observation, self._obs_for_type_matching) try: if not self._obs_space.contains(observation): raise ValueError( "Observation ({} dtype={}) outside given space ({})!", observation, observation.dtype if isinstance( self._obs_space, gym.spaces.Box) else None, self._obs_space, ) except AttributeError: raise ValueError( "Observation for a Box/MultiBinary/MultiDiscrete space " "should be an np.array, not a Python list.", observation, ) self._i += 1
def observation(self, obs): new_obs = OrderedDict() new_obs["user"] = obs["user"] new_obs["doc"] = {str(k): v for k, (_, v) in enumerate(obs["doc"].items())} new_obs["response"] = obs["response"] new_obs = convert_element_to_space_type(new_obs, self._sampled_obs) return new_obs
def test_convert_element_to_space_type(self): """Test if space converter works for all elements/space permutations""" box_space = Box(low=-1, high=1, shape=(2, )) discrete_space = Discrete(2) multi_discrete_space = MultiDiscrete([2, 2]) multi_binary_space = MultiBinary(2) tuple_space = Tuple((box_space, discrete_space)) dict_space = Dict({ "box": box_space, "discrete": discrete_space, "multi_discrete": multi_discrete_space, "multi_binary": multi_binary_space, "dict_space": Dict({ "box2": box_space, "discrete2": discrete_space, }), "tuple_space": tuple_space, }) box_space_uncoverted = box_space.sample().astype(np.float64) multi_discrete_unconverted = multi_discrete_space.sample().astype( np.int32) multi_binary_unconverted = multi_binary_space.sample().astype(np.int32) tuple_unconverted = (box_space_uncoverted, float(0)) modified_element = { "box": box_space_uncoverted, "discrete": float(0), "multi_discrete": multi_discrete_unconverted, "multi_binary": multi_binary_unconverted, "tuple_space": tuple_unconverted, "dict_space": { "box2": box_space_uncoverted, "discrete2": float(0), }, } element_with_correct_types = convert_element_to_space_type( modified_element, dict_space.sample()) assert dict_space.contains(element_with_correct_types)
def reset(self): obs = super().reset() obs["response"] = self.env.observation_space["response"].sample() obs = convert_element_to_space_type(obs, self._sampled_obs) return obs
def check_gym_environments(env: gym.Env) -> None: """Checking for common errors in gym environments. Args: env: Environment to be checked. Warning: If env has no attribute spec with a sub attribute, max_episode_steps. Raises: AttributeError: If env has no observation space. AttributeError: If env has no action space. ValueError: Observation space must be a gym.spaces.Space. ValueError: Action space must be a gym.spaces.Space. ValueError: Observation sampled from observation space must be contained in the observation space. ValueError: Action sampled from action space must be contained in the observation space. ValueError: If env cannot be resetted. ValueError: If an observation collected from a call to env.reset(). is not contained in the observation_space. ValueError: If env cannot be stepped via a call to env.step(). ValueError: If the observation collected from env.step() is not contained in the observation_space. AssertionError: If env.step() returns a reward that is not an int or float. AssertionError: IF env.step() returns a done that is not a bool. AssertionError: If env.step() returns an env_info that is not a dict. """ # check that env has observation and action spaces if not hasattr(env, "observation_space"): raise AttributeError("Env must have observation_space.") if not hasattr(env, "action_space"): raise AttributeError("Env must have action_space.") # check that observation and action spaces are gym.spaces if not isinstance(env.observation_space, gym.spaces.Space): raise ValueError("Observation space must be a gym.space") if not isinstance(env.action_space, gym.spaces.Space): raise ValueError("Action space must be a gym.space") # raise a warning if there isn't a max_episode_steps attribute if not hasattr(env, "spec") or not hasattr(env.spec, "max_episode_steps"): logger.warning( "Your env doesn't have a .spec.max_episode_steps " "attribute. This is fine if you have set 'horizon' " "in your config dictionary, or `soft_horizon`. " "However, if you haven't, 'horizon' will default " "to infinity, and your environment will not be " "reset." ) # check if sampled actions and observations are contained within their # respective action and observation spaces. def get_type(var): return var.dtype if hasattr(var, "dtype") else type(var) sampled_action = env.action_space.sample() sampled_observation = env.observation_space.sample() # check if observation generated from stepping the environment is # contained within the observation space reset_obs = env.reset() if not env.observation_space.contains(reset_obs): reset_obs_type = get_type(reset_obs) space_type = env.observation_space.dtype error = ( f"The observation collected from env.reset() was not " f"contained within your env's observation space. Its possible " f"that There was a type mismatch, or that one of the " f"sub-observations was out of bounds: \n\n reset_obs: " f"{reset_obs}\n\n env.observation_space: " f"{env.observation_space}\n\n reset_obs's dtype: " f"{reset_obs_type}\n\n env.observation_space's dtype: " f"{space_type}" ) temp_sampled_reset_obs = convert_element_to_space_type( reset_obs, sampled_observation ) if not env.observation_space.contains(temp_sampled_reset_obs): raise ValueError(error) # check if env.step can run, and generates observations rewards, done # signals and infos that are within their respective spaces and are of # the correct dtypes next_obs, reward, done, info = env.step(sampled_action) if not env.observation_space.contains(next_obs): next_obs_type = get_type(next_obs) space_type = env.observation_space.dtype error = ( f"The observation collected from env.step(sampled_action) was " f"not contained within your env's observation space. Its " f"possible that There was a type mismatch, or that one of the " f"sub-observations was out of bounds:\n\n next_obs: {next_obs}" f"\n\n env.observation_space: {env.observation_space}" f"\n\n next_obs's dtype: {next_obs_type}" f"\n\n env.observation_space's dtype: {space_type}" ) temp_sampled_next_obs = convert_element_to_space_type( next_obs, sampled_observation ) if not env.observation_space.contains(temp_sampled_next_obs): raise ValueError(error) _check_done(done) _check_reward(reward) _check_info(info)