Esempio n. 1
0
    def _unpack_observation(self, obs_batch):
        """Unpacks the action mask / tuple obs from agent grouping.

        Returns:
            obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
            mask (Tensor): action mask, if any
        """
        unpacked = _unpack_obs(
            np.array(obs_batch),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            obs = np.concatenate(
                unpacked,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions])
        return obs, action_mask
Esempio n. 2
0
    def _unpack_observation(self, obs_batch):
        """Unpacks the action mask / tuple obs from agent grouping.

        Returns:
            obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
            mask (Tensor): action mask, if any
        """
        unpacked = _unpack_obs(
            np.array(obs_batch),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            obs = np.concatenate(
                unpacked,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions])
        return obs, action_mask
Esempio n. 3
0
 def _unpack_observation(self, obs_batch):
     unpacked = _unpack_obs(np.array(obs_batch),
                            self.observation_space.original_space,
                            tensorlib=np)
     if self.has_action_mask:
         obs = np.concatenate([o["obs"] for o in unpacked], axis=1).reshape(
             [len(obs_batch), self.n_agents, self.obs_size])
         action_mask = np.concatenate([o["action_mask"] for o in unpacked],
                                      axis=1).reshape([
                                          len(obs_batch), self.n_agents,
                                          self.n_actions
                                      ])
     else:
         obs = np.concatenate(unpacked, axis=1).reshape(
             [len(obs_batch), self.n_agents, self.obs_size])
         action_mask = np.ones(
             [len(obs_batch), self.n_agents, self.n_actions])
     return obs, action_mask
Esempio n. 4
0
    def _unpack_observation(self, obs_batch):
        """Unpacks the observation, action mask, and state (if present)
        from agent grouping.

        Returns:
            obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
            mask (np.ndarray): action mask, if any
            state (np.ndarray or None): state tensor of shape [B, state_size]
                or None if it is not in the batch
        """
        unpacked = _unpack_obs(
            np.array(obs_batch, dtype=np.float32),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            if isinstance(unpacked[0], dict):
                unpacked_obs = [u["obs"] for u in unpacked]
            else:
                unpacked_obs = unpacked
            obs = np.concatenate(
                unpacked_obs,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions],
                dtype=np.float32)

        if self.has_env_global_state:
            state = unpacked[0][ENV_STATE]
        else:
            state = None
        return obs, action_mask, state