def __init__(self, space, null_value=0, name="obs", force_float32=True): """Input ``space`` is a gym space instance. Input ``name`` is used to disambiguate different gym spaces being wrapped, which is necessary if more than one GymDict space is to be wrapped in the same file. The reason is that the associated namedtuples must be defined in the globals of this file, so they must have distinct names. """ self._gym_space = space self._base_name = name self._null_value = null_value if isinstance(space, GymDict): nt = globals().get(name) if nt is None: nt = namedtuple(name, [k for k in space.spaces.keys()]) globals()[name] = nt # Put at module level for pickle. elif not (is_namedtuple_class(nt) and sorted(nt._fields) == sorted([k for k in space.spaces.keys()])): raise ValueError(f"Name clash in globals: {name}.") spaces = [GymSpaceWrapper( space=v, null_value=null_value, name="_".join([name, k]), force_float32=force_float32) for k, v in space.spaces.items()] self.space = Composite(spaces, nt) self._dtype = None else: self.space = space self._dtype = np.float32 if (space.dtype == np.float64 and force_float32) else None
def __init__(self, space, null_value=0, name="obs", force_float32=True): self._gym_space = space self._base_name = name self._null_value = null_value if isinstance(space, GymDict): nt = globals().get(name) if nt is None: nt = namedtuple(name, [k for k in space.spaces.keys()]) globals()[name] = nt # Put at module level for pickle. elif not (is_namedtuple_class(nt) and sorted(nt._fields) == sorted( [k for k in space.spaces.keys()])): raise ValueError(f"Name clash in globals: {name}.") spaces = [ GymSpaceWrapper(space=v, null_value=null_value, name="_".join([name, k]), force_float32=force_float32) for k, v in space.spaces.items() ] self.space = Composite(spaces, nt) self._dtype = None else: self.space = space self._dtype = np.float32 if (space.dtype == np.float64 and force_float32) else None
def build_info_tuples(info, name="info"): ntc = globals().get(name) # Define at module level for pickle. if ntc is None: globals()[name] = namedtuple(name, list(info.keys())) elif not (is_namedtuple_class(ntc) and sorted(ntc._fields) == sorted(list(info.keys()))): raise ValueError(f"Name clash in globals: {name}.") for k, v in info.items(): if isinstance(v, dict): build_info_tuples(v, "_".join([name, k]))
def build_info_tuples(info, name="info"): # Define namedtuples at module level for pickle. # Only place rlpyt uses pickle is in the sampler, when getting the # first examples, to avoid MKL threading issues...can probably turn # that off, (look for subprocess=True --> False), and then might # be able to define these directly within the class. ntc = globals().get(name) # Define at module level for pickle. if ntc is None: globals()[name] = namedtuple(name, list(info.keys())) elif not (is_namedtuple_class(ntc) and sorted(ntc._fields) == sorted(list(info.keys()))): raise ValueError(f"Name clash in globals: {name}.") for k, v in info.items(): if isinstance(v, dict): build_info_tuples(v, "_".join([name, k]))