def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]: """Atari games have multiple logical episodes, one per life. However, for metrics reporting we count full episodes, all lives included. """ unwrapped = base_env.get_unwrapped() if not unwrapped: return None atari_out = [] for u in unwrapped: monitor = get_wrapper_by_cls(u, MonitorEnv) if not monitor: return None for eps_rew, eps_len in monitor.next_episode_results(): atari_out.append(RolloutMetrics(eps_len, eps_rew)) return atari_out
def _process_observations( *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], active_episodes: Dict[str, MultiAgentEpisode], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], rewards: Dict[EnvID, Dict[AgentID, float]], dones: Dict[EnvID, Dict[AgentID, bool]], infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", sample_collector: SampleCollector, ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. policies (dict): Map of policy ids to Policy instances. batch_builder_pool (List[SampleBatchBuilder]): List of pooled SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> unfiltered observation tensor, returned by a `BaseEnv.poll()` call. rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards tensor, returned by a `BaseEnv.poll()` call. dones (dict): Doubly keyed dict of env-ids -> agent ids -> boolean done flags, returned by a `BaseEnv.poll()` call. infos (dict): Doubly keyed dict of env-ids -> agent ids -> info dicts, returned by a `BaseEnv.poll()` call. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks (DefaultCallbacks): User callbacks to run on episode events. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. sample_collector (SampleCollector): The SampleCollector object used to store and retrieve environment samples. Returns: Tuple: - active_envs: Set of non-terminated env ids. - to_eval: Map of policy_id to list of agent PolicyEvalData. - outputs: List of metrics and samples to return from the sampler. """ # Output objects. active_envs: Set[EnvID] = set() to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] # For each (vectorized) sub-environment. # type: EnvID, Dict[AgentID, EnvObsType] for env_id, all_agents_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: sample_collector.episode_step(episode) episode._add_agent_rewards(rewards[env_id]) # Check episode termination conditions. if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_agents_done = True atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics( base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data, episode.media)) else: hit_horizon = False all_agents_done = False active_envs.add(env_id) # Custom observation function is applied before preprocessing. if observation_fn: all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=all_agents_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) if not isinstance(all_agents_obs, dict): raise ValueError( "observe() must return a dict of agent observations") # For each agent in the environment. # type: AgentID, EnvObsType for agent_id, raw_obs in all_agents_obs.items(): assert agent_id != "__all__" last_observation: EnvObsType = episode.last_observation_for( agent_id) agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) # A new agent (initial obs) is already done -> Skip entirely. if last_observation is None and agent_done: continue policy_id: PolicyID = episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) # Infos from the environment. agent_infos = infos[env_id].get(agent_id, {}) episode._set_last_info(agent_id, agent_infos) # Record transition info if applicable. if last_observation is None: sample_collector.add_init_obs(episode, agent_id, env_id, policy_id, episode.length - 1, filtered_obs) else: # Add actions, rewards, next-obs to collectors. values_dict = { "t": episode.length - 1, "env_id": env_id, "agent_index": episode._agent_index(agent_id), # Action (slot 0) taken at timestep t. "actions": episode.last_action_for(agent_id), # Reward received after taking a at timestep t. "rewards": rewards[env_id][agent_id], # After taking action=a, did we reach terminal? "dones": (False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), # Next observation. "new_obs": filtered_obs, } # Add extra-action-fetches to collectors. pol = policies[policy_id] for key, value in episode.last_pi_info_for(agent_id).items(): if key in pol.view_requirements: values_dict[key] = value # Env infos for this agent. if "infos" in pol.view_requirements: values_dict["infos"] = agent_infos sample_collector.add_action_reward_next_obs( episode.episode_id, agent_id, env_id, policy_id, agent_done, values_dict) if not agent_done: item = PolicyEvalData( env_id, agent_id, filtered_obs, agent_infos, None if last_observation is None else episode.rnn_state_for(agent_id), None if last_observation is None else episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0) to_eval[policy_id].append(item) # Invoke the `on_episode_step` callback after the step is logged # to the episode. # Exception: The very first env.poll() call causes the env to get reset # (no step taken yet, just a single starting observation logged). # We need to skip this callback in this case. if episode.length > 0: callbacks.on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_id) # Episode is done for all agents (dones[__all__] == True) # or we hit the horizon. if all_agents_done: is_done = dones[env_id]["__all__"] check_dones = is_done and not no_done_at_end # If, we are not allowed to pack the next episode into the same # SampleBatch (batch_mode=complete_episodes) -> Build the # MultiAgentBatch from a single episode and add it to "outputs". # Otherwise, just postprocess and continue collecting across # episodes. ma_sample_batch = sample_collector.postprocess_episode( episode, is_done=is_done or (hit_horizon and not soft_horizon), check_dones=check_dones, build=not multiple_episodes_in_batch) if ma_sample_batch: outputs.append(ma_sample_batch) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end(policy=p, environment=base_env, episode=episode, tf_sess=getattr( p, "_sess", None)) # Call custom on_episode_end callback. callbacks.on_episode_end( worker=worker, base_env=base_env, policies=policies, episode=episode, env_index=env_id, ) # Horizon hit and we have a soft horizon (no hard env reset). if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs else: del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(env_id) # Reset not supported, drop this env from the ready list. if resetted_obs is None: if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll. elif resetted_obs != ASYNC_RESET_RETURN: new_episode: MultiAgentEpisode = active_episodes[env_id] if observation_fn: resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=resetted_obs, worker=worker, base_env=base_env, policies=policies, episode=new_episode) # type: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = new_episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise( preprocessors, policy_id).transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) new_episode._set_last_observation(agent_id, filtered_obs) # Add initial obs to buffer. sample_collector.add_init_obs(new_episode, agent_id, env_id, policy_id, new_episode.length - 1, filtered_obs) item = PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), None, 0.0) to_eval[policy_id].append(item) # Try to build something. if multiple_episodes_in_batch: sample_batches = \ sample_collector.try_build_truncated_episode_multi_agent_batch() if sample_batches: outputs.extend(sample_batches) return active_envs, to_eval, outputs
def _process_observations(base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, rollout_fragment_length, pack, callbacks, soft_horizon, no_done_at_end): """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ active_envs = set() to_eval = defaultdict(list) outputs = [] large_batch_threshold = max(1000, rollout_fragment_length * 10) if \ rollout_fragment_length != float("inf") else 5000 # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > large_batch_threshold and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly and that" " it terminates at some point. " "Note: In multi-agent environments, `rollout_fragment_length` " "sets the batch size based on environment steps, not the " "steps of " "individual agents, which can result in unexpectedly large " "batches. Also, you may be in evaluation waiting for your Env " "to terminate (batch_mode=`complete_episodes`). Make sure it " "does at some point.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data)) else: hit_horizon = False all_done = False active_envs.add(env_id) # For each agent in the environment. for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_agent_data(): if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= rollout_fragment_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): p.exploration.on_episode_end( policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) # Call custom on_episode_end callback. if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs = agent_obs else: del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # Creates a new episode if this is not async return # If reset is async, we will get its result in some future poll episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _process_observations( worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], batch_builder_pool: List[MultiAgentSampleBatchBuilder], active_episodes: Dict[str, MultiAgentEpisode], unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], rewards: Dict[EnvID, Dict[AgentID, float]], dones: Dict[EnvID, Dict[AgentID, bool]], infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], horizon: int, preprocessors: Dict[PolicyID, Preprocessor], obs_filters: Dict[PolicyID, Filter], rollout_fragment_length: int, pack_multiple_episodes_in_batch: bool, callbacks: "DefaultCallbacks", soft_horizon: bool, no_done_at_end: bool, observation_fn: "ObservationFunction", _use_trajectory_view_api: bool = False ) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[ RolloutMetrics, SampleBatchType]]]: """Record new data from the environment and prepare for policy evaluation. Args: worker (RolloutWorker): Reference to the current rollout worker. base_env (BaseEnv): Env implementing BaseEnv. policies (dict): Map of policy ids to Policy instances. batch_builder_pool (List[SampleBatchBuilder]): List of pooled SampleBatchBuilder object for recycling. active_episodes (Dict[str, MultiAgentEpisode]): Mapping from episode ID to currently ongoing MultiAgentEpisode object. unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids -> unfiltered observation tensor, returned by a `BaseEnv.poll()` call. rewards (dict): Doubly keyed dict of env-ids -> agent ids -> rewards tensor, returned by a `BaseEnv.poll()` call. dones (dict): Doubly keyed dict of env-ids -> agent ids -> boolean done flags, returned by a `BaseEnv.poll()` call. infos (dict): Doubly keyed dict of env-ids -> agent ids -> info dicts, returned by a `BaseEnv.poll()` call. horizon (int): Horizon of the episode. preprocessors (dict): Map of policy id to preprocessor for the observations prior to filtering. obs_filters (dict): Map of policy id to filter used to process observations for the policy. rollout_fragment_length (int): Number of episode steps before `SampleBatch` is yielded. Set to infinity to yield complete episodes. pack_multiple_episodes_in_batch (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks (DefaultCallbacks): User callbacks to run on episode events. soft_horizon (bool): Calculate rewards but don't reset the environment when the horizon is hit. no_done_at_end (bool): Ignore the done=True at the end of the episode and instead record done=False. observation_fn (ObservationFunction): Optional multi-agent observation func to use for preprocessing observations. _use_trajectory_view_api (bool): Whether to use the (experimental) `_use_trajectory_view_api` to make generic trajectory views available to Models. Default: False. Returns: Tuple: - active_envs: Set of non-terminated env ids. - to_eval: Map of policy_id to list of agent PolicyEvalData. - outputs: List of metrics and samples to return from the sampler. """ # Output objects. active_envs: Set[EnvID] = set() to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] large_batch_threshold: int = max(1000, rollout_fragment_length * 10) if \ rollout_fragment_length != float("inf") else 5000 # For each environment. # type: EnvID, Dict[AgentID, EnvObsType] for env_id, agent_obs in unfiltered_obs.items(): is_new_episode: bool = env_id not in active_episodes episode: MultiAgentEpisode = active_episodes[env_id] if not is_new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > large_batch_threshold and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check " "that you set a horizon on your environment correctly and " "that it terminates at some point. " "Note: In multi-agent environments, `rollout_fragment_length` " "sets the batch size based on environment steps, not the " "steps of " "individual agents, which can result in unexpectedly large " "batches. Also, you may be in evaluation waiting for your Env " "to terminate (batch_mode=`complete_episodes`). Make sure it " "does at some point.") # Check episode termination conditions. if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_agents_done = True atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics( base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {}, episode.hist_data)) else: hit_horizon = False all_agents_done = False active_envs.add(env_id) # Custom observation function is applied before preprocessing. if observation_fn: agent_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=agent_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) if not isinstance(agent_obs, dict): raise ValueError( "observe() must return a dict of agent observations") # For each agent in the environment. # type: AgentID, EnvObsType for agent_id, raw_obs in agent_obs.items(): assert agent_id != "__all__" policy_id: PolicyID = episode.policy_for(agent_id) prep_obs: EnvObsType = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs: EnvObsType = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_agents_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation: EnvObsType = episode.last_observation_for( agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable. if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode callbacks.on_episode_step( worker=worker, base_env=base_env, episode=episode) # Cut the batch if ... # - all-agents-done and not packing multiple episodes into one # (batch_mode="complete_episodes") # - or if we've exceeded the rollout_fragment_length. if episode.batch_builder.has_pending_agent_data(): # Sanity check, whether all agents have done=True, if done[__all__] # is True. if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() # Reached end of episode and we are not allowed to pack the # next episode into the same SampleBatch -> Build the SampleBatch # and add it to "outputs". if (all_agents_done and not pack_multiple_episodes_in_batch) or \ episode.batch_builder.count >= rollout_fragment_length: outputs.append(episode.batch_builder.build_and_reset(episode)) # Make sure postprocessor stays within one episode. elif all_agents_done: episode.batch_builder.postprocess_batch_so_far(episode) # Episode is done. if all_agents_done: # Handle episode termination. batch_builder_pool.append(episode.batch_builder) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) # Call custom on_episode_end callback. callbacks.on_episode_end( worker=worker, base_env=base_env, policies=policies, episode=episode) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = agent_obs else: del active_episodes[env_id] resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset( env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list. if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll episode: MultiAgentEpisode = active_episodes[env_id] if observation_fn: resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( agent_obs=resetted_obs, worker=worker, base_env=base_env, policies=policies, episode=episode) # type: AgentID, EnvObsType for agent_id, raw_obs in resetted_obs.items(): policy_id: PolicyID = episode.policy_for(agent_id) policy: Policy = _get_or_raise(policies, policy_id) prep_obs: EnvObsType = _get_or_raise( preprocessors, policy_id).transform(raw_obs) filtered_obs: EnvObsType = _get_or_raise( obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( flatten_to_single_ndarray( policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _process_observations(base_env, policies, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks, soft_horizon, no_done_at_end): """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ global i global tmp_dic global traffic_light_node_dict i += 1 def inter_num_2_id(num): return list(tmp_dic.keys())[list(tmp_dic.values()).index(num)] def read_traffic_light_node_dict(): path_to_read = os.path.join(record_dir, 'traffic_light_node_dict.conf') with open(path_to_read, 'r') as f: traffic_light_node_dict = eval(f.read()) print("Read traffic_light_node_dict") return traffic_light_node_dict if i <= 1: # 此处用于从配置文件读入 neighbor 情况 record_dir = base_env.envs[0].record_dir traffic_light_node_dict = base_env.envs[0].traffic_light_node_dict tmp_dic = traffic_light_node_dict['intersection_1_1'][ 'inter_id_to_index'] active_envs = set() to_eval = defaultdict(list) outputs = [] # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > max(1000, unroll_length * 10) and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly. Note " "that in multi-agent environments, `sample_batch_size` sets " "the batch size based on environment steps, not the steps of " "individual agents, which can result in unexpectedly large " "batches.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {})) else: hit_horizon = False all_done = False active_envs.add(env_id) # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) # eg: "policy_0" # print(policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) ''' For Attention !!!!!!!!!!!!!!!!!!!! 这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值 ''' # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式 neighbor_pid_list = [ 'policy_{}'.format(pid_) for pid_ in traffic_light_node_dict[inter_num_2_id( int(policy_id.split('_')[1]))]['adjacency_row'] if pid_ != None ] # print(neighbor_pid_list) neighbor_obs = [] neighbor_obs.append([]) # Size: (1, 5, 15) 只有这个形式才能传入neighbor_obs (batch, 5, 15) 的 Placeholder i = 0 for neighbor_id in neighbor_pid_list: neighbor_prep_obs = _get_or_raise( preprocessors, neighbor_id).transform(raw_obs) neighbor_filtered_obs = _get_or_raise( obs_filters, neighbor_id)(neighbor_prep_obs) neighbor_obs[0].append(neighbor_filtered_obs) i += 1 neighbor_obs = np.array(neighbor_obs).reshape( (len(neighbor_pid_list), len(raw_obs))) # (5, 29) # ------------------------------------------------------------------ if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, neighbor_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs = agent_obs else: del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # Creates a new episode if this is not async return # If reset is async, we will get its result in some future poll episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) # eg: "policy_0" policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) # print('policy_id' + str(policy_id)) # print('filtered_obs' + str(filtered_obs)) ''' For Attention !!!!!!!!!!!!!!!!!!!! 这里是episode终止, create a new episode 这里要执行的是实时的Q eval, 因此要Q eval 网络传neighbor_obs值 ''' # 根据 traffic_light_node_dict 字典中的路网关系, 找到当前 policy_id 的 neighbor, 并保存成 "policy_0" 的形式 neighbor_pid_list = [ 'policy_{}'.format(pid_) for pid_ in traffic_light_node_dict[inter_num_2_id( int(policy_id.split('_')[1]))]['adjacency_row'] if pid_ != None ] # print(neighbor_pid_list) neighbor_obs = [] neighbor_obs.append([]) # Size: (1, 5, 29) 只有这个形式才能传入neighbor_obs (batch, 5, 17) 的 Placeholder i = 0 for neighbor_id in neighbor_pid_list: neighbor_prep_obs = _get_or_raise( preprocessors, neighbor_id).transform(raw_obs) neighbor_filtered_obs = _get_or_raise( obs_filters, neighbor_id)(neighbor_prep_obs) neighbor_obs[0].append(neighbor_filtered_obs) i += 1 neighbor_obs = np.squeeze(np.array(neighbor_obs)) # ------------------------------------------------------------------ episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, neighbor_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) return active_envs, to_eval, outputs
def _process_observations(base_env, policies, policies_to_train, dead_policies, policy_config, observation_filter, tf_sess, batch_builder_pool, active_episodes, unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon, preprocessors, obs_filters, unroll_length, pack, callbacks, soft_horizon, no_done_at_end): #===MOD=== """Record new data from the environment and prepare for policy evaluation. Returns: active_envs: set of non-terminated env ids to_eval: map of policy_id to list of agent PolicyEvalData outputs: list of metrics and samples to return from the sampler """ active_envs = set() to_eval = defaultdict(list) outputs = [] # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes episode = active_episodes[env_id] if not new_episode: episode.length += 1 episode.batch_builder.count += 1 episode._add_agent_rewards(rewards[env_id]) if (episode.batch_builder.total() > max(1000, unroll_length * 10) and log_once("large_batch_warning")): logger.warning( "More than {} observations for {} env steps ".format( episode.batch_builder.total(), episode.batch_builder.count) + "are buffered in " "the sampler. If this is more than you expected, check that " "that you set a horizon on your environment correctly. Note " "that in multi-agent environments, `sample_batch_size` sets " "the batch size based on environment steps, not the steps of " "individual agents, which can result in unexpectedly large " "batches.") # Check episode termination conditions if dones[env_id]["__all__"] or episode.length >= horizon: # DEBUG # print("Trying to terminate.") # print("Dones of __all__ is set:", dones[env_id]["__all__"]) # print("Horizon hit:", episode.length >= horizon) hit_horizon = (episode.length >= horizon and not dones[env_id]["__all__"]) all_done = True atari_metrics = _fetch_atari_metrics(base_env) if atari_metrics is not None: for m in atari_metrics: outputs.append( m._replace(custom_metrics=episode.custom_metrics)) else: outputs.append( RolloutMetrics(episode.length, episode.total_reward, dict(episode.agent_rewards), episode.custom_metrics, {})) else: hit_horizon = False all_done = False active_envs.add(env_id) #===MOD=== additional_builders_ids = set() #===MOD=== # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): #===MOD=== policy_id, policy_constructor_tuple = episode.policy_for(agent_id) pols_tuple = generate_policies( policy_id, policy_constructor_tuple, policies, policies_to_train, dead_policies, policy_config, preprocessors, obs_filters, observation_filter, tf_sess, ) policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple #===MOD=== prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) if log_once("prep_obs"): logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) if log_once("filtered_obs"): logger.info("Filtered obs: {}".format(summarize(filtered_obs))) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( PolicyEvalData(env_id, agent_id, filtered_obs, infos[env_id].get(agent_id, {}), episode.rnn_state_for(agent_id), episode.last_action_for(agent_id), rewards[env_id][agent_id] or 0.0)) last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable if (last_observation is not None and infos[env_id].get( agent_id, {}).get("training_enabled", True)): #===MOD=== additional_builders_ids.add(agent_id) #===MOD=== episode.batch_builder.add_values( agent_id, policy_id, t=episode.length - 1, eps_id=episode.episode_id, agent_index=episode._agent_index(agent_id), obs=last_observation, actions=episode.last_action_for(agent_id), rewards=rewards[env_id][agent_id], prev_actions=episode.prev_action_for(agent_id), prev_rewards=episode.prev_reward_for(agent_id), dones=(False if (no_done_at_end or (hit_horizon and soft_horizon)) else agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, **episode.last_pi_info_for(agent_id)) #===MOD=== if agent_done: # Does it make sense to remove agent id from `agent_builders`? dead_policies.add(policy_id) print("Removing agent id from agent builders: %s" % str(agent_id)) episode.batch_builder.agent_builders.pop(agent_id) if policy_id in to_eval: to_eval.pop(policy_id) # print("Popping policy id from toeval.") #===MOD=== start = time.time() #===MOD=== print("sampler.py: ids added to agent builders:\t", additional_builders_ids) # Update ``self.policy_map`` in ``MultiAgentSampleBatchBuilder``. # TODO: policies is not being pruned in this file. episode.batch_builder.policy_map = policies print("sampler.py: policies: \t", policies.keys()) #===MOD=== # Invoke the step callback after the step is logged to the episode if callbacks.get("on_episode_step"): callbacks["on_episode_step"]({"env": base_env, "episode": episode}) # Cut the batch if we're not packing multiple episodes into one, # or if we've exceeded the requested batch size. if episode.batch_builder.has_pending_data(): if dones[env_id]["__all__"] and not no_done_at_end: episode.batch_builder.check_missing_dones() if (all_done and not pack) or \ episode.batch_builder.count >= unroll_length: outputs.append(episode.batch_builder.build_and_reset(episode)) elif all_done: # Make sure postprocessor stays within one episode # KEYERROR episode.batch_builder.postprocess_batch_so_far(episode) if all_done: # Handle episode termination batch_builder_pool.append(episode.batch_builder) if callbacks.get("on_episode_end"): callbacks["on_episode_end"]({ "env": base_env, "policy": policies, "episode": episode }) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs = agent_obs else: del active_episodes[env_id] resetted_obs = base_env.try_reset(env_id) if resetted_obs is None: # Reset not supported, drop this env from the ready list if horizon != float("inf"): raise ValueError( "Setting episode horizon requires reset() support " "from the environment.") elif resetted_obs != ASYNC_RESET_RETURN: # print("Executing new epsiode non-async return.") time.sleep(1) raise NotImplementedError( "Multiple episodes not supported by design.") # Creates a new episode if this is not async return # If reset is async, we will get its result in some future poll episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): #===MOD=== policy_id, policy_constructor_tuple = episode.policy_for( agent_id) # with tf_sess.as_default(): pols_tuple = generate_policies( policy_id, policy_constructor_tuple, policies, policies_to_train, dead_policies, policy_config, preprocessors, obs_filters, observation_filter, tf_sess, ) policies, preprocessors, obs_filters, policies_to_train, dead_policies = pols_tuple #===MOD=== policy = _get_or_raise(policies, policy_id) prep_obs = _get_or_raise(preprocessors, policy_id).transform(raw_obs) filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs) episode._set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.last_info_for(agent_id) or {}, episode.rnn_state_for(agent_id), np.zeros_like( _flatten_action(policy.action_space.sample())), 0.0)) #===MOD=== pols_tuple = (policies, preprocessors, obs_filters, policies_to_train, dead_policies) #===MOD=== #===MOD=== return active_envs, to_eval, outputs, pols_tuple