コード例 #1
0
    def test_returning_model_based_rollouts_data(self):
        class ModelBasedPolicy(DQNTFPolicy):
            def compute_actions_from_input_dict(self,
                                                input_dict,
                                                explore=None,
                                                timestep=None,
                                                episodes=None,
                                                **kwargs):
                obs_batch = input_dict["obs"]
                # In policy loss initialization phase, no episodes are passed
                # in.
                if episodes is not None:
                    # Pretend we did a model-based rollout and want to return
                    # the extra trajectory.
                    env_id = episodes[0].env_id
                    fake_eps = MultiAgentEpisode(
                        episodes[0]._policies, episodes[0]._policy_mapping_fn,
                        lambda: None, lambda x: None, env_id)
                    builder = get_global_worker().sampler.sample_collector
                    agent_id = "extra_0"
                    policy_id = "p1"  # use p1 so we can easily check it
                    builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
                                         -1, obs_batch[0])
                    for t in range(4):
                        builder.add_action_reward_next_obs(
                            episode_id=fake_eps.episode_id,
                            agent_id=agent_id,
                            env_id=env_id,
                            policy_id=policy_id,
                            agent_done=t == 3,
                            values=dict(
                                t=t,
                                actions=0,
                                rewards=0,
                                dones=t == 3,
                                infos={},
                                new_obs=obs_batch[0]))
                    batch = builder.postprocess_episode(
                        episode=fake_eps, build=True)
                    episodes[0].add_extra_batch(batch)

                # Just return zeros for actions
                return [0] * len(obs_batch), [], {}

        single_env = gym.make("CartPole-v0")
        obs_space = single_env.observation_space
        act_space = single_env.action_space
        ev = RolloutWorker(
            env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
            policy_spec={
                "p0": (ModelBasedPolicy, obs_space, act_space, {}),
                "p1": (ModelBasedPolicy, obs_space, act_space, {}),
            },
            policy_config={"_use_trajectory_view_api": True},
            policy_mapping_fn=lambda agent_id: "p0",
            rollout_fragment_length=5)
        batch = ev.sample()
        # 5 environment steps (rollout_fragment_length).
        self.assertEqual(batch.count, 5)
        # 10 agent steps for p0: 2 agents, both using p0 as their policy.
        self.assertEqual(batch.policy_batches["p0"].count, 10)
        # 20 agent steps for p1: Each time both(!) agents takes 1 step,
        # p1 takes 4: 5 (rollout-fragment length) * 4 = 20
        self.assertEqual(batch.policy_batches["p1"].count, 20)
コード例 #2
0
                            mode="async",
                            output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)


if __name__ == "__main__":
    args = parser.parse_args()
    assert not (args.torch and args.mixed_torch_tf),\
        "Use either --torch or --mixed-torch-tf, not both!"

    ray.init()

    # Simple environment with 4 independent cartpole entities
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 4}))
    single_env = gym.make("CartPole-v0")
    obs_space = single_env.observation_space
    act_space = single_env.action_space

    # Note that since the trainer below does not include a default policy or
    # policy configs, we have to explicitly set it in the multiagent config:
    policies = {
        "ppo_policy":
        (PPOTorchPolicy if args.torch or args.mixed_torch_tf else PPOTFPolicy,
         obs_space, act_space, PPO_CONFIG),
        "dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy,
                       obs_space, act_space, DQN_CONFIG),
    }

    def policy_mapping_fn(agent_id):