def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: # Make sure our mappings are up to date. agent_key = (episode.episode_id, agent_id) if agent_key not in self.agent_key_to_policy_id: self.agent_key_to_policy_id[agent_key] = policy_id else: assert self.agent_key_to_policy_id[agent_key] == policy_id policy = self.policy_map[policy_id] view_reqs = policy.model.view_requirements if \ getattr(policy, "model", None) else policy.view_requirements # Add initial obs to Trajectory. assert agent_key not in self.agent_collectors # TODO: determine exact shift-before based on the view-req shifts. self.agent_collectors[agent_key] = _AgentCollector(view_reqs) self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, agent_index=episode._agent_index(agent_id), env_id=env_id, t=t, init_obs=init_obs) self.episodes[episode.episode_id] = episode if episode.batch_builder is None: episode.batch_builder = self.policy_collector_groups.pop() if \ self.policy_collector_groups else _PolicyCollectorGroup( self.policy_map) self._add_to_next_inference_call(agent_key)
def new_episode(env_id): episode = MultiAgentEpisode(worker.policy_map, worker.policy_mapping_fn, get_batch_builder, extra_batch_callback, env_id=env_id) # Call each policy's Exploration.on_episode_start method. # Note: This may break the exploration (e.g. ParameterNoise) of # policies in the `policy_map` that have not been recently used # (and are therefore stashed to disk). However, we certainly do not # want to loop through all (even stashed) policies here as that # would counter the purpose of the LRU policy caching. for p in worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start(policy=p, environment=base_env, episode=episode, tf_sess=p.get_session()) callbacks.on_episode_start( worker=worker, base_env=base_env, policies=worker.policy_map, episode=episode, env_index=env_id, ) return episode
def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) if callbacks.get("on_episode_start"): callbacks["on_episode_start"]({ "env": async_vector_env, "episode": episode }) return episode
def on_episode_start( worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs, ): episode.user_data["ego_speed"] = []
def on_episode_step( worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs, ): single_agent_id = list(episode._agent_to_last_obs)[0] obs = episode.last_raw_obs_for(single_agent_id) episode.user_data["ego_speed"].append(obs["speed"])
def on_episode_end( worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs, ): mean_ego_speed = np.mean(episode.user_data["ego_speed"]) print(f"ep. {episode.episode_id:<12} ended;" f" length={episode.length:<6}" f" mean_ego_speed={mean_ego_speed:.2f}") episode.custom_metrics["mean_ego_speed"] = mean_ego_speed
def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) # Call each policy's Exploration.on_episode_start method. for p in policies.values(): p.exploration.on_episode_start(policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) callbacks.on_episode_start(worker=worker, base_env=base_env, policies=policies, episode=episode) return episode
def new_episode(): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) # Call each policy's Exploration.on_episode_start method. for p in policies.values(): p.exploration.on_episode_start(policy=p, environment=base_env, episode=episode, tf_sess=getattr(p, "_sess", None)) # Call custom on_episode_start callback. if callbacks.get("on_episode_start"): callbacks["on_episode_start"]({ "env": base_env, "policy": policies, "episode": episode, }) return episode
def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs): obs_batch = input_dict["obs"] # In policy loss initialization phase, no episodes are passed # in. if episodes is not None: # Pretend we did a model-based rollout and want to return # the extra trajectory. env_id = episodes[0].env_id fake_eps = MultiAgentEpisode( episodes[0].policy_map, episodes[0]._policy_mapping_fn, lambda: None, lambda x: None, env_id) builder = get_global_worker().sampler.sample_collector agent_id = "extra_0" policy_id = "p1" # use p1 so we can easily check it builder.add_init_obs(fake_eps, agent_id, env_id, policy_id, -1, obs_batch[0]) for t in range(4): builder.add_action_reward_next_obs( episode_id=fake_eps.episode_id, agent_id=agent_id, env_id=env_id, policy_id=policy_id, agent_done=t == 3, values=dict( t=t, actions=0, rewards=0, dones=t == 3, infos={}, new_obs=obs_batch[0])) batch = builder.postprocess_episode( episode=fake_eps, build=True) episodes[0].add_extra_batch(batch) # Just return zeros for actions return [0] * len(obs_batch), [], {}
def new_episode(env_id): episode = MultiAgentEpisode(worker.policy_map, worker.policy_mapping_fn, get_batch_builder, extra_batch_callback, env_id=env_id) # Call each policy's Exploration.on_episode_start method. # types: Policy for p in worker.policy_map.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start(policy=p, environment=base_env, episode=episode, tf_sess=p.get_session()) callbacks.on_episode_start( worker=worker, base_env=base_env, policies=worker.policy_map, episode=episode, env_index=env_id, ) return episode
def new_episode(): return MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback)