Exemplo n.º 1
0
    def test_train_multi_agent_cartpole_multi_policy(self):
        n = 10
        register_env(
            "multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": n})
        )

        def gen_policy():
            config = {
                "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
                "n_step": random.choice([1, 2, 3, 4, 5]),
            }
            return PolicySpec(config=config)

        pg = PGTrainer(
            env="multi_agent_cartpole",
            config={
                "num_workers": 0,
                "multiagent": {
                    "policies": {
                        "policy_1": gen_policy(),
                        "policy_2": gen_policy(),
                    },
                    "policy_mapping_fn": lambda aid, **kwargs: "policy_1",
                },
                "framework": "tf",
            },
        )

        # Just check that it runs without crashing
        for i in range(10):
            result = pg.train()
            print(
                "Iteration {}, reward {}, timesteps {}".format(
                    i, result["episode_reward_mean"], result["timesteps_total"]
                )
            )
        self.assertTrue(
            pg.compute_single_action([0, 0, 0, 0], policy_id="policy_1") in [0, 1]
        )
        self.assertTrue(
            pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1]
        )
        self.assertRaisesRegex(
            KeyError,
            "not found in PolicyMap",
            lambda: pg.compute_single_action([0, 0, 0, 0], policy_id="policy_3"),
        )