Exemplo n.º 1
0
def _process_observations(
    *,
    worker: "RolloutWorker",
    base_env: BaseEnv,
    policies: Dict[PolicyID, Policy],
    active_episodes: Dict[str, MultiAgentEpisode],
    unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
    rewards: Dict[EnvID, Dict[AgentID, float]],
    dones: Dict[EnvID, Dict[AgentID, bool]],
    infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
    horizon: int,
    preprocessors: Dict[PolicyID, Preprocessor],
    obs_filters: Dict[PolicyID, Filter],
    multiple_episodes_in_batch: bool,
    callbacks: "DefaultCallbacks",
    soft_horizon: bool,
    no_done_at_end: bool,
    observation_fn: "ObservationFunction",
    sample_collector: SampleCollector,
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        policies (dict): Map of policy ids to Policy instances.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
            -> unfiltered observation tensor, returned by a `BaseEnv.poll()`
            call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        sample_collector (SampleCollector): The SampleCollector object
            used to store and retrieve environment samples.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    # For each (vectorized) sub-environment.
    # type: EnvID, Dict[AgentID, EnvObsType]
    for env_id, all_agents_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]

        if not is_new_episode:
            sample_collector.episode_step(episode)
            episode._add_agent_rewards(rewards[env_id])

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data, episode.media))
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=all_agents_obs,
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if not isinstance(all_agents_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # type: AgentID, EnvObsType
        for agent_id, raw_obs in all_agents_obs.items():
            assert agent_id != "__all__"

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))

            # A new agent (initial obs) is already done -> Skip entirely.
            if last_observation is None and agent_done:
                continue

            policy_id: PolicyID = episode.policy_for(agent_id)

            prep_obs: EnvObsType = _get_or_raise(preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
            filtered_obs: EnvObsType = _get_or_raise(obs_filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            # Infos from the environment.
            agent_infos = infos[env_id].get(agent_id, {})
            episode._set_last_info(agent_id, agent_infos)

            # Record transition info if applicable.
            if last_observation is None:
                sample_collector.add_init_obs(episode, agent_id, env_id,
                                              policy_id, episode.length - 1,
                                              filtered_obs)
            else:
                # Add actions, rewards, next-obs to collectors.
                values_dict = {
                    "t":
                    episode.length - 1,
                    "env_id":
                    env_id,
                    "agent_index":
                    episode._agent_index(agent_id),
                    # Action (slot 0) taken at timestep t.
                    "actions":
                    episode.last_action_for(agent_id),
                    # Reward received after taking a at timestep t.
                    "rewards":
                    rewards[env_id][agent_id],
                    # After taking action=a, did we reach terminal?
                    "dones":
                    (False if
                     (no_done_at_end or
                      (hit_horizon and soft_horizon)) else agent_done),
                    # Next observation.
                    "new_obs":
                    filtered_obs,
                }
                # Add extra-action-fetches to collectors.
                pol = policies[policy_id]
                for key, value in episode.last_pi_info_for(agent_id).items():
                    if key in pol.view_requirements:
                        values_dict[key] = value
                # Env infos for this agent.
                if "infos" in pol.view_requirements:
                    values_dict["infos"] = agent_infos
                sample_collector.add_action_reward_next_obs(
                    episode.episode_id, agent_id, env_id, policy_id,
                    agent_done, values_dict)

            if not agent_done:
                item = PolicyEvalData(
                    env_id, agent_id, filtered_obs, agent_infos,
                    None if last_observation is None else
                    episode.rnn_state_for(agent_id),
                    None if last_observation is None else
                    episode.last_action_for(agent_id),
                    rewards[env_id][agent_id] or 0.0)
                to_eval[policy_id].append(item)

        # Invoke the `on_episode_step` callback after the step is logged
        # to the episode.
        # Exception: The very first env.poll() call causes the env to get reset
        # (no step taken yet, just a single starting observation logged).
        # We need to skip this callback in this case.
        if episode.length > 0:
            callbacks.on_episode_step(worker=worker,
                                      base_env=base_env,
                                      episode=episode,
                                      env_index=env_id)

        # Episode is done for all agents (dones[__all__] == True)
        # or we hit the horizon.
        if all_agents_done:
            is_done = dones[env_id]["__all__"]
            check_dones = is_done and not no_done_at_end

            # If, we are not allowed to pack the next episode into the same
            # SampleBatch (batch_mode=complete_episodes) -> Build the
            # MultiAgentBatch from a single episode and add it to "outputs".
            # Otherwise, just postprocess and continue collecting across
            # episodes.
            ma_sample_batch = sample_collector.postprocess_episode(
                episode,
                is_done=is_done or (hit_horizon and not soft_horizon),
                check_dones=check_dones,
                build=not multiple_episodes_in_batch)
            if ma_sample_batch:
                outputs.append(ma_sample_batch)

            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(policy=p,
                                                 environment=base_env,
                                                 episode=episode,
                                                 tf_sess=getattr(
                                                     p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode,
                env_index=env_id,
            )
            # Horizon hit and we have a soft horizon (no hard env reset).
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID,
                                   EnvObsType] = base_env.try_reset(env_id)
            # Reset not supported, drop this env from the ready list.
            if resetted_obs is None:
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            # Creates a new episode if this is not async return.
            # If reset is async, we will get its result in some future poll.
            elif resetted_obs != ASYNC_RESET_RETURN:
                new_episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=policies,
                        episode=new_episode)
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = new_episode.policy_for(agent_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    new_episode._set_last_observation(agent_id, filtered_obs)

                    # Add initial obs to buffer.
                    sample_collector.add_init_obs(new_episode, agent_id,
                                                  env_id, policy_id,
                                                  new_episode.length - 1,
                                                  filtered_obs)

                    item = PolicyEvalData(
                        env_id, agent_id, filtered_obs,
                        episode.last_info_for(agent_id) or {},
                        episode.rnn_state_for(agent_id), None, 0.0)
                    to_eval[policy_id].append(item)

    # Try to build something.
    if multiple_episodes_in_batch:
        sample_batches = \
            sample_collector.try_build_truncated_episode_multi_agent_batch()
        if sample_batches:
            outputs.extend(sample_batches)

    return active_envs, to_eval, outputs
Exemplo n.º 2
0
def _process_observations(
        worker: "RolloutWorker", base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        batch_builder_pool: List[MultiAgentSampleBatchBuilder],
        active_episodes: Dict[str, MultiAgentEpisode],
        unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
        rewards: Dict[EnvID, Dict[AgentID, float]],
        dones: Dict[EnvID, Dict[AgentID, bool]],
        infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int,
        preprocessors: Dict[PolicyID, Preprocessor],
        obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int,
        pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks",
        soft_horizon: bool, no_done_at_end: bool,
        observation_fn: "ObservationFunction",
        _use_trajectory_view_api: bool = False
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
        RolloutMetrics, SampleBatchType]]]:
    """Record new data from the environment and prepare for policy evaluation.

    Args:
        worker (RolloutWorker): Reference to the current rollout worker.
        base_env (BaseEnv): Env implementing BaseEnv.
        policies (dict): Map of policy ids to Policy instances.
        batch_builder_pool (List[SampleBatchBuilder]): List of pooled
            SampleBatchBuilder object for recycling.
        active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
            episode ID to currently ongoing MultiAgentEpisode object.
        unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids ->
            unfiltered observation tensor, returned by a `BaseEnv.poll()` call.
        rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
            rewards tensor, returned by a `BaseEnv.poll()` call.
        dones (dict): Doubly keyed dict of env-ids -> agent ids ->
            boolean done flags, returned by a `BaseEnv.poll()` call.
        infos (dict): Doubly keyed dict of env-ids -> agent ids ->
            info dicts, returned by a `BaseEnv.poll()` call.
        horizon (int): Horizon of the episode.
        preprocessors (dict): Map of policy id to preprocessor for the
            observations prior to filtering.
        obs_filters (dict): Map of policy id to filter used to process
            observations for the policy.
        rollout_fragment_length (int): Number of episode steps before
            `SampleBatch` is yielded. Set to infinity to yield complete
            episodes.
        pack_multiple_episodes_in_batch (bool): Whether to pack multiple
            episodes into each batch. This guarantees batches will be exactly
            `rollout_fragment_length` in size.
        callbacks (DefaultCallbacks): User callbacks to run on episode events.
        soft_horizon (bool): Calculate rewards but don't reset the
            environment when the horizon is hit.
        no_done_at_end (bool): Ignore the done=True at the end of the episode
            and instead record done=False.
        observation_fn (ObservationFunction): Optional multi-agent
            observation func to use for preprocessing observations.
        _use_trajectory_view_api (bool): Whether to use the (experimental)
            `_use_trajectory_view_api` to make generic trajectory views
            available to Models. Default: False.

    Returns:
        Tuple:
            - active_envs: Set of non-terminated env ids.
            - to_eval: Map of policy_id to list of agent PolicyEvalData.
            - outputs: List of metrics and samples to return from the sampler.
    """

    # Output objects.
    active_envs: Set[EnvID] = set()
    to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
    outputs: List[Union[RolloutMetrics, SampleBatchType]] = []

    large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \
        rollout_fragment_length != float("inf") else 5000

    # For each environment.
    # type: EnvID, Dict[AgentID, EnvObsType]
    for env_id, agent_obs in unfiltered_obs.items():
        is_new_episode: bool = env_id not in active_episodes
        episode: MultiAgentEpisode = active_episodes[env_id]
        if not is_new_episode:
            episode.length += 1
            episode.batch_builder.count += 1
            episode._add_agent_rewards(rewards[env_id])

        if (episode.batch_builder.total() > large_batch_threshold
                and log_once("large_batch_warning")):
            logger.warning(
                "More than {} observations for {} env steps ".format(
                    episode.batch_builder.total(),
                    episode.batch_builder.count) + "are buffered in "
                "the sampler. If this is more than you expected, check "
                "that you set a horizon on your environment correctly and "
                "that it terminates at some point. "
                "Note: In multi-agent environments, `rollout_fragment_length` "
                "sets the batch size based on environment steps, not the "
                "steps of "
                "individual agents, which can result in unexpectedly large "
                "batches. Also, you may be in evaluation waiting for your Env "
                "to terminate (batch_mode=`complete_episodes`). Make sure it "
                "does at some point.")

        # Check episode termination conditions.
        if dones[env_id]["__all__"] or episode.length >= horizon:
            hit_horizon = (episode.length >= horizon
                           and not dones[env_id]["__all__"])
            all_agents_done = True
            atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
                base_env)
            if atari_metrics is not None:
                for m in atari_metrics:
                    outputs.append(
                        m._replace(custom_metrics=episode.custom_metrics))
            else:
                outputs.append(
                    RolloutMetrics(episode.length, episode.total_reward,
                                   dict(episode.agent_rewards),
                                   episode.custom_metrics, {},
                                   episode.hist_data))
        else:
            hit_horizon = False
            all_agents_done = False
            active_envs.add(env_id)

        # Custom observation function is applied before preprocessing.
        if observation_fn:
            agent_obs: Dict[AgentID, EnvObsType] = observation_fn(
                agent_obs=agent_obs,
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if not isinstance(agent_obs, dict):
                raise ValueError(
                    "observe() must return a dict of agent observations")

        # For each agent in the environment.
        # type: AgentID, EnvObsType
        for agent_id, raw_obs in agent_obs.items():
            assert agent_id != "__all__"
            policy_id: PolicyID = episode.policy_for(agent_id)
            prep_obs: EnvObsType = _get_or_raise(preprocessors,
                                                 policy_id).transform(raw_obs)
            if log_once("prep_obs"):
                logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))

            filtered_obs: EnvObsType = _get_or_raise(obs_filters,
                                                     policy_id)(prep_obs)
            if log_once("filtered_obs"):
                logger.info("Filtered obs: {}".format(summarize(filtered_obs)))

            agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
            if not agent_done:
                to_eval[policy_id].append(
                    PolicyEvalData(env_id, agent_id, filtered_obs,
                                   infos[env_id].get(agent_id, {}),
                                   episode.rnn_state_for(agent_id),
                                   episode.last_action_for(agent_id),
                                   rewards[env_id][agent_id] or 0.0))

            last_observation: EnvObsType = episode.last_observation_for(
                agent_id)
            episode._set_last_observation(agent_id, filtered_obs)
            episode._set_last_raw_obs(agent_id, raw_obs)
            episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))

            # Record transition info if applicable.
            if (last_observation is not None and infos[env_id].get(
                    agent_id, {}).get("training_enabled", True)):
                episode.batch_builder.add_values(
                    agent_id,
                    policy_id,
                    t=episode.length - 1,
                    eps_id=episode.episode_id,
                    agent_index=episode._agent_index(agent_id),
                    obs=last_observation,
                    actions=episode.last_action_for(agent_id),
                    rewards=rewards[env_id][agent_id],
                    prev_actions=episode.prev_action_for(agent_id),
                    prev_rewards=episode.prev_reward_for(agent_id),
                    dones=(False if (no_done_at_end
                                     or (hit_horizon and soft_horizon)) else
                           agent_done),
                    infos=infos[env_id].get(agent_id, {}),
                    new_obs=filtered_obs,
                    **episode.last_pi_info_for(agent_id))

        # Invoke the step callback after the step is logged to the episode
        callbacks.on_episode_step(
            worker=worker, base_env=base_env, episode=episode)

        # Cut the batch if ...
        # - all-agents-done and not packing multiple episodes into one
        #   (batch_mode="complete_episodes")
        # - or if we've exceeded the rollout_fragment_length.
        if episode.batch_builder.has_pending_agent_data():
            # Sanity check, whether all agents have done=True, if done[__all__]
            # is True.
            if dones[env_id]["__all__"] and not no_done_at_end:
                episode.batch_builder.check_missing_dones()

            # Reached end of episode and we are not allowed to pack the
            # next episode into the same SampleBatch -> Build the SampleBatch
            # and add it to "outputs".
            if (all_agents_done and not pack_multiple_episodes_in_batch) or \
                    episode.batch_builder.count >= rollout_fragment_length:
                outputs.append(episode.batch_builder.build_and_reset(episode))
            # Make sure postprocessor stays within one episode.
            elif all_agents_done:
                episode.batch_builder.postprocess_batch_so_far(episode)

        # Episode is done.
        if all_agents_done:
            # Handle episode termination.
            batch_builder_pool.append(episode.batch_builder)
            # Call each policy's Exploration.on_episode_end method.
            for p in policies.values():
                if getattr(p, "exploration", None) is not None:
                    p.exploration.on_episode_end(
                        policy=p,
                        environment=base_env,
                        episode=episode,
                        tf_sess=getattr(p, "_sess", None))
            # Call custom on_episode_end callback.
            callbacks.on_episode_end(
                worker=worker,
                base_env=base_env,
                policies=policies,
                episode=episode)
            if hit_horizon and soft_horizon:
                episode.soft_reset()
                resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
            else:
                del active_episodes[env_id]
                resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
                    env_id)
            if resetted_obs is None:
                # Reset not supported, drop this env from the ready list.
                if horizon != float("inf"):
                    raise ValueError(
                        "Setting episode horizon requires reset() support "
                        "from the environment.")
            elif resetted_obs != ASYNC_RESET_RETURN:
                # Creates a new episode if this is not async return.
                # If reset is async, we will get its result in some future poll
                episode: MultiAgentEpisode = active_episodes[env_id]
                if observation_fn:
                    resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
                        agent_obs=resetted_obs,
                        worker=worker,
                        base_env=base_env,
                        policies=policies,
                        episode=episode)
                # type: AgentID, EnvObsType
                for agent_id, raw_obs in resetted_obs.items():
                    policy_id: PolicyID = episode.policy_for(agent_id)
                    policy: Policy = _get_or_raise(policies, policy_id)
                    prep_obs: EnvObsType = _get_or_raise(
                        preprocessors, policy_id).transform(raw_obs)
                    filtered_obs: EnvObsType = _get_or_raise(
                        obs_filters, policy_id)(prep_obs)
                    episode._set_last_observation(agent_id, filtered_obs)
                    to_eval[policy_id].append(
                        PolicyEvalData(
                            env_id, agent_id, filtered_obs,
                            episode.last_info_for(agent_id) or {},
                            episode.rnn_state_for(agent_id),
                            np.zeros_like(
                                flatten_to_single_ndarray(
                                    policy.action_space.sample())), 0.0))

    return active_envs, to_eval, outputs