예제 #1
0
    def reset_wait(self, timeout=None):
        """
        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to `reset_wait` times out. If
            `None`, the call to `reset_wait` never times out.

        Returns
        -------
        observations : sample from `observation_space`
            A batch of observations from the vectorized environment.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_RESET:
            raise NoAsyncCallError('Calling `reset_wait` without any prior '
                'call to `reset_async`.', AsyncState.WAITING_RESET.value)

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError('The call to `reset_wait` has timed out after '
                '{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        if not self.shared_memory:
            concatenate(results, self.observations, self.single_observation_space)

        return deepcopy(self.observations) if self.copy else self.observations
예제 #2
0
    def reset_wait(
        self,
        timeout: Optional[Union[int, float]] = None,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
        """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.

        Args:
            timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
            seed: ignored
            return_info: If to return information
            options: ignored

        Returns:
            A tuple of batched observations and list of dictionaries

        Raises:
            ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
            NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`.
            TimeoutError: If :meth:`reset_wait` timed out.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_RESET:
            raise NoAsyncCallError(
                "Calling `reset_wait` without any prior "
                "call to `reset_async`.",
                AsyncState.WAITING_RESET.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `reset_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        if return_info:
            infos = {}
            results, info_data = zip(*results)
            for i, info in enumerate(info_data):
                infos = self._add_info(infos, info, i)

            if not self.shared_memory:
                self.observations = concatenate(self.single_observation_space,
                                                results, self.observations)

            return (deepcopy(self.observations)
                    if self.copy else self.observations), infos
        else:
            if not self.shared_memory:
                self.observations = concatenate(self.single_observation_space,
                                                results, self.observations)

            return deepcopy(
                self.observations) if self.copy else self.observations
예제 #3
0
    def call_wait(self, timeout=None):
        """
        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to `step_wait` times out. If
            `None` (default), the call to `step_wait` never times out.

        Returns
        -------
        results : list
            List of the results of the individual calls to the method or
            property for each environment.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_CALL:
            raise NoAsyncCallError(
                "Calling `call_wait` without any prior call to `call_async`.",
                AsyncState.WAITING_CALL.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `call_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        return results
예제 #4
0
    def step_wait(
        self,
        timeout: Optional[Union[int, float]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
        """Wait for the calls to :obj:`step` in each sub-environment to finish.

        Args:
            timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.

        Returns:
             The batched environment step information, obs, reward, done and info

        Raises:
            ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
            NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`.
            TimeoutError: If :meth:`step_wait` timed out.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_STEP:
            raise NoAsyncCallError(
                "Calling `step_wait` without any prior call "
                "to `step_async`.",
                AsyncState.WAITING_STEP.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `step_wait` has timed out after {timeout} second(s)."
            )

        observations_list, rewards, dones, infos = [], [], [], {}
        successes = []
        for i, pipe in enumerate(self.parent_pipes):
            result, success = pipe.recv()
            obs, rew, done, info = result

            successes.append(success)
            observations_list.append(obs)
            rewards.append(rew)
            dones.append(done)
            infos = self._add_info(infos, info, i)

        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        if not self.shared_memory:
            self.observations = concatenate(
                self.single_observation_space,
                observations_list,
                self.observations,
            )

        return (
            deepcopy(self.observations) if self.copy else self.observations,
            np.array(rewards),
            np.array(dones, dtype=np.bool_),
            infos,
        )
예제 #5
0
    def step_wait(self, timeout=None):
        """
        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to `step_wait` times out. If
            `None`, the call to `step_wait` never times out.

        Returns
        -------
        observations : sample from `observation_space`
            A batch of observations from the vectorized environment.

        rewards : `np.ndarray` instance (dtype `np.float_`)
            A vector of rewards from the vectorized environment.

        dones : `np.ndarray` instance (dtype `np.bool_`)
            A vector whose entries indicate whether the episode has ended.

        infos : list of dict
            A list of auxiliary diagnostic information.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_STEP:
            raise NoAsyncCallError(
                "Calling `step_wait` without any prior call "
                "to `step_async`.",
                AsyncState.WAITING_STEP.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                "The call to `step_wait` has timed out after "
                "{0} second{1}.".format(timeout, "s" if timeout > 1 else ""))

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT
        observations_list, rewards, dones, infos = zip(*results)

        if not self.shared_memory:
            self.observations = concatenate(observations_list,
                                            self.observations,
                                            self.single_observation_space)

        return (
            deepcopy(self.observations) if self.copy else self.observations,
            np.array(rewards),
            np.array(dones, dtype=np.bool_),
            infos,
        )
예제 #6
0
    def reset_wait(self, timeout=None):
        """Wait for the calls to :obj:`reset` in each sub-environment to finish.

        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to :meth:`reset_wait` times out.
            If ``None``, the call to :meth:`reset_wait` never times out.

        Returns
        -------
        element of :attr:`~VectorEnv.observation_space`
            A batch of observations from the vectorized environment.

        Raises
        ------
        ClosedEnvironmentError
            If the environment was closed (if :meth:`close` was previously called).

        NoAsyncCallError
            If :meth:`reset_wait` was called without any prior call to
            :meth:`reset_async`.

        TimeoutError
            If :meth:`reset_wait` timed out.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_RESET:
            raise NoAsyncCallError(
                "Calling `reset_wait` without any prior " "call to `reset_async`.",
                AsyncState.WAITING_RESET.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `reset_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        if not self.shared_memory:
            self.observations = concatenate(
                results, self.observations, self.single_observation_space
            )

        return deepcopy(self.observations) if self.copy else self.observations
