示例#1
0
 def observation(self, obs):
     new_obs = OrderedDict()
     new_obs["user"] = obs["user"]
     new_obs["item"] = np.vstack(list(obs["doc"].values()))
     new_obs["response"] = obs["response"]
     new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
     return new_obs
示例#2
0
 def check_shape(self, observation: Any) -> None:
     """Checks the shape of the given observation."""
     if self._i % OBS_VALIDATION_INTERVAL == 0:
         # Convert lists to np.ndarrays.
         if type(observation) is list and isinstance(
                 self._obs_space, gym.spaces.Box):
             observation = np.array(observation).astype(np.float32)
         if not self._obs_space.contains(observation):
             observation = convert_element_to_space_type(
                 observation, self._obs_for_type_matching)
         try:
             if not self._obs_space.contains(observation):
                 raise ValueError(
                     "Observation ({} dtype={}) outside given space ({})!",
                     observation,
                     observation.dtype if isinstance(
                         self._obs_space, gym.spaces.Box) else None,
                     self._obs_space,
                 )
         except AttributeError:
             raise ValueError(
                 "Observation for a Box/MultiBinary/MultiDiscrete space "
                 "should be an np.array, not a Python list.",
                 observation,
             )
     self._i += 1
示例#3
0
 def observation(self, obs):
     new_obs = OrderedDict()
     new_obs["user"] = obs["user"]
     new_obs["doc"] = {str(k): v for k, (_, v) in enumerate(obs["doc"].items())}
     new_obs["response"] = obs["response"]
     new_obs = convert_element_to_space_type(new_obs, self._sampled_obs)
     return new_obs
示例#4
0
    def test_convert_element_to_space_type(self):
        """Test if space converter works for all elements/space permutations"""
        box_space = Box(low=-1, high=1, shape=(2, ))
        discrete_space = Discrete(2)
        multi_discrete_space = MultiDiscrete([2, 2])
        multi_binary_space = MultiBinary(2)
        tuple_space = Tuple((box_space, discrete_space))
        dict_space = Dict({
            "box":
            box_space,
            "discrete":
            discrete_space,
            "multi_discrete":
            multi_discrete_space,
            "multi_binary":
            multi_binary_space,
            "dict_space":
            Dict({
                "box2": box_space,
                "discrete2": discrete_space,
            }),
            "tuple_space":
            tuple_space,
        })

        box_space_uncoverted = box_space.sample().astype(np.float64)
        multi_discrete_unconverted = multi_discrete_space.sample().astype(
            np.int32)
        multi_binary_unconverted = multi_binary_space.sample().astype(np.int32)
        tuple_unconverted = (box_space_uncoverted, float(0))
        modified_element = {
            "box": box_space_uncoverted,
            "discrete": float(0),
            "multi_discrete": multi_discrete_unconverted,
            "multi_binary": multi_binary_unconverted,
            "tuple_space": tuple_unconverted,
            "dict_space": {
                "box2": box_space_uncoverted,
                "discrete2": float(0),
            },
        }
        element_with_correct_types = convert_element_to_space_type(
            modified_element, dict_space.sample())
        assert dict_space.contains(element_with_correct_types)
示例#5
0
文件: recsim.py 项目: stjordanis/ray
 def reset(self):
     obs = super().reset()
     obs["response"] = self.env.observation_space["response"].sample()
     obs = convert_element_to_space_type(obs, self._sampled_obs)
     return obs
示例#6
0
def check_gym_environments(env: gym.Env) -> None:
    """Checking for common errors in gym environments.

    Args:
        env: Environment to be checked.

    Warning:
        If env has no attribute spec with a sub attribute,
            max_episode_steps.

    Raises:
        AttributeError: If env has no observation space.
        AttributeError: If env has no action space.
        ValueError: Observation space must be a gym.spaces.Space.
        ValueError: Action space must be a gym.spaces.Space.
        ValueError: Observation sampled from observation space must be
            contained in the observation space.
        ValueError: Action sampled from action space must be
            contained in the observation space.
        ValueError: If env cannot be resetted.
        ValueError: If an observation collected from a call to env.reset().
            is not contained in the observation_space.
        ValueError: If env cannot be stepped via a call to env.step().
        ValueError: If the observation collected from env.step() is not
            contained in the observation_space.
        AssertionError: If env.step() returns a reward that is not an
            int or float.
        AssertionError: IF env.step() returns a done that is not a bool.
        AssertionError: If env.step() returns an env_info that is not a dict.
    """

    # check that env has observation and action spaces
    if not hasattr(env, "observation_space"):
        raise AttributeError("Env must have observation_space.")
    if not hasattr(env, "action_space"):
        raise AttributeError("Env must have action_space.")

    # check that observation and action spaces are gym.spaces
    if not isinstance(env.observation_space, gym.spaces.Space):
        raise ValueError("Observation space must be a gym.space")
    if not isinstance(env.action_space, gym.spaces.Space):
        raise ValueError("Action space must be a gym.space")

    # raise a warning if there isn't a max_episode_steps attribute
    if not hasattr(env, "spec") or not hasattr(env.spec, "max_episode_steps"):
        logger.warning(
            "Your env doesn't have a .spec.max_episode_steps "
            "attribute. This is fine if you have set 'horizon' "
            "in your config dictionary, or `soft_horizon`. "
            "However, if you haven't, 'horizon' will default "
            "to infinity, and your environment will not be "
            "reset."
        )
    # check if sampled actions and observations are contained within their
    # respective action and observation spaces.

    def get_type(var):
        return var.dtype if hasattr(var, "dtype") else type(var)

    sampled_action = env.action_space.sample()
    sampled_observation = env.observation_space.sample()
    # check if observation generated from stepping the environment is
    # contained within the observation space
    reset_obs = env.reset()
    if not env.observation_space.contains(reset_obs):
        reset_obs_type = get_type(reset_obs)
        space_type = env.observation_space.dtype
        error = (
            f"The observation collected from env.reset() was not  "
            f"contained within your env's observation space. Its possible "
            f"that There was a type mismatch, or that one of the "
            f"sub-observations  was out of bounds: \n\n reset_obs: "
            f"{reset_obs}\n\n env.observation_space: "
            f"{env.observation_space}\n\n reset_obs's dtype: "
            f"{reset_obs_type}\n\n env.observation_space's dtype: "
            f"{space_type}"
        )
        temp_sampled_reset_obs = convert_element_to_space_type(
            reset_obs, sampled_observation
        )
        if not env.observation_space.contains(temp_sampled_reset_obs):
            raise ValueError(error)
    # check if env.step can run, and generates observations rewards, done
    # signals and infos that are within their respective spaces and are of
    # the correct dtypes
    next_obs, reward, done, info = env.step(sampled_action)
    if not env.observation_space.contains(next_obs):
        next_obs_type = get_type(next_obs)
        space_type = env.observation_space.dtype
        error = (
            f"The observation collected from env.step(sampled_action) was "
            f"not contained within your env's observation space. Its "
            f"possible that There was a type mismatch, or that one of the "
            f"sub-observations was out of bounds:\n\n next_obs: {next_obs}"
            f"\n\n env.observation_space: {env.observation_space}"
            f"\n\n next_obs's dtype: {next_obs_type}"
            f"\n\n env.observation_space's dtype: {space_type}"
        )
        temp_sampled_next_obs = convert_element_to_space_type(
            next_obs, sampled_observation
        )
        if not env.observation_space.contains(temp_sampled_next_obs):
            raise ValueError(error)
    _check_done(done)
    _check_reward(reward)
    _check_info(info)