Exemplo n.º 1
0
class GymSpaceWrapper:
    """Wraps a gym space to match the rlpyt interface; most of
    the functionality is for automatically converting a GymDict (dictionary)
    space into an rlpyt Composite space (and converting between the two).  Use
    inside the initialization of the environment wrapper for a gym environment.
    """
    def __init__(self,
                 space,
                 null_value=0,
                 name="obs",
                 force_float32=True,
                 schemas=None):
        """Input ``space`` is a gym space instance.

        Input ``name`` governs naming of internal NamedTupleSchemas used to
        store Gym info.
        """
        self._gym_space = space
        self._base_name = name
        self._null_value = null_value
        if schemas is None:
            schemas = {}
        self._schemas = schemas
        if isinstance(space, GymDict):
            nt = self._schemas.get(name)
            if nt is None:
                nt = NamedTupleSchema(name, [k for k in space.spaces.keys()])
                schemas[name] = nt  # Put at module level for pickle.
            elif not (isinstance(nt, NamedTupleSchema) and sorted(nt._fields)
                      == sorted([k for k in space.spaces.keys()])):
                raise ValueError(f"Name clash in schemas: {name}.")
            spaces = [
                GymSpaceWrapper(
                    space=v,
                    null_value=null_value,
                    name="_".join([name, k]),
                    force_float32=force_float32,
                    schemas=schemas,
                ) for k, v in space.spaces.items()
            ]
            self.space = Composite(spaces, nt)
            self._dtype = None
        else:
            self.space = space
            self._dtype = (np.float32 if (space.dtype == np.float64
                                          and force_float32) else None)

    def sample(self):
        """Returns a single sample in a namedtuple (for composite) or numpy
        array using the the ``sample()`` method of the underlying gym
        space(s)."""
        sample = self.space.sample()
        if self.space is self._gym_space:  # Not Composite.
            # Force numpy array, might force float64->float32.
            sample = np.asarray(sample, dtype=self._dtype)
        return sample

    def null_value(self):
        """Similar to ``sample()`` but returning a null value."""
        if self.space is self._gym_space:
            null = np.asarray(self.space.sample(), dtype=self._dtype)
            if self._null_value is not None:
                try:
                    null[:] = self._null_value
                except IndexError:  # e.g. scalar.
                    null.fill(self._null_value)
            else:
                null.fill(0)
        else:  # Is composite.
            null = self.space.null_value()
        return null

    def convert(self, value):
        """For dictionary space, use to convert wrapped env's dict to rlpyt
        namedtuple, i.e. inside the environment wrapper's ``step()``, for
        observation output to the rlpyt sampler (see helper function in
        file)"""
        return dict_to_nt(value, name=self._base_name, schemas=self._schemas)

    def revert(self, value):
        """For dictionary space, use to revert namedtuple action into wrapped
        env's dict, i.e. inside the environment wrappers ``step()``, for input
        to the underlying gym environment (see helper function in file)."""
        return nt_to_dict(value)

    @property
    def dtype(self):
        return self._dtype or self.space.dtype

    @property
    def shape(self):
        return self.space.shape

    def contains(self, x):
        return self.space.contains(x)

    def __repr__(self):
        return self.space.__repr__()

    def __eq__(self, other):
        return self.space.__eq__(other)

    @property
    def low(self):
        return self.space.low

    @property
    def high(self):
        return self.space.high

    @property
    def n(self):
        return self.space.n

    def seed(self, seed=None):
        if type(self.space) is Composite:
            return [space.seed(seed=seed) for space in self.space.spaces]
        else:
            return self.space.seed(seed=seed)
Exemplo n.º 2
0
class GymSpaceWrapper:
    """Wraps a gym space to interface from dictionaries to namedtuples."""
    def __init__(self, space, null_value=0, name="obs", force_float32=True):
        self._gym_space = space
        self._base_name = name
        self._null_value = null_value
        if isinstance(space, GymDict):
            nt = globals().get(name)
            if nt is None:
                nt = namedtuple(name, [k for k in space.spaces.keys()])
                globals()[name] = nt  # Put at module level for pickle.
            elif not (is_namedtuple_class(nt) and sorted(nt._fields) == sorted(
                [k for k in space.spaces.keys()])):
                raise ValueError(f"Name clash in globals: {name}.")
            spaces = [
                GymSpaceWrapper(space=v,
                                null_value=null_value,
                                name="_".join([name, k]),
                                force_float32=force_float32)
                for k, v in space.spaces.items()
            ]
            self.space = Composite(spaces, nt)
            self._dtype = None
        else:
            self.space = space
            self._dtype = np.float32 if (space.dtype == np.float64
                                         and force_float32) else None

    def sample(self):
        sample = self.space.sample()
        if self.space is self._gym_space:  # Not Composite.
            # Force numpy array, might force float64->float32.
            sample = np.asarray(sample, dtype=self._dtype)
        return sample

    def null_value(self):
        if self.space is self._gym_space:
            null = np.asarray(self.space.sample(), dtype=self._dtype)
            if self._null_value is not None:
                try:
                    null[:] = self._null_value
                except IndexError:  # e.g. scalar.
                    null.fill(self._null_value)
            else:
                null.fill(0)
        else:  # Is composite.
            null = self.space.null_value()
        return null

    def convert(self, value):
        # Convert wrapped env's observation from dict to namedtuple.
        return dict_to_nt(value, name=self._base_name)

    def revert(self, value):
        # Revert namedtuple action into wrapped env's dict.
        return nt_to_dict(value)

    @property
    def dtype(self):
        return self._dtype or self.space.dtype

    @property
    def shape(self):
        return self.space.shape

    def contains(self, x):
        return self.space.contains(x)

    def __repr__(self):
        return self.space.__repr__()

    def __eq__(self, other):
        return self.space.__eq__(other)

    @property
    def low(self):
        return self.space.low

    @property
    def high(self):
        return self.space.high

    @property
    def n(self):
        return self.space.n