def _unpack_observation(self, obs_batch): """Unpacks the action mask / tuple obs from agent grouping. Returns: obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size] mask (Tensor): action mask, if any """ unpacked = _unpack_obs( np.array(obs_batch), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate( [o["obs"] for o in unpacked], axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: obs = np.concatenate( unpacked, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions]) return obs, action_mask
def _unpack_observation(self, obs_batch): unpacked = _unpack_obs(np.array(obs_batch), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate([o["obs"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate([o["action_mask"] for o in unpacked], axis=1).reshape([ len(obs_batch), self.n_agents, self.n_actions ]) else: obs = np.concatenate(unpacked, axis=1).reshape( [len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions]) return obs, action_mask
def _unpack_observation(self, obs_batch): """Unpacks the observation, action mask, and state (if present) from agent grouping. Returns: obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size] mask (np.ndarray): action mask, if any state (np.ndarray or None): state tensor of shape [B, state_size] or None if it is not in the batch """ unpacked = _unpack_obs( np.array(obs_batch, dtype=np.float32), self.observation_space.original_space, tensorlib=np) if self.has_action_mask: obs = np.concatenate( [o["obs"] for o in unpacked], axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.concatenate( [o["action_mask"] for o in unpacked], axis=1).reshape( [len(obs_batch), self.n_agents, self.n_actions]) else: if isinstance(unpacked[0], dict): unpacked_obs = [u["obs"] for u in unpacked] else: unpacked_obs = unpacked obs = np.concatenate( unpacked_obs, axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size]) action_mask = np.ones( [len(obs_batch), self.n_agents, self.n_actions], dtype=np.float32) if self.has_env_global_state: state = unpacked[0][ENV_STATE] else: state = None return obs, action_mask, state