Пример #1
0
def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
    """
    Handle the different cases for images as PyTorch use channel first format.

    :param observation:
    :param observation_space:
    :return: channel first observation if observation is an image
    """
    # Avoid circular import
    from stable_baselines3.common.vec_env import VecTransposeImage

    if is_image_space(observation_space):
        if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
            # Try to re-order the channels
            transpose_obs = VecTransposeImage.transpose_image(observation)
            if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
                observation = transpose_obs
    return observation
Пример #2
0
    def predict(
        self,
        observation: np.ndarray,
        partner_idx: int = 0,
        state: Optional[np.ndarray] = None,
        mask: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """
        Get the policy action and state from an observation (and optional state).
        Includes sugar-coating to handle different observations (e.g. normalizing images).

        :param observation: (np.ndarray) the input observation
        :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
        :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
        :param deterministic: (bool) Whether or not to return deterministic actions.
        :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
            (used in recurrent policies)
        """
        # TODO (GH/1): add support for RNN policies
        # if state is None:
        #     state = self.initial_state
        # if mask is None:
        #     mask = [False for _ in range(self.n_envs)]
        observation = np.array(observation)

        # Handle the different cases for images
        # as PyTorch use channel first format
        if is_image_space(self.observation_space) and not (
                observation.shape == self.observation_space.shape
                or observation.shape[1:] == self.observation_space.shape):
            # Try to re-order the channels
            transpose_obs = VecTransposeImage.transpose_image(observation)
            if (transpose_obs.shape == self.observation_space.shape or
                    transpose_obs.shape[1:] == self.observation_space.shape):
                observation = transpose_obs

        vectorized_env = is_vectorized_observation(observation,
                                                   self.observation_space)

        observation = observation.reshape((-1, ) +
                                          self.observation_space.shape)

        observation = th.as_tensor(observation).to(self.device)
        with th.no_grad():
            actions = self._predict(observation,
                                    partner_idx=partner_idx,
                                    deterministic=deterministic)
        # Convert to numpy
        actions = actions.cpu().numpy()

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low,
                                  self.action_space.high)

        if not vectorized_env:
            if state is not None:
                raise ValueError(
                    "Error: The environment must be vectorized when using recurrent policies."
                )
            actions = actions[0]

        return actions, state
Пример #3
0
 def forward(self, obs):
     obs_transposed = VecTransposeImage.transpose_image(obs)
     latent, _, _ = self.ac_model._get_latent(
         th.tensor(obs_transposed).to(self.device))
     return self.reward_net(latent)