Esempio n. 1
0
    def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
                     env_id: EnvID, policy_id: PolicyID, t: int,
                     init_obs: TensorType) -> None:
        # Make sure our mappings are up to date.
        agent_key = (episode.episode_id, agent_id)
        if agent_key not in self.agent_key_to_policy_id:
            self.agent_key_to_policy_id[agent_key] = policy_id
        else:
            assert self.agent_key_to_policy_id[agent_key] == policy_id
        policy = self.policy_map[policy_id]
        view_reqs = policy.model.view_requirements if \
            getattr(policy, "model", None) else policy.view_requirements

        # Add initial obs to Trajectory.
        assert agent_key not in self.agent_collectors
        # TODO: determine exact shift-before based on the view-req shifts.
        self.agent_collectors[agent_key] = _AgentCollector(view_reqs)
        self.agent_collectors[agent_key].add_init_obs(
            episode_id=episode.episode_id,
            agent_index=episode._agent_index(agent_id),
            env_id=env_id,
            t=t,
            init_obs=init_obs)

        self.episodes[episode.episode_id] = episode
        if episode.batch_builder is None:
            episode.batch_builder = self.policy_collector_groups.pop() if \
                self.policy_collector_groups else _PolicyCollectorGroup(
                self.policy_map)

        self._add_to_next_inference_call(agent_key)
Esempio n. 2
0
File: sampler.py Progetto: Yard1/ray
 def new_episode(env_id):
     episode = MultiAgentEpisode(worker.policy_map,
                                 worker.policy_mapping_fn,
                                 get_batch_builder,
                                 extra_batch_callback,
                                 env_id=env_id)
     # Call each policy's Exploration.on_episode_start method.
     # Note: This may break the exploration (e.g. ParameterNoise) of
     # policies in the `policy_map` that have not been recently used
     # (and are therefore stashed to disk). However, we certainly do not
     # want to loop through all (even stashed) policies here as that
     # would counter the purpose of the LRU policy caching.
     for p in worker.policy_map.cache.values():
         if getattr(p, "exploration", None) is not None:
             p.exploration.on_episode_start(policy=p,
                                            environment=base_env,
                                            episode=episode,
                                            tf_sess=p.get_session())
     callbacks.on_episode_start(
         worker=worker,
         base_env=base_env,
         policies=worker.policy_map,
         episode=episode,
         env_index=env_id,
     )
     return episode
Esempio n. 3
0
 def new_episode():
     episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                 get_batch_builder, extra_batch_callback)
     if callbacks.get("on_episode_start"):
         callbacks["on_episode_start"]({
             "env": async_vector_env,
             "episode": episode
         })
     return episode
Esempio n. 4
0
    def on_episode_start(
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        episode: MultiAgentEpisode,
        env_index: int,
        **kwargs,
    ):

        episode.user_data["ego_speed"] = []
Esempio n. 5
0
    def on_episode_step(
        worker: RolloutWorker,
        base_env: BaseEnv,
        episode: MultiAgentEpisode,
        env_index: int,
        **kwargs,
    ):

        single_agent_id = list(episode._agent_to_last_obs)[0]
        obs = episode.last_raw_obs_for(single_agent_id)
        episode.user_data["ego_speed"].append(obs["speed"])
Esempio n. 6
0
    def on_episode_end(
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[PolicyID, Policy],
        episode: MultiAgentEpisode,
        env_index: int,
        **kwargs,
    ):

        mean_ego_speed = np.mean(episode.user_data["ego_speed"])
        print(f"ep. {episode.episode_id:<12} ended;"
              f" length={episode.length:<6}"
              f" mean_ego_speed={mean_ego_speed:.2f}")
        episode.custom_metrics["mean_ego_speed"] = mean_ego_speed
Esempio n. 7
0
 def new_episode():
     episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                 get_batch_builder, extra_batch_callback)
     # Call each policy's Exploration.on_episode_start method.
     for p in policies.values():
         p.exploration.on_episode_start(policy=p,
                                        environment=base_env,
                                        episode=episode,
                                        tf_sess=getattr(p, "_sess", None))
     callbacks.on_episode_start(worker=worker,
                                base_env=base_env,
                                policies=policies,
                                episode=episode)
     return episode
Esempio n. 8
0
 def new_episode():
     episode = MultiAgentEpisode(policies, policy_mapping_fn,
                                 get_batch_builder, extra_batch_callback)
     # Call each policy's Exploration.on_episode_start method.
     for p in policies.values():
         p.exploration.on_episode_start(policy=p,
                                        environment=base_env,
                                        episode=episode,
                                        tf_sess=getattr(p, "_sess", None))
     # Call custom on_episode_start callback.
     if callbacks.get("on_episode_start"):
         callbacks["on_episode_start"]({
             "env": base_env,
             "policy": policies,
             "episode": episode,
         })
     return episode
Esempio n. 9
0
            def compute_actions_from_input_dict(self,
                                                input_dict,
                                                explore=None,
                                                timestep=None,
                                                episodes=None,
                                                **kwargs):
                obs_batch = input_dict["obs"]
                # In policy loss initialization phase, no episodes are passed
                # in.
                if episodes is not None:
                    # Pretend we did a model-based rollout and want to return
                    # the extra trajectory.
                    env_id = episodes[0].env_id
                    fake_eps = MultiAgentEpisode(
                        episodes[0].policy_map, episodes[0]._policy_mapping_fn,
                        lambda: None, lambda x: None, env_id)
                    builder = get_global_worker().sampler.sample_collector
                    agent_id = "extra_0"
                    policy_id = "p1"  # use p1 so we can easily check it
                    builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
                                         -1, obs_batch[0])
                    for t in range(4):
                        builder.add_action_reward_next_obs(
                            episode_id=fake_eps.episode_id,
                            agent_id=agent_id,
                            env_id=env_id,
                            policy_id=policy_id,
                            agent_done=t == 3,
                            values=dict(
                                t=t,
                                actions=0,
                                rewards=0,
                                dones=t == 3,
                                infos={},
                                new_obs=obs_batch[0]))
                    batch = builder.postprocess_episode(
                        episode=fake_eps, build=True)
                    episodes[0].add_extra_batch(batch)

                # Just return zeros for actions
                return [0] * len(obs_batch), [], {}
Esempio n. 10
0
 def new_episode(env_id):
     episode = MultiAgentEpisode(worker.policy_map,
                                 worker.policy_mapping_fn,
                                 get_batch_builder,
                                 extra_batch_callback,
                                 env_id=env_id)
     # Call each policy's Exploration.on_episode_start method.
     # types: Policy
     for p in worker.policy_map.values():
         if getattr(p, "exploration", None) is not None:
             p.exploration.on_episode_start(policy=p,
                                            environment=base_env,
                                            episode=episode,
                                            tf_sess=p.get_session())
     callbacks.on_episode_start(
         worker=worker,
         base_env=base_env,
         policies=worker.policy_map,
         episode=episode,
         env_index=env_id,
     )
     return episode
Esempio n. 11
0
 def new_episode():
     return MultiAgentEpisode(policies, policy_mapping_fn,
                              get_batch_builder, extra_batch_callback)