Exemple #1
0
    def _convert_observation(  # type: ignore[override]
            self, agent: str, observe: Union[dict, np.ndarray],
            done: bool) -> types.OLT:

        legals: np.ndarray = None
        observation: np.ndarray = None

        if isinstance(observe, dict) and "action_mask" in observe:
            legals = observe["action_mask"]
            observation = observe["observation"]
        else:
            legals = np.ones(
                _convert_to_spec(self._environment.action_spaces[agent]).shape,
                dtype=self._environment.action_spaces[agent].dtype,
            )
            observation = observe
        if observation.dtype == np.int8:
            observation = np.dtype(np.float32).type(
                observation)  # observation is not expected to be int8
        if legals.dtype == np.int8:
            legals = np.dtype(np.int64).type(legals)

        observation = types.OLT(
            observation=observation,
            legal_actions=legals,
            terminal=np.asarray([done], dtype=np.float32),
        )
        return observation
Exemple #2
0
 def observation_spec(self) -> types.Observation:
     observation_specs = {}
     for agent in self.possible_agents:
         spec = self._environment.observation_spec()
         observation_specs[agent] = types.OLT(
             observation=specs.Array(spec["info_state"], np.float32),
             legal_actions=specs.Array(spec["legal_actions"], np.float32),
             terminal=specs.Array((1, ), np.float32),
         )
     return observation_specs
Exemple #3
0
 def observation_spec(self) -> types.Observation:
     return {
         agent: types.OLT(
             observation=_convert_to_spec(
                 self.observation_space["observation"]),
             legal_actions=_convert_to_spec(
                 self.observation_space["action_mask"]),
             terminal=specs.Array((1, ), np.float32),
         )
         for agent in self._possible_agents
     }
Exemple #4
0
 def observation_spec(self) -> types.Observation:
     observation_specs = {}
     for agent in self.possible_agents:
         observation_specs[agent] = types.OLT(
             observation=_convert_to_spec(
                 self._environment.observation_spaces[agent]),
             legal_actions=_convert_to_spec(
                 self._environment.action_spaces[agent]),
             terminal=specs.Array((1, ), np.float32),
         )
     return observation_specs
Exemple #5
0
    def _convert_observation(  # type: ignore[override]
            self, agent: str, observe: Union[dict, np.ndarray],
            done: bool) -> types.OLT:
        if isinstance(observe, dict):
            legals = np.array(observe["action_mask"], np.float32)
            observation = np.array(observe["observation"])
        else:
            legals = np.ones(self.num_actions, np.float32)
            observation = np.array(observe)

        observation = types.OLT(
            observation=observation,
            legal_actions=legals,
            terminal=np.asarray([done], dtype=np.float32),
        )

        return observation
Exemple #6
0
    def _convert_observations(self, observes: Dict[str, np.ndarray],
                              dones: Dict[str, bool]) -> types.Observation:
        observations: Dict[str, types.OLT] = {}
        for agent, observation in observes.items():
            if isinstance(observation, dict) and "action_mask" in observation:
                legals = observation["action_mask"]
                observation = observation["observation"]
            else:
                legals = np.ones(
                    _convert_to_spec(self.action_space).shape,
                    dtype=self.action_space.dtype,
                )
            observations[agent] = types.OLT(
                observation=observation,
                legal_actions=legals,
                terminal=np.asarray([dones[agent]], dtype=np.float32),
            )

        return observations
Exemple #7
0
    def _convert_observations(self, observes: Dict[str, np.ndarray],
                              dones: Dict[str, bool]) -> types.Observation:
        observations: Dict[str, types.OLT] = {}
        for agent, observation in observes.items():
            if isinstance(observation, dict) and "action_mask" in observation:
                legals = observation["action_mask"]
                observation = observation["observation"]
            else:
                # TODO Handle legal actions better for continous envs,
                #  maybe have min and max for each action and clip the agents actions
                #  accordingly
                legals = np.ones(
                    _convert_to_spec(
                        self._environment.action_spaces[agent]).shape,
                    dtype=self._environment.action_spaces[agent].dtype,
                )

            observations[agent] = types.OLT(
                observation=observation,
                legal_actions=legals,
                terminal=np.asarray([dones[agent]], dtype=np.float32),
            )

        return observations
Exemple #8
0
 def observation_spec(self) -> types.Observation:
     observation_specs = {}
     for agent in self._environment.possible_agents:
         if isinstance(self._environment.observation_spaces[agent],
                       gym.spaces.Dict):
             obs_space = copy.deepcopy(
                 self._environment.observation_spaces[agent]["observation"])
             legal_actions_space = copy.deepcopy(
                 self._environment.observation_spaces[agent]["action_mask"])
         else:
             obs_space = copy.deepcopy(
                 self._environment.observation_spaces[agent])
             legal_actions_space = copy.deepcopy(
                 self._environment.action_spaces[agent])
         if obs_space.dtype == np.int8:
             obs_space.dtype = np.dtype(np.float32)
         if legal_actions_space.dtype == np.int8:
             legal_actions_space.dtype = np.dtype(np.int64)
         observation_specs[agent] = types.OLT(
             observation=_convert_to_spec(obs_space),
             legal_actions=_convert_to_spec(legal_actions_space),
             terminal=specs.Array((1, ), np.float32),
         )
     return observation_specs