Example #1
0
 def __init__(self, env: gym.Env, task_label: Optional[int],
              task_label_space: gym.Space):
     super().__init__(env=env)
     self.task_label = task_label
     self.task_label_space = task_label_space
     self.observation_space = add_task_labels(self.env.observation_space,
                                              task_labels=task_label_space)
Example #2
0
 def __init__(self, envs: List[gym.Env], add_task_ids: bool = False):
     self._envs = envs
     self._current_task_id = 0
     self.nb_tasks = len(envs)
     self._envs_is_closed: Sequence[bool] = np.zeros([self.nb_tasks], dtype=bool)
     self._add_task_labels = add_task_ids
     self.rng: np.random.Generator = np.random.default_rng()
     super().__init__(env=self._envs[self._current_task_id])
     self.task_label_space = spaces.Discrete(self.nb_tasks)
     if self._add_task_labels:
         self.observation_space = add_task_labels(
             self.env.observation_space, self.task_label_space
         )
Example #3
0
 def set_task(self, task_id: int) -> None:
     if self.is_closed(env_index=None):
         raise gym.error.ClosedEnvironmentError(
             f"Can't call set_task on the env, since it's already closed."
         )
     self._current_task_id = task_id
     # Use super().__init__() to reset the `self.env` attribute in gym.Wrapper.
     # TODO: This also resets the '_is_closed' on self.
     # TODO: This resets the 'observation_' and 'action_' etc objects that are saved
     # in the constructor of the 'IterableWrapper'
     gym.Wrapper.__init__(self, env=self._envs[self._current_task_id])
     if self._add_task_labels:
         self.observation_space = add_task_labels(
             self.env.observation_space, self.task_label_space
         )
Example #4
0
 def observation(self, observation: Union[IncrementalRLSetting.Observations,
                                          Any]):
     return add_task_labels(observation, self.task_label)
Example #5
0
 def observation(self, observation):
     if self._add_task_labels:
         return add_task_labels(observation, task_labels=self._current_task_id)
     return observation