def __init__(self, space, null_value=0, name="obs", force_float32=True, schemas=None): """Input ``space`` is a gym space instance. Input ``name`` governs naming of internal NamedTupleSchemas used to store Gym info. """ self._gym_space = space self._base_name = name self._null_value = null_value if schemas is None: schemas = {} self._schemas = schemas if isinstance(space, GymDict): nt = self._schemas.get(name) if nt is None: nt = NamedTupleSchema(name, [k for k in space.spaces.keys()]) schemas[name] = nt # Put at module level for pickle. elif not (isinstance(nt, NamedTupleSchema) and sorted(nt._fields) == sorted([k for k in space.spaces.keys()])): raise ValueError(f"Name clash in schemas: {name}.") spaces = [GymSpaceWrapper( space=v, null_value=null_value, name="_".join([name, k]), force_float32=force_float32, schemas=schemas) for k, v in space.spaces.items()] self.space = Composite(spaces, nt) self._dtype = None else: self.space = space self._dtype = np.float32 if (space.dtype == np.float64 and force_float32) else None
def _build_info_schemas(self, info, name="info"): ntc = self._info_schemas.get(name) if ntc is None: self._info_schemas[name] = NamedTupleSchema( name, list(info.keys())) elif not (isinstance(ntc, NamedTupleSchema) and sorted(ntc._fields) == sorted(list(info.keys()))): raise ValueError(f"Name clash in schema index: {name}.") for k, v in info.items(): if isinstance(v, dict): self._build_info_schemas(v, "_".join([name, k]))
) # Wait for workers ready (e.g. decorrelate). return examples # e.g. In case useful to build replay buffer. class MuxCpuSampler(MuxParallelSampler, CpuSampler): pass class MuxGpuSampler(MuxParallelSampler, GpuSampler): pass EnvIDObs = NamedTupleSchema( 'EnvIDObs', # Conventions: # - variant_id is 0 for the demonstration variant # - source_id is 0 for novice rollouts, and can take values 1,2,… for all # others ('observation', 'task_id', 'variant_id', 'source_id')) EnvIDObsArray = NamedArrayTupleSchema(EnvIDObs._typename, EnvIDObs._fields) class EnvIDWrapper(Wrapper): def __init__(self, env, task_id, variant_id, source_id, num_tasks, max_num_variants, num_demo_sources): super().__init__(env) assert isinstance(task_id, int) and isinstance(variant_id, int) self._task_id_np = np.asarray([task_id]).reshape(()) self._variant_id_np = np.asarray([variant_id]).reshape(()) self._source_id = np.asarray([source_id]).reshape(()) task_space = IntBox(0, num_tasks)