Exemplo n.º 1
0
    def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Returns:
            True if the observations of x are contained in space.
        """
        agents = set(self.get_agent_ids())
        for multi_agent_dict in x.values():
            for agent_id, obs in multi_agent_dict.items():
                # this is for the case where we have a single agent
                # and we're checking a Vector env thats been converted to
                # a BaseEnv
                if agent_id == _DUMMY_AGENT_ID:
                    if not space.contains(obs):
                        return False
                # for the MultiAgent env case
                elif (agent_id
                      not in agents) or (not space[agent_id].contains(obs)):
                    return False

        return True
Exemplo n.º 2
0
    def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Note: With vector envs, we can process the raw observations
            and ignore the agent ids and env ids, since vector envs'
            sub environements are guaranteed to be the same

        Returns:
            True if the observations of x are contained in space.
        """
        for _, multi_agent_dict in x.items():
            for _, element in multi_agent_dict.items():
                if not space.contains(element):
                    return False
        return True
Exemplo n.º 3
0
def _gym_space_contains(space: gym.Space, x: Any) -> None:
    """Strengthened version of gym.Space.contains.
    Giving more diagnostic information on why validation fails.

    Throw exception rather than returning true or false.
    """
    if isinstance(space, spaces.Dict):
        if not isinstance(x, dict) or len(x) != len(space):
            raise GymSpaceValidationError(
                "Sample must be a dict with same length as space.", space, x)
        for k, subspace in space.spaces.items():
            if k not in x:
                raise GymSpaceValidationError(f"Key {k} not found in sample.",
                                              space, x)
            try:
                _gym_space_contains(subspace, x[k])
            except GymSpaceValidationError as e:
                raise GymSpaceValidationError(
                    f"Subspace of key {k} validation error.", space, x) from e

    elif isinstance(space, spaces.Tuple):
        if isinstance(x, (list, np.ndarray)):
            x = tuple(
                x)  # Promote list and ndarray to tuple for contains check
        if not isinstance(x, tuple) or len(x) != len(space):
            raise GymSpaceValidationError(
                "Sample must be a tuple with same length as space.", space, x)
        for i, (subspace, part) in enumerate(zip(space, x)):
            try:
                _gym_space_contains(subspace, part)
            except GymSpaceValidationError as e:
                raise GymSpaceValidationError(
                    f"Subspace of index {i} validation error.", space,
                    x) from e

    else:
        if not space.contains(x):
            raise GymSpaceValidationError("Validation error reported by gym.",
                                          space, x)
 def __init__(self, ob_space: gym.Space, ac_space: gym.Space, fixed_val: np.ndarray):
     super().__init__(ob_space, ac_space)
     self.fixed_val = fixed_val
     if not ac_space.contains(fixed_val):  # pragma: no cover
         raise ValueError(f"fixed_val = '{fixed_val}' not contained in ac_space = '{ac_space}'")