예제 #7
0
    def render_wait(self, timeout=None):
        self._assert_is_running()
        if self._state.value != AsyncState.WAITING_RENDER.value:
            raise NoAsyncCallError(
                'Calling `render_wait` without any prior '
                'call to `render_async`.', AsyncState.WAITING_RESET.value)

        if not self._poll(timeout):
            self._state = self.default_state
            raise mp.TimeoutError(
                'The call to `render_wait` has timed out after '
                '{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))

        result, success = self.parent_pipes[0].recv()
        self._raise_if_errors([success])
        self._state = self.default_state

        return result
예제 #8
0
    def step_wait(self, timeout=None):
        """Wait for the calls to :obj:`step` in each sub-environment to finish.

        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to :meth:`step_wait` times out. If
            ``None``, the call to :meth:`step_wait` never times out.

        Returns
        -------
        observations : element of :attr:`~VectorEnv.observation_space`
            A batch of observations from the vectorized environment.

        rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
            A vector of rewards from the vectorized environment.

        dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
            A vector whose entries indicate whether the episode has ended.

        infos : list of dict
            A list of auxiliary diagnostic information dicts from sub-environments.

        Raises
        ------
        ClosedEnvironmentError
            If the environment was closed (if :meth:`close` was previously called).

        NoAsyncCallError
            If :meth:`step_wait` was called without any prior call to
            :meth:`step_async`.

        TimeoutError
            If :meth:`step_wait` timed out.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_STEP:
            raise NoAsyncCallError(
                "Calling `step_wait` without any prior call "
                "to `step_async`.",
                AsyncState.WAITING_STEP.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `step_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT
        observations_list, rewards, dones, infos = zip(*results)

        if not self.shared_memory:
            self.observations = concatenate(
                self.single_observation_space,
                observations_list,
                self.observations,
            )

        return (
            deepcopy(self.observations) if self.copy else self.observations,
            np.array(rewards),
            np.array(dones, dtype=np.bool_),
            infos,
        )
예제 #9
0
    def reset_wait(
        self,
        timeout=None,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        """
        Parameters
        ----------
        timeout : int or float, optional
            Number of seconds before the call to `reset_wait` times out. If
            `None`, the call to `reset_wait` never times out.
        seed: ignored
        options: ignored

        Returns
        -------
        element of :attr:`~VectorEnv.observation_space`
            A batch of observations from the vectorized environment.
        infos : list of dicts containing metadata

        Raises
        ------
        ClosedEnvironmentError
            If the environment was closed (if :meth:`close` was previously called).

        NoAsyncCallError
            If :meth:`reset_wait` was called without any prior call to
            :meth:`reset_async`.

        TimeoutError
            If :meth:`reset_wait` timed out.
        """
        self._assert_is_running()
        if self._state != AsyncState.WAITING_RESET:
            raise NoAsyncCallError(
                "Calling `reset_wait` without any prior "
                "call to `reset_async`.",
                AsyncState.WAITING_RESET.value,
            )

        if not self._poll(timeout):
            self._state = AsyncState.DEFAULT
            raise mp.TimeoutError(
                f"The call to `reset_wait` has timed out after {timeout} second(s)."
            )

        results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
        self._raise_if_errors(successes)
        self._state = AsyncState.DEFAULT

        if return_info:
            results, infos = zip(*results)
            infos = list(infos)

            if not self.shared_memory:
                self.observations = concatenate(self.single_observation_space,
                                                results, self.observations)

            return (deepcopy(self.observations)
                    if self.copy else self.observations), infos
        else:
            if not self.shared_memory:
                self.observations = concatenate(self.single_observation_space,
                                                results, self.observations)

            return deepcopy(
                self.observations) if self.copy else self.observations