예제 #1
0
    def __init__(self,
                 name: str,
                 n_repeat_action: int = 1,
                 n_workers: int = 8,
                 blocking: bool = True,
                 **kwargs):
        """

        :param name: Name of the Environment
        :param env_class: Class of the environment to be wrapped.
        :param n_workers: number of workers that will be used.
        :param blocking: step the environments asynchronously.
        :param args: args of the environment that will be parallelized.
        :param kwargs: kwargs of the environment that will be parallelized.
        """

        super(ParallelDMControl, self).__init__(name=name)

        envs = [
            ExternalDMControl(name=name,
                              n_repeat_action=n_repeat_action,
                              **kwargs) for _ in range(n_workers)
        ]
        self._batch_env = BatchEnv(envs, blocking)
        self._env = DMControlEnv(name,
                                 n_repeat_action=n_repeat_action,
                                 **kwargs)

        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space
예제 #2
0
class ParallelDMControl(Environment):
    """Wrap a dm_control environment to be stepped in parallel."""

    def __init__(self, name: str, n_repeat_action: int=1, n_workers: int=8,
                 blocking: bool=True, **kwargs):
        """

        :param name: Name of the Environment
        :param env_class: Class of the environment to be wrapped.
        :param n_workers: number of workers that will be used.
        :param blocking: step the environments asynchronously.
        :param args: args of the environment that will be parallelized.
        :param kwargs: kwargs of the environment that will be parallelized.
        """

        super(ParallelDMControl, self).__init__(name=name)

        envs = [ExternalDMControl(name=name, n_repeat_action=n_repeat_action, **kwargs)
                for _ in range(n_workers)]
        self._batch_env = BatchEnv(envs, blocking)
        self._env = DMControlEnv(name, n_repeat_action=n_repeat_action, **kwargs)

        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space

    def __getattr__(self, item):
        return getattr(self._env, item)

    def step_batch(self, actions: np.ndarray, states: np.ndarray=None,
                   n_repeat_action: [np.ndarray, int]=None):
        return self._batch_env.step_batch(actions=actions, states=states,
                                          n_repeat_action=n_repeat_action)

    def step(self, action: np.ndarray, state: np.ndarray=None, n_repeat_action: int=None):
        return self._env.step(action=action, state=state, n_repeat_action=n_repeat_action)

    def reset(self, return_state: bool = True, blocking: bool=True):
        state, obs = self._env.reset(return_state=True)
        self.sync_states()
        return state, obs if return_state else obs

    def get_state(self):
        return self._env.get_state()

    def set_state(self, state):
        self._env.set_state(state)
        self.sync_states()

    def sync_states(self):
        self._batch_env.sync_states(self.get_state())
예제 #3
0
    def __init__(self,
                 name: str,
                 n_repeat_action: int = 1,
                 height: float = 100,
                 width: float = 100,
                 wrappers=None,
                 n_workers: int = 8,
                 blocking: bool = True,
                 **kwargs):
        """

        :param name: Name of the Environment
        :param env_class: Class of the environment to be wrapped.
        :param n_workers: number of workers that will be used.
        :param blocking: step the environments asynchronously.
        :param args: args of the environment that will be parallelized.
        :param kwargs: kwargs of the environment that will be parallelized.
        """

        super(ParallelRetro, self).__init__(name=name)

        envs = [
            ExternalRetro(name=name,
                          n_repeat_action=n_repeat_action,
                          height=height,
                          width=width,
                          wrappers=wrappers,
                          **kwargs) for _ in range(n_workers)
        ]
        self._batch_env = BatchEnv(envs, blocking)
        self._env = RetroEnvironment(name,
                                     n_repeat_action=n_repeat_action,
                                     height=height,
                                     width=width,
                                     wrappers=wrappers,
                                     **kwargs)
        self._env.init_env()
        self.observation_space = self._env.observation_space
        self.action_space = self._env.action_space