Exemple #1
0
 def send_actions(self, action_dict: MultiEnvDict) -> None:
     if self.multiagent:
         for env_id, actions in action_dict.items():
             self.external_env._episodes[env_id].action_queue.put(actions)
     else:
         for env_id, action in action_dict.items():
             self.external_env._episodes[env_id].action_queue.put(
                 action[_DUMMY_AGENT_ID])
Exemple #2
0
    def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Returns:
            True if the observations of x are contained in space.
        """
        agents = set(self.get_agent_ids())
        for multi_agent_dict in x.values():
            for agent_id, obs in multi_agent_dict.items():
                # this is for the case where we have a single agent
                # and we're checking a Vector env thats been converted to
                # a BaseEnv
                if agent_id == _DUMMY_AGENT_ID:
                    if not space.contains(obs):
                        return False
                # for the MultiAgent env case
                elif (agent_id
                      not in agents) or (not space[agent_id].contains(obs)):
                    return False

        return True
Exemple #3
0
    def send_actions(self, action_dict: MultiEnvDict) -> None:
        for env_id, agent_dict in action_dict.items():
            if env_id in self.dones:
                raise ValueError(
                    f"Env {env_id} is already done and cannot accept new actions"
                )
            env = self.envs[env_id]
            try:
                obs, rewards, dones, infos = env.step(agent_dict)
            except Exception as e:
                if self.restart_failed_sub_environments:
                    logger.exception(e.args[0])
                    self.try_restart(env_id=env_id)
                    obs, rewards, dones, infos = e, {}, {"__all__": True}, {}
                else:
                    raise e

            assert isinstance(
                obs,
                (dict,
                 Exception)), "Not a multi-agent obs dict or an Exception!"
            assert isinstance(rewards, dict), "Not a multi-agent reward dict!"
            assert isinstance(dones, dict), "Not a multi-agent done dict!"
            assert isinstance(infos, dict), "Not a multi-agent info dict!"
            if isinstance(obs, dict) and set(infos).difference(set(obs)):
                raise ValueError("Key set for infos must be a subset of obs: "
                                 "{} vs {}".format(infos.keys(), obs.keys()))
            if "__all__" not in dones:
                raise ValueError(
                    "In multi-agent environments, '__all__': True|False must "
                    "be included in the 'done' dict: got {}.".format(dones))

            if dones["__all__"]:
                self.dones.add(env_id)
            self.env_states[env_id].observe(obs, rewards, dones, infos)
Exemple #4
0
 def send_actions(self, action_dict: MultiEnvDict) -> None:
     for env_id, agent_dict in action_dict.items():
         if env_id in self.dones:
             raise ValueError("Env {} is already done".format(env_id))
         env = self.envs[env_id]
         obs, rewards, dones, infos = env.step(agent_dict)
         assert isinstance(obs, dict), "Not a multi-agent obs"
         assert isinstance(rewards, dict), "Not a multi-agent reward"
         assert isinstance(dones, dict), "Not a multi-agent return"
         assert isinstance(infos, dict), "Not a multi-agent info"
         # Allow `__common__` entry in `infos` for data unrelated with any
         # agent, but rather with the environment itself.
         if set(infos).difference(set(obs) | {"__common__"}):
             raise ValueError(
                 "Key set for infos must be a subset of obs: "
                 "{} vs {}".format(infos.keys(), obs.keys())
             )
         if "__all__" not in dones:
             raise ValueError(
                 "In multi-agent environments, '__all__': True|False must "
                 "be included in the 'done' dict: got {}.".format(dones)
             )
         if dones["__all__"]:
             self.dones.add(env_id)
         self.env_states[env_id].observe(obs, rewards, dones, infos)
Exemple #5
0
 def send_actions(self, action_dict: MultiEnvDict) -> None:
     for env_id, actions in action_dict.items():
         actor = self.actors[env_id]
         # `actor` is a simple single-agent (remote) env, e.g. a gym.Env
         # that was made a @ray.remote.
         if not self.multiagent and self.make_env_creates_actors:
             obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID])
         # `actor` is already a _RemoteSingleAgentEnv or
         # _RemoteMultiAgentEnv wrapper
         # (handles the multi-agent action_dict automatically).
         else:
             obj_ref = actor.step.remote(actions)
         self.pending[obj_ref] = actor
Exemple #6
0
    def _space_contains(self, space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Returns:
            True if the observations of x are contained in space.
        """
        agents = set(self.get_agent_ids())
        for multi_agent_dict in x.values():
            for agent_id, obs in multi_agent_dict:
                if (agent_id not in agents) or (not space[agent_id].contains(obs)):
                    return False

        return True
Exemple #7
0
    def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Note: With vector envs, we can process the raw observations
            and ignore the agent ids and env ids, since vector envs'
            sub environements are guaranteed to be the same

        Returns:
            True if the observations of x are contained in space.
        """
        for _, multi_agent_dict in x.items():
            for _, element in multi_agent_dict.items():
                if not space.contains(element):
                    return False
        return True
Exemple #8
0
    def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
        """Check if the given space contains the observations of x.

        Args:
            space: The space to if x's observations are contained in.
            x: The observations to check.

        Returns:
            True if the observations of x are contained in space.
        """
        # this removes the agent_id key and inner dicts
        # in MultiEnvDicts
        flattened_obs = {
            env_id: list(obs.values())
            for env_id, obs in x.items()
        }
        ret = True
        for env_id in flattened_obs:
            for obs in flattened_obs[env_id]:
                ret = ret and space[env_id].contains(obs)
        return ret
Exemple #9
0
 def send_actions(self, action_dict: MultiEnvDict) -> None:
     for env_id, agent_dict in action_dict.items():
         if env_id in self.dones:
             raise ValueError("Env {} is already done".format(env_id))
         env = self.envs[env_id]
         obs, rewards, dones, infos = env.step(agent_dict)
         assert isinstance(obs, dict), "Not a multi-agent obs"
         assert isinstance(rewards, dict), "Not a multi-agent reward"
         assert isinstance(dones, dict), "Not a multi-agent return"
         assert isinstance(infos, dict), "Not a multi-agent info"
         if set(obs.keys()) != set(rewards.keys()):
             raise ValueError(
                 "Key set for obs and rewards must be the same: "
                 "{} vs {}".format(obs.keys(), rewards.keys()))
         if set(infos).difference(set(obs)):
             raise ValueError("Key set for infos must be a subset of obs: "
                              "{} vs {}".format(infos.keys(), obs.keys()))
         if "__all__" not in dones:
             raise ValueError(
                 "In multi-agent environments, '__all__': True|False must "
                 "be included in the 'done' dict: got {}.".format(dones))
         if dones["__all__"]:
             self.dones.add(env_id)
         self.env_states[env_id].observe(obs, rewards, dones, infos)
Exemple #10
0
    def _process_observations(
        self,
        unfiltered_obs: MultiEnvDict,
        rewards: MultiEnvDict,
        dones: MultiEnvDict,
        infos: MultiEnvDict,
    ) -> Tuple[Dict[PolicyID, List[_PolicyEvalData]], List[Union[
            RolloutMetrics, SampleBatchType]], ]:
        """Process raw obs from env.

        Group data for active agents by policy. Reset environments that are done.

        Args:
            unfiltered_obs: obs
            rewards: rewards
            dones: dones
            infos: infos

        Returns:
            A tuple of:
                _PolicyEvalData for active agents for policy evaluation.
                SampleBatches and RolloutMetrics for completed agents for output.
        """
        # Output objects.
        to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list)
        outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

        # For each (vectorized) sub-environment.
        # types: EnvID, Dict[AgentID, EnvObsType]
        for env_id, env_obs in unfiltered_obs.items():
            # Check for env_id having returned an error instead of a multi-agent
            # obs dict. This is how our BaseEnv can tell the caller to `poll()` that
            # one of its sub-environments is faulty and should be restarted (and the
            # ongoing episode should not be used for training).
            if isinstance(env_obs, Exception):
                assert dones[env_id]["__all__"] is True, (
                    f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
                    "as observation, the dones[__all__] flag must also be set to True!"
                )
                # all_agents_obs is an Exception here.
                # Drop this episode and skip to next.
                self.end_episode(env_id, env_obs)
                continue

            episode: EpisodeV2 = self._active_episodes[env_id]

            # Episode length after this step.
            # If this is a branch new episode, this step is adding init_obs.
            # So env_steps will stay at 0. Otherwise, env_steps will advance by 1.
            next_episode_length = episode.length + 1 if episode.has_init_obs else 0
            # Check episode termination conditions.
            if dones[env_id]["__all__"] or next_episode_length >= self._horizon:
                hit_horizon = (next_episode_length >= self._horizon
                               and not dones[env_id]["__all__"])
                all_agents_done = True
                # Add rollout metrics.
                outputs.extend(self._get_rollout_metrics(episode))
            else:
                hit_horizon = False
                all_agents_done = False

            # Special handling of common info dict.
            episode.set_last_info("__common__",
                                  infos[env_id].get("__common__", {}))

            # Agent sample batches grouped by policy. Each set of sample batches will
            # go through agent connectors together.
            sample_batches_by_policy = defaultdict(list)
            # Whether an agent is done, regardless of no_done_at_end or soft_horizon.
            agent_dones = {}
            for agent_id, obs in env_obs.items():
                assert agent_id != "__all__"

                policy_id: PolicyID = episode.policy_for(agent_id)

                agent_done = bool(all_agents_done
                                  or dones[env_id].get(agent_id))
                agent_dones[agent_id] = agent_done

                # A completely new agent is already done -> Skip entirely.
                if not episode.has_init_obs and agent_done:
                    continue

                values_dict = {
                    SampleBatch.T:
                    episode.length - 1,
                    SampleBatch.ENV_ID:
                    env_id,
                    SampleBatch.AGENT_INDEX:
                    episode.agent_index(agent_id),
                    # Last action (SampleBatch.ACTIONS) column will be populated by
                    # StateBufferConnector.
                    # Reward received after taking action at timestep t.
                    SampleBatch.REWARDS:
                    rewards[env_id].get(agent_id, 0.0),
                    # After taking action=a, did we reach terminal?
                    SampleBatch.DONES:
                    (False if
                     (self._no_done_at_end or
                      (hit_horizon and self._soft_horizon)) else agent_done),
                    SampleBatch.INFOS:
                    infos[env_id].get(agent_id, {}),
                    SampleBatch.NEXT_OBS:
                    obs,
                }

                # Queue this obs sample for connector preprocessing.
                sample_batches_by_policy[policy_id].append(
                    (agent_id, values_dict))

            # The entire episode is done.
            if all_agents_done:
                # Let's check to see if there are any agents that haven't got the
                # last "done" obs yet. If there are, we have to create fake-last
                # observations for them. (the environment is not required to do so if
                # dones[__all__]=True).
                for agent_id in episode.get_agents():
                    # If the latest obs we got for this agent is done, or if its
                    # episode state is already done, nothing to do.
                    if agent_dones.get(agent_id,
                                       False) or episode.is_done(agent_id):
                        continue

                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy = self._worker.policy_map[policy_id]

                    # Create a fake (all-0s) observation.
                    obs_space = policy.observation_space
                    obs_space = getattr(obs_space, "original_space", obs_space)
                    values_dict = {
                        SampleBatch.T:
                        episode.length - 1,
                        SampleBatch.ENV_ID:
                        env_id,
                        SampleBatch.AGENT_INDEX:
                        episode.agent_index(agent_id),
                        SampleBatch.REWARDS:
                        0.0,
                        SampleBatch.DONES:
                        True,
                        SampleBatch.INFOS: {},
                        SampleBatch.NEXT_OBS:
                        tree.map_structure(np.zeros_like, obs_space.sample()),
                    }

                    # Queue these fake obs for connector preprocessing too.
                    sample_batches_by_policy[policy_id].append(
                        (agent_id, values_dict))

            # Run agent connectors.
            processed = []
            for policy_id, batches in sample_batches_by_policy.items():
                policy: Policy = self._worker.policy_map[policy_id]
                # Collected full MultiAgentDicts for this environment.
                # Run agent connectors.
                assert (policy.agent_connectors
                        ), "EnvRunnerV2 requires agent connectors to work."

                acd_list: List[AgentConnectorDataType] = [
                    AgentConnectorDataType(env_id, agent_id, data)
                    for agent_id, data in batches
                ]
                processed.extend(policy.agent_connectors(acd_list))

            is_initial_obs = not episode.has_init_obs
            for d in processed:
                # Record transition info if applicable.
                if is_initial_obs:
                    episode.add_init_obs(
                        d.agent_id,
                        d.data.for_training[SampleBatch.T],
                        d.data.for_training[SampleBatch.NEXT_OBS],
                    )
                else:
                    episode.add_action_reward_done_next_obs(
                        d.agent_id, d.data.for_training)

                if not agent_dones[d.agent_id]:
                    item = _PolicyEvalData(d.env_id, d.agent_id,
                                           d.data.for_action)
                    to_eval[policy_id].append(item)

            # Exception: The very first env.poll() call causes the env to get reset
            # (no step taken yet, just a single starting observation logged).
            # We need to skip this callback in this case.
            if not is_initial_obs:
                # Finished advancing episode by 1 step, mark it so.
                episode.step()

                # Invoke the `on_episode_step` callback after the step is logged
                # to the episode.
                self._callbacks.on_episode_step(
                    worker=self._worker,
                    base_env=self._base_env,
                    policies=self._worker.policy_map,
                    episode=episode,
                    env_index=env_id,
                )

            # Episode is done for all agents (dones[__all__] == True)
            # or we hit the horizon.
            if all_agents_done:
                is_done = dones[env_id]["__all__"]
                # _handle_done_episode will build a MultiAgentBatch for all
                # the agents that are done during this step of rollout in
                # the case of _multiple_episodes_in_batch=False.
                self._handle_done_episode(env_id, env_obs, is_done,
                                          hit_horizon, to_eval, outputs)

            # Try to build something.
            if self._multiple_episodes_in_batch:
                sample_batch = self._try_build_truncated_episode_multi_agent_batch(
                    self._batch_builders[env_id], episode)
                if sample_batch:
                    outputs.append(sample_batch)

                    # SampleBatch built from data collected by batch_builder.
                    # Clean up and delete the batch_builder.
                    del self._batch_builders[env_id]

        return to_eval, outputs
Exemple #11
0
 def action_space_contains(self, x: MultiEnvDict) -> bool:
     return all(self.envs[0].action_space_contains(val)
                for val in x.values())
Exemple #12
0
 def send_actions(self, action_dict: MultiEnvDict) -> None:
     for env_id, actions in action_dict.items():
         actor = self.actors[env_id]
         obj_ref = actor.step.remote(actions)
         self.pending[obj_ref] = actor