def __init__(self, name: str, n_repeat_action: int = 1, n_workers: int = 8, blocking: bool = True, **kwargs): """ :param name: Name of the Environment :param env_class: Class of the environment to be wrapped. :param n_workers: number of workers that will be used. :param blocking: step the environments asynchronously. :param args: args of the environment that will be parallelized. :param kwargs: kwargs of the environment that will be parallelized. """ super(ParallelDMControl, self).__init__(name=name) envs = [ ExternalDMControl(name=name, n_repeat_action=n_repeat_action, **kwargs) for _ in range(n_workers) ] self._batch_env = BatchEnv(envs, blocking) self._env = DMControlEnv(name, n_repeat_action=n_repeat_action, **kwargs) self.observation_space = self._env.observation_space self.action_space = self._env.action_space
class ParallelDMControl(Environment): """Wrap a dm_control environment to be stepped in parallel.""" def __init__(self, name: str, n_repeat_action: int=1, n_workers: int=8, blocking: bool=True, **kwargs): """ :param name: Name of the Environment :param env_class: Class of the environment to be wrapped. :param n_workers: number of workers that will be used. :param blocking: step the environments asynchronously. :param args: args of the environment that will be parallelized. :param kwargs: kwargs of the environment that will be parallelized. """ super(ParallelDMControl, self).__init__(name=name) envs = [ExternalDMControl(name=name, n_repeat_action=n_repeat_action, **kwargs) for _ in range(n_workers)] self._batch_env = BatchEnv(envs, blocking) self._env = DMControlEnv(name, n_repeat_action=n_repeat_action, **kwargs) self.observation_space = self._env.observation_space self.action_space = self._env.action_space def __getattr__(self, item): return getattr(self._env, item) def step_batch(self, actions: np.ndarray, states: np.ndarray=None, n_repeat_action: [np.ndarray, int]=None): return self._batch_env.step_batch(actions=actions, states=states, n_repeat_action=n_repeat_action) def step(self, action: np.ndarray, state: np.ndarray=None, n_repeat_action: int=None): return self._env.step(action=action, state=state, n_repeat_action=n_repeat_action) def reset(self, return_state: bool = True, blocking: bool=True): state, obs = self._env.reset(return_state=True) self.sync_states() return state, obs if return_state else obs def get_state(self): return self._env.get_state() def set_state(self, state): self._env.set_state(state) self.sync_states() def sync_states(self): self._batch_env.sync_states(self.get_state())
def __init__(self, name: str, n_repeat_action: int = 1, height: float = 100, width: float = 100, wrappers=None, n_workers: int = 8, blocking: bool = True, **kwargs): """ :param name: Name of the Environment :param env_class: Class of the environment to be wrapped. :param n_workers: number of workers that will be used. :param blocking: step the environments asynchronously. :param args: args of the environment that will be parallelized. :param kwargs: kwargs of the environment that will be parallelized. """ super(ParallelRetro, self).__init__(name=name) envs = [ ExternalRetro(name=name, n_repeat_action=n_repeat_action, height=height, width=width, wrappers=wrappers, **kwargs) for _ in range(n_workers) ] self._batch_env = BatchEnv(envs, blocking) self._env = RetroEnvironment(name, n_repeat_action=n_repeat_action, height=height, width=width, wrappers=wrappers, **kwargs) self._env.init_env() self.observation_space = self._env.observation_space self.action_space = self._env.action_space