def test_returning_model_based_rollouts_data(self): class ModelBasedPolicy(DQNTFPolicy): def compute_actions_from_input_dict(self, input_dict, explore=None, timestep=None, episodes=None, **kwargs): obs_batch = input_dict["obs"] # In policy loss initialization phase, no episodes are passed # in. if episodes is not None: # Pretend we did a model-based rollout and want to return # the extra trajectory. env_id = episodes[0].env_id fake_eps = MultiAgentEpisode( episodes[0]._policies, episodes[0]._policy_mapping_fn, lambda: None, lambda x: None, env_id) builder = get_global_worker().sampler.sample_collector agent_id = "extra_0" policy_id = "p1" # use p1 so we can easily check it builder.add_init_obs(fake_eps, agent_id, env_id, policy_id, -1, obs_batch[0]) for t in range(4): builder.add_action_reward_next_obs( episode_id=fake_eps.episode_id, agent_id=agent_id, env_id=env_id, policy_id=policy_id, agent_done=t == 3, values=dict( t=t, actions=0, rewards=0, dones=t == 3, infos={}, new_obs=obs_batch[0])) batch = builder.postprocess_episode( episode=fake_eps, build=True) episodes[0].add_extra_batch(batch) # Just return zeros for actions return [0] * len(obs_batch), [], {} single_env = gym.make("CartPole-v0") obs_space = single_env.observation_space act_space = single_env.action_space ev = RolloutWorker( env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}), policy_spec={ "p0": (ModelBasedPolicy, obs_space, act_space, {}), "p1": (ModelBasedPolicy, obs_space, act_space, {}), }, policy_config={"_use_trajectory_view_api": True}, policy_mapping_fn=lambda agent_id: "p0", rollout_fragment_length=5) batch = ev.sample() # 5 environment steps (rollout_fragment_length). self.assertEqual(batch.count, 5) # 10 agent steps for p0: 2 agents, both using p0 as their policy. self.assertEqual(batch.policy_batches["p0"].count, 10) # 20 agent steps for p1: Each time both(!) agents takes 1 step, # p1 takes 4: 5 (rollout-fragment length) * 4 = 20 self.assertEqual(batch.policy_batches["p1"].count, 20)
mode="async", output_indexes=[1]) return StandardMetricsReporting(train_op, workers, config) if __name__ == "__main__": args = parser.parse_args() assert not (args.torch and args.mixed_torch_tf),\ "Use either --torch or --mixed-torch-tf, not both!" ray.init() # Simple environment with 4 independent cartpole entities register_env("multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": 4})) single_env = gym.make("CartPole-v0") obs_space = single_env.observation_space act_space = single_env.action_space # Note that since the trainer below does not include a default policy or # policy configs, we have to explicitly set it in the multiagent config: policies = { "ppo_policy": (PPOTorchPolicy if args.torch or args.mixed_torch_tf else PPOTFPolicy, obs_space, act_space, PPO_CONFIG), "dqn_policy": (DQNTorchPolicy if args.torch else DQNTFPolicy, obs_space, act_space, DQN_CONFIG), } def policy_mapping_fn(agent_id):