Beispiel #1
0
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
    """Atari games have multiple logical episodes, one per life.

    However, for metrics reporting we count full episodes, all lives included.
    """
    unwrapped = base_env.get_unwrapped()
    if not unwrapped:
        return None
    atari_out = []
    for u in unwrapped:
        monitor = get_wrapper_by_cls(u, MonitorEnv)
        if not monitor:
            return None
        for eps_rew, eps_len in monitor.next_episode_results():
            atari_out.append(RolloutMetrics(eps_len, eps_rew))
    return atari_out
Beispiel #2
0
def _process_observations(
    *,
    worker: "RolloutWorker",
    base_env: BaseEnv,
    policies: Dict[PolicyID, Policy],
    active_episodes: Dict[str, MultiAgentEpisode],
    unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
    rewards: Dict[EnvID, Dict[AgentID, float]],
    dones: Dict[EnvID, Dict[AgentID, bool]],
    infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
    horizon: int,
    preprocessors: Dict[PolicyID, Preprocessor],
    obs_filters: Dict[PolicyID, Filter],
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        policies (dict): Map of policy ids to Policy instances.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
            -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
            call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector (SampleCollector): The SampleCollector object
            used to store and retrieve environment samples.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    # For each (vectorized) sub-environment.
    # type: EnvID, Dict[AgentID, EnvObsType]
    for env_id, all_agents_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]

        if not is_new_episode:
            sample_collector.episode_step(episode)
            episode._add_agent_rewards(rewards[env_id])

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data, episode.media))
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=all_agents_obs,
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if not isinstance(all_agents_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # type: AgentID, EnvObsType
        for agent_id, raw_obs in all_agents_obs.items():
            assert agent_id != "__all__"

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))

            # A new agent (initial obs) is already done -> Skip entirely.
            if last_observation is None and agent_done:
                continue

            policy_id: PolicyID = episode.policy_for(agent_id)

            prep_obs: EnvObsType = _get_or_raise(preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
            filtered_obs: EnvObsType = _get_or_raise(obs_filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            # Infos from the environment.
            agent_infos = infos[env_id].get(agent_id, {})
            episode._set_last_info(agent_id, agent_infos)

            # Record transition info if applicable.
            if last_observation is None:
                sample_collector.add_init_obs(episode, agent_id, env_id,
                                              policy_id, episode.length - 1,
                                              filtered_obs)
            else:
                # Add actions, rewards, next-obs to collectors.
                values_dict = {
                    "t":
                    episode.length - 1,
                    "env_id":
                    env_id,
                    "agent_index":
                    episode._agent_index(agent_id),
                    # Action (slot 0) taken at timestep t.
                    "actions":
                    episode.last_action_for(agent_id),
                    # Reward received after taking a at timestep t.
                    "rewards":
                    rewards[env_id][agent_id],
                    # After taking action=a, did we reach terminal?
                    "dones":
                    (False if
                     (no_done_at_end or
                      (hit_horizon and soft_horizon)) else agent_done),
                    # Next observation.
                    "new_obs":
                    filtered_obs,
                }
                # Add extra-action-fetches to collectors.
                pol = policies[policy_id]
                for key, value in episode.last_pi_info_for(agent_id).items():
                    if key in pol.view_requirements:
                        values_dict[key] = value
                # Env infos for this agent.
                if "infos" in pol.view_requirements:
                    values_dict["infos"] = agent_infos
                sample_collector.add_action_reward_next_obs(
                    episode.episode_id, agent_id, env_id, policy_id,
                    agent_done, values_dict)

            if not agent_done:
                item = PolicyEvalData(
                    env_id, agent_id, filtered_obs, agent_infos,
                    None if last_observation is None else
                    episode.rnn_state_for(agent_id),
                    None if last_observation is None else
                    episode.last_action_for(agent_id),
                    rewards[env_id][agent_id] or 0.0)
                to_eval[policy_id].append(item)

        # Invoke the `on_episode_step` callback after the step is logged
        # to the episode.
        # Exception: The very first env.poll() call causes the env to get reset
        # (no step taken yet, just a single starting observation logged).
        # We need to skip this callback in this case.
        if episode.length > 0:
            callbacks.on_episode_step(worker=worker,
                                      base_env=base_env,
                                      episode=episode,
                                      env_index=env_id)

        # Episode is done for all agents (dones[__all__] == True)
        # or we hit the horizon.
        if all_agents_done:
            is_done = dones[env_id]["__all__"]
            check_dones = is_done and not no_done_at_end

            # If, we are not allowed to pack the next episode into the same
            # SampleBatch (batch_mode=complete_episodes) -> Build the
            # MultiAgentBatch from a single episode and add it to "outputs".
            # Otherwise, just postprocess and continue collecting across
            # episodes.
            ma_sample_batch = sample_collector.postprocess_episode(
                episode,
                is_done=is_done or (hit_horizon and not soft_horizon),
                check_dones=check_dones,
                build=not multiple_episodes_in_batch)
            if ma_sample_batch:
                outputs.append(ma_sample_batch)

            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(policy=p,
                                                 environment=base_env,
                                                 episode=episode,
                                                 tf_sess=getattr(
                                                     p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode,
                env_index=env_id,
            )
            # Horizon hit and we have a soft horizon (no hard env reset).
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID,
                                   EnvObsType] = base_env.try_reset(env_id)
            # Reset not supported, drop this env from the ready list.
            if resetted_obs is None:
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            # Creates a new episode if this is not async return.
            # If reset is async, we will get its result in some future poll.
            elif resetted_obs != ASYNC_RESET_RETURN:
                new_episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=policies,
                        episode=new_episode)
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = new_episode.policy_for(agent_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    new_episode._set_last_observation(agent_id, filtered_obs)

                    # Add initial obs to buffer.
                    sample_collector.add_init_obs(new_episode, agent_id,
                                                  env_id, policy_id,
                                                  new_episode.length - 1,
                                                  filtered_obs)

                    item = PolicyEvalData(
                        env_id, agent_id, filtered_obs,
                        episode.last_info_for(agent_id) or {},
                        episode.rnn_state_for(agent_id), None, 0.0)
                    to_eval[policy_id].append(item)

    # Try to build something.
    if multiple_episodes_in_batch:
        sample_batches = \
            sample_collector.try_build_truncated_episode_multi_agent_batch()
        if sample_batches:
            outputs.extend(sample_batches)

    return active_envs, to_eval, outputs
Beispiel #3
0
def _process_observations(base_env, policies, batch_builder_pool,
                          active_episodes, unfiltered_obs, rewards, dones,
                          infos, off_policy_actions, horizon, preprocessors,
                          obs_filters, rollout_fragment_length, pack,
                          callbacks, soft_horizon, no_done_at_end):
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []
    large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
        rollout_fragment_length != float("inf") else 5000

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > large_batch_threshold
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly and that"
                " it terminates at some point. "
                "Note: In multi-agent environments, `rollout_fragment_length` "
                "sets the batch size based on environment steps, not the "
                "steps of "
                "individual agents, which can result in unexpectedly large "
                "batches. Also, you may be in evaluation waiting for your Env "
                "to terminate (batch_mode=`complete_episodes`). Make sure it "
                "does at some point.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        # For each agent in the environment.
        for agent_id, raw_obs in agent_obs.items():
            policy_id = episode.policy_for(agent_id)
            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if (no_done_at_end
                                     or (hit_horizon and soft_horizon)) else
                           agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_agent_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= rollout_fragment_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                p.exploration.on_episode_end(
                    policy=p,
                    environment=base_env,
                    episode=episode,
                    tf_sess=getattr(p, "_sess", None))
            # Call custom on_episode_end callback.
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id = episode.policy_for(agent_id)
                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

    return active_envs, to_eval, outputs
Beispiel #4
0
def _process_observations(
        worker: "RolloutWorker", base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        batch_builder_pool: List[MultiAgentSampleBatchBuilder],
        active_episodes: Dict[str, MultiAgentEpisode],
        unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
        rewards: Dict[EnvID, Dict[AgentID, float]],
        dones: Dict[EnvID, Dict[AgentID, bool]],
        infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int,
        preprocessors: Dict[PolicyID, Preprocessor],
        obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
        pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
        soft_horizon: bool, no_done_at_end: bool,
        observation_fn: "ObservationFunction",
        _use_trajectory_view_api: bool = False
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        policies (dict): Map of policy ids to Policy instances.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
            unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        pack_multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` to make generic trajectory views
            available to Models. Default: False.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
        rollout_fragment_length != float("inf") else 5000

    # For each environment.
    # type: EnvID, Dict[AgentID, EnvObsType]
    for env_id, agent_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]
        if not is_new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > large_batch_threshold
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check "
                "that you set a horizon on your environment correctly and "
                "that it terminates at some point. "
                "Note: In multi-agent environments, `rollout_fragment_length` "
                "sets the batch size based on environment steps, not the "
                "steps of "
                "individual agents, which can result in unexpectedly large "
                "batches. Also, you may be in evaluation waiting for your Env "
                "to terminate (batch_mode=`complete_episodes`). Make sure it "
                "does at some point.")

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data))
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=agent_obs,
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if not isinstance(agent_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # type: AgentID, EnvObsType
        for agent_id, raw_obs in agent_obs.items():
            assert agent_id != "__all__"
            policy_id: PolicyID = episode.policy_for(agent_id)
            prep_obs: EnvObsType = _get_or_raise(preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs: EnvObsType = _get_or_raise(obs_filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable.
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if (no_done_at_end
                                     or (hit_horizon and soft_horizon)) else
                           agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        callbacks.on_episode_step(
            worker=worker, base_env=base_env, episode=episode)

        # Cut the batch if ...
        # - all-agents-done and not packing multiple episodes into one
        #   (batch_mode="complete_episodes")
        # - or if we've exceeded the rollout_fragment_length.
        if episode.batch_builder.has_pending_agent_data():
            # Sanity check, whether all agents have done=True, if done[__all__]
            # is True.
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()

            # Reached end of episode and we are not allowed to pack the
            # next episode into the same SampleBatch -> Build the SampleBatch
            # and add it to "outputs".
            if (all_agents_done and not pack_multiple_episodes_in_batch) or \
                    episode.batch_builder.count >= rollout_fragment_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            # Make sure postprocessor stays within one episode.
            elif all_agents_done:
                episode.batch_builder.postprocess_batch_so_far(episode)

        # Episode is done.
        if all_agents_done:
            # Handle episode termination.
            batch_builder_pool.append(episode.batch_builder)
            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(
                        policy=p,
                        environment=base_env,
                        episode=episode,
                        tf_sess=getattr(p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
                    env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list.
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return.
                # If reset is async, we will get its result in some future poll
                episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=policies,
                        episode=episode)
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy: Policy = _get_or_raise(policies, policy_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                flatten_to_single_ndarray(
                                    policy.action_space.sample())), 0.0))

    return active_envs, to_eval, outputs
Beispiel #5
0
def _process_observations(base_env, policies, batch_builder_pool,
                          active_episodes, unfiltered_obs, rewards, dones,
                          infos, off_policy_actions, horizon, preprocessors,
                          obs_filters, unroll_length, pack, callbacks,
                          soft_horizon, no_done_at_end):
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """
    global i
    global tmp_dic
    global traffic_light_node_dict
    i += 1

    def inter_num_2_id(num):
        return list(tmp_dic.keys())[list(tmp_dic.values()).index(num)]

    def read_traffic_light_node_dict():
        path_to_read = os.path.join(record_dir, 'traffic_light_node_dict.conf')
        with open(path_to_read, 'r') as f:
            traffic_light_node_dict = eval(f.read())
            print("Read traffic_light_node_dict")
            return traffic_light_node_dict

    if i <= 1:
        # 此处用于从配置文件读入 neighbor 情况
        record_dir = base_env.envs[0].record_dir
        traffic_light_node_dict = base_env.envs[0].traffic_light_node_dict
        tmp_dic = traffic_light_node_dict['intersection_1_1'][
            'inter_id_to_index']

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > max(1000, unroll_length * 10)
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly. Note "
                "that in multi-agent environments, `sample_batch_size` sets "
                "the batch size based on environment steps, not the steps of "
                "individual agents, which can result in unexpectedly large "
                "batches.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {}))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        # For each agent in the environment
        for agent_id, raw_obs in agent_obs.items():
            policy_id = episode.policy_for(agent_id)  # eg: "policy_0"
            # print(policy_id)
            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            '''
            For Attention !!!!!!!!!!!!!!!!!!!!
            这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值
            '''
            # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式
            neighbor_pid_list = [
                'policy_{}'.format(pid_)
                for pid_ in traffic_light_node_dict[inter_num_2_id(
                    int(policy_id.split('_')[1]))]['adjacency_row']
                if pid_ != None
            ]
            # print(neighbor_pid_list)
            neighbor_obs = []
            neighbor_obs.append([])

            # Size: (1, 5, 15) 只有这个形式才能传入neighbor_obs (batch, 5, 15) 的 Placeholder
            i = 0
            for neighbor_id in neighbor_pid_list:
                neighbor_prep_obs = _get_or_raise(
                    preprocessors, neighbor_id).transform(raw_obs)
                neighbor_filtered_obs = _get_or_raise(
                    obs_filters, neighbor_id)(neighbor_prep_obs)
                neighbor_obs[0].append(neighbor_filtered_obs)
                i += 1
            neighbor_obs = np.array(neighbor_obs).reshape(
                (len(neighbor_pid_list), len(raw_obs)))  # (5, 29)

            # ------------------------------------------------------------------
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   neighbor_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if
                           (no_done_at_end or
                            (hit_horizon and soft_horizon)) else agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= unroll_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id = episode.policy_for(agent_id)  # eg: "policy_0"
                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    # print('policy_id' + str(policy_id))
                    # print('filtered_obs' + str(filtered_obs))
                    '''
                    For Attention !!!!!!!!!!!!!!!!!!!!
                    这里是episode终止, create a new episode
                    这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值
                    '''
                    # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式
                    neighbor_pid_list = [
                        'policy_{}'.format(pid_)
                        for pid_ in traffic_light_node_dict[inter_num_2_id(
                            int(policy_id.split('_')[1]))]['adjacency_row']
                        if pid_ != None
                    ]
                    # print(neighbor_pid_list)
                    neighbor_obs = []
                    neighbor_obs.append([])

                    # Size: (1, 5, 29) 只有这个形式才能传入neighbor_obs (batch, 5, 17) 的 Placeholder
                    i = 0
                    for neighbor_id in neighbor_pid_list:
                        neighbor_prep_obs = _get_or_raise(
                            preprocessors, neighbor_id).transform(raw_obs)
                        neighbor_filtered_obs = _get_or_raise(
                            obs_filters, neighbor_id)(neighbor_prep_obs)
                        neighbor_obs[0].append(neighbor_filtered_obs)
                        i += 1
                    neighbor_obs = np.squeeze(np.array(neighbor_obs))

                    # ------------------------------------------------------------------
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs, neighbor_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

    return active_envs, to_eval, outputs
def _process_observations(base_env, policies, policies_to_train, dead_policies,
                          policy_config, observation_filter, tf_sess,
                          batch_builder_pool, active_episodes, unfiltered_obs,
                          rewards, dones, infos, off_policy_actions, horizon,
                          preprocessors, obs_filters, unroll_length, pack,
                          callbacks, soft_horizon, no_done_at_end):
    #===MOD===
    """Record new data from the environment and prepare for policy evaluation.

    Returns:
        active_envs: set of non-terminated env ids
        to_eval: map of policy_id to list of agent PolicyEvalData
        outputs: list of metrics and samples to return from the sampler
    """

    active_envs = set()
    to_eval = defaultdict(list)
    outputs = []

    # For each environment
    for env_id, agent_obs in unfiltered_obs.items():
        new_episode = env_id not in active_episodes
        episode = active_episodes[env_id]
        if not new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > max(1000, unroll_length * 10)
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check that "
                "that you set a horizon on your environment correctly. Note "
                "that in multi-agent environments, `sample_batch_size` sets "
                "the batch size based on environment steps, not the steps of "
                "individual agents, which can result in unexpectedly large "
                "batches.")

        # Check episode termination conditions
        if dones[env_id]["__all__"] or episode.length >= horizon:
            # DEBUG
            # print("Trying to terminate.")
            # print("Dones of __all__ is set:", dones[env_id]["__all__"])
            # print("Horizon hit:", episode.length >= horizon)
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_done = True
            atari_metrics = _fetch_atari_metrics(base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {}))
        else:
            hit_horizon = False
            all_done = False
            active_envs.add(env_id)

        #===MOD===
        additional_builders_ids = set()
        #===MOD===

        # For each agent in the environment
        for agent_id, raw_obs in agent_obs.items():

            #===MOD===
            policy_id, policy_constructor_tuple = episode.policy_for(agent_id)
            pols_tuple = generate_policies(
                policy_id,
                policy_constructor_tuple,
                policies,
                policies_to_train,
                dead_policies,
                policy_config,
                preprocessors,
                obs_filters,
                observation_filter,
                tf_sess,
            )
            policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple
            #===MOD===

            prep_obs = _get_or_raise(preprocessors,
                                     policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation = episode.last_observation_for(agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                #===MOD===
                additional_builders_ids.add(agent_id)
                #===MOD===
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if
                           (no_done_at_end or
                            (hit_horizon and soft_horizon)) else agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

            #===MOD===
            if agent_done:
                # Does it make sense to remove agent id from `agent_builders`?
                dead_policies.add(policy_id)
                print("Removing agent id from agent builders: %s" %
                      str(agent_id))
                episode.batch_builder.agent_builders.pop(agent_id)
                if policy_id in to_eval:
                    to_eval.pop(policy_id)
                    # print("Popping policy id from toeval.")
            #===MOD===

        start = time.time()

        #===MOD===
        print("sampler.py: ids added to agent builders:\t",
              additional_builders_ids)
        # Update ``self.policy_map`` in ``MultiAgentSampleBatchBuilder``.
        # TODO: policies is not being pruned in this file.
        episode.batch_builder.policy_map = policies
        print("sampler.py: policies: \t", policies.keys())
        #===MOD===

        # Invoke the step callback after the step is logged to the episode
        if callbacks.get("on_episode_step"):
            callbacks["on_episode_step"]({"env": base_env, "episode": episode})

        # Cut the batch if we're not packing multiple episodes into one,
        # or if we've exceeded the requested batch size.
        if episode.batch_builder.has_pending_data():
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()
            if (all_done and not pack) or \
                    episode.batch_builder.count >= unroll_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            elif all_done:
                # Make sure postprocessor stays within one episode
                # KEYERROR
                episode.batch_builder.postprocess_batch_so_far(episode)

        if all_done:
            # Handle episode termination
            batch_builder_pool.append(episode.batch_builder)
            if callbacks.get("on_episode_end"):
                callbacks["on_episode_end"]({
                    "env": base_env,
                    "policy": policies,
                    "episode": episode
                })
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs = base_env.try_reset(env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # print("Executing new epsiode non-async return.")
                time.sleep(1)
                raise NotImplementedError(
                    "Multiple episodes not supported by design.")
                # Creates a new episode if this is not async return
                # If reset is async, we will get its result in some future poll
                episode = active_episodes[env_id]
                for agent_id, raw_obs in resetted_obs.items():

                    #===MOD===
                    policy_id, policy_constructor_tuple = episode.policy_for(
                        agent_id)
                    # with tf_sess.as_default():
                    pols_tuple = generate_policies(
                        policy_id,
                        policy_constructor_tuple,
                        policies,
                        policies_to_train,
                        dead_policies,
                        policy_config,
                        preprocessors,
                        obs_filters,
                        observation_filter,
                        tf_sess,
                    )
                    policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple
                    #===MOD===

                    policy = _get_or_raise(policies, policy_id)
                    prep_obs = _get_or_raise(preprocessors,
                                             policy_id).transform(raw_obs)
                    filtered_obs = _get_or_raise(obs_filters,
                                                 policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                _flatten_action(policy.action_space.sample())),
                            0.0))

        #===MOD===
        pols_tuple = (policies, preprocessors, obs_filters, policies_to_train,
                      dead_policies)
        #===MOD===
    #===MOD===
    return active_envs, to_eval, outputs, pols_tuple