Exemplo n.º 1
0
 def __init__(self,
              env,
              act_null_value=0,
              obs_null_value=0,
              force_float32=True):
     super().__init__(env)
     o = self.env.reset()
     o, r, d, info = self.env.step(self.env.action_space.sample())
     env_ = self.env
     time_limit = isinstance(self.env, TimeLimit)
     while not time_limit and hasattr(env_, "env"):
         env_ = env_.env
         time_limit = isinstance(self.env, TimeLimit)
     if time_limit:
         info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
     self._time_limit = time_limit
     self.action_space = GymSpaceWrapper(
         space=self.env.action_space,
         name="act",
         null_value=act_null_value,
         force_float32=force_float32,
     )
     self.observation_space = GymSpaceWrapper(
         space=self.env.observation_space,
         name="obs",
         null_value=obs_null_value,
         force_float32=force_float32,
     )
     self._info_schemas = {}
     self._build_info_schemas(info)
Exemplo n.º 2
0
class GymEnvWrapper(Wrapper):
    """Gym-style wrapper for converting the Openai Gym interface to the
    rlpyt interface.  Action and observation spaces are wrapped by rlpyt's
    ``GymSpaceWrapper``.

    Output `env_info` is automatically converted from a dictionary to a
    corresponding namedtuple, which the rlpyt sampler expects.  For this to
    work, every key that might appear in the gym environments `env_info` at
    any step must appear at the first step after a reset, as the `env_info`
    entries will have sampler memory pre-allocated for them (so they also
    cannot change dtype or shape). (see `EnvInfoWrapper`, `build_info_schemas`,
    and `info_to_nt` in file or more help/details)

    Warning:
        Unrecognized keys in `env_info` appearing later during use will be
        silently ignored.

    This wrapper looks for gym's ``TimeLimit`` env wrapper to
    see whether to add the field ``timeout`` to env info.   
    """
    def __init__(self,
                 env,
                 act_null_value=0,
                 obs_null_value=0,
                 force_float32=True):
        super().__init__(env)
        o = self.env.reset()
        o, r, d, info = self.env.step(self.env.action_space.sample())
        env_ = self.env
        time_limit = isinstance(self.env, TimeLimit)
        while not time_limit and hasattr(env_, "env"):
            env_ = env_.env
            time_limit = isinstance(self.env, TimeLimit)
        if time_limit:
            info["timeout"] = False  # gym's TimeLimit.truncated invalid name.
        self._time_limit = time_limit
        self.action_space = GymSpaceWrapper(
            space=self.env.action_space,
            name="act",
            null_value=act_null_value,
            force_float32=force_float32,
        )
        self.observation_space = GymSpaceWrapper(
            space=self.env.observation_space,
            name="obs",
            null_value=obs_null_value,
            force_float32=force_float32,
        )
        self._info_schemas = {}
        self._build_info_schemas(info)

    def _build_info_schemas(self, info, name="info"):
        ntc = self._info_schemas.get(name)
        if ntc is None:
            self._info_schemas[name] = NamedTupleSchema(
                name, list(info.keys()))
        elif not (isinstance(ntc, NamedTupleSchema)
                  and sorted(ntc._fields) == sorted(list(info.keys()))):
            raise ValueError(f"Name clash in schema index: {name}.")
        for k, v in info.items():
            if isinstance(v, dict):
                self._build_info_schemas(v, "_".join([name, k]))

    def step(self, action):
        """Reverts the action from rlpyt format to gym format (i.e. if composite-to-
        dictionary spaces), steps the gym environment, converts the observation
        from gym to rlpyt format (i.e. if dict-to-composite), and converts the
        env_info from dictionary into namedtuple."""
        a = self.action_space.revert(action)
        o, r, d, info = self.env.step(a)
        obs = self.observation_space.convert(o)
        if self._time_limit:
            if "TimeLimit.truncated" in info:
                info["timeout"] = info.pop("TimeLimit.truncated")
            else:
                info["timeout"] = False
        info = info_to_nt(info, self._info_schemas)
        return EnvStep(obs, r, d, info)

    def reset(self):
        """Returns converted observation from gym env reset."""
        return self.observation_space.convert(self.env.reset())

    @property
    def spaces(self):
        """Returns the rlpyt spaces for the wrapped env."""
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )