def __init__(self, space, null_value=0, name="obs", force_float32=True,
              schemas=None):
     """Input ``space`` is a gym space instance.  
     
     Input ``name`` governs naming of internal NamedTupleSchemas used to
     store Gym info.
     """
     self._gym_space = space
     self._base_name = name
     self._null_value = null_value
     if schemas is None:
         schemas = {}
     self._schemas = schemas
     if isinstance(space, GymDict):
         nt = self._schemas.get(name)
         if nt is None:
             nt = NamedTupleSchema(name, [k for k in space.spaces.keys()])
             schemas[name] = nt  # Put at module level for pickle.
         elif not (isinstance(nt, NamedTupleSchema) and
                 sorted(nt._fields) ==
                 sorted([k for k in space.spaces.keys()])):
             raise ValueError(f"Name clash in schemas: {name}.")
         spaces = [GymSpaceWrapper(
             space=v,
             null_value=null_value,
             name="_".join([name, k]),
             force_float32=force_float32,
             schemas=schemas)
             for k, v in space.spaces.items()]
         self.space = Composite(spaces, nt)
         self._dtype = None
     else:
         self.space = space
         self._dtype = np.float32 if (space.dtype == np.float64 and
             force_float32) else None
Exemple #2
0
 def _build_info_schemas(self, info, name="info"):
     ntc = self._info_schemas.get(name)
     if ntc is None:
         self._info_schemas[name] = NamedTupleSchema(
             name, list(info.keys()))
     elif not (isinstance(ntc, NamedTupleSchema)
               and sorted(ntc._fields) == sorted(list(info.keys()))):
         raise ValueError(f"Name clash in schema index: {name}.")
     for k, v in info.items():
         if isinstance(v, dict):
             self._build_info_schemas(v, "_".join([name, k]))
Exemple #3
0
        )  # Wait for workers ready (e.g. decorrelate).
        return examples  # e.g. In case useful to build replay buffer.


class MuxCpuSampler(MuxParallelSampler, CpuSampler):
    pass


class MuxGpuSampler(MuxParallelSampler, GpuSampler):
    pass


EnvIDObs = NamedTupleSchema(
    'EnvIDObs',
    # Conventions:
    # - variant_id is 0 for the demonstration variant
    # - source_id is 0 for novice rollouts, and can take values 1,2,… for all
    #   others
    ('observation', 'task_id', 'variant_id', 'source_id'))
EnvIDObsArray = NamedArrayTupleSchema(EnvIDObs._typename, EnvIDObs._fields)


class EnvIDWrapper(Wrapper):
    def __init__(self, env, task_id, variant_id, source_id, num_tasks,
                 max_num_variants, num_demo_sources):
        super().__init__(env)
        assert isinstance(task_id, int) and isinstance(variant_id, int)
        self._task_id_np = np.asarray([task_id]).reshape(())
        self._variant_id_np = np.asarray([variant_id]).reshape(())
        self._source_id = np.asarray([source_id]).reshape(())
        task_space = IntBox(0, num_tasks)