예제 #1
0
def check_support_multiagent(alg, config):
    register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
    register_env("multi_cartpole", lambda _: MultiCartpole(2))
    if "DDPG" in alg:
        a = get_agent_class(alg)(config=config, env="multi_mountaincar")
    else:
        a = get_agent_class(alg)(config=config, env="multi_cartpole")
    try:
        a.train()
    finally:
        a.stop()
예제 #2
0
파일: test_io.py 프로젝트: zofuthan/ray
    def testMultiAgent(self):
        register_env("multi_cartpole", lambda _: MultiCartpole(10))
        single_env = gym.make("CartPole-v0")

        def gen_policy():
            obs_space = single_env.observation_space
            act_space = single_env.action_space
            return (PGPolicyGraph, obs_space, act_space, {})

        pg = PGAgent(
            env="multi_cartpole",
            config={
                "num_workers": 0,
                "output": self.test_dir,
                "multiagent": {
                    "policy_graphs": {
                        "policy_1": gen_policy(),
                        "policy_2": gen_policy(),
                    },
                    "policy_mapping_fn":
                    (lambda agent_id: random.choice(["policy_1", "policy_2"])),
                },
            })
        pg.train()
        self.assertEqual(len(os.listdir(self.test_dir)), 1)

        pg.stop()
        pg = PGAgent(
            env="multi_cartpole",
            config={
                "num_workers": 0,
                "input": self.test_dir,
                "input_evaluation": "simulation",
                "train_batch_size": 2000,
                "multiagent": {
                    "policy_graphs": {
                        "policy_1": gen_policy(),
                        "policy_2": gen_policy(),
                    },
                    "policy_mapping_fn":
                    (lambda agent_id: random.choice(["policy_1", "policy_2"])),
                },
            })
        for _ in range(50):
            result = pg.train()
            if not np.isnan(result["episode_reward_mean"]):
                return  # simulation ok
            time.sleep(0.1)
        assert False, "did not see any simulation results"
                                              64,
                                              activation_fn=tf.nn.relu,
                                              scope="fc1")
        output = slim.fully_connected(last_layer,
                                      num_outputs,
                                      activation_fn=None,
                                      scope="fc_out")
        return output, last_layer


if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    # Simple environment with `num_agents` independent cartpole entities
    register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
    ModelCatalog.register_custom_model("model1", CustomModel1)
    ModelCatalog.register_custom_model("model2", CustomModel2)
    single_env = gym.make("CartPole-v0")
    obs_space = single_env.observation_space
    act_space = single_env.action_space

    # Each policy can have a different configuration (including custom model)
    def gen_policy(i):
        config = {
            "model": {
                "custom_model": ["model1", "model2"][i % 2],
            },
            "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
            "n_step": random.choice([1, 2, 3, 4, 5]),
        }
예제 #4
0
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.ppo.ppo import PPOAgent
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env

parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=20)

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    # Simple environment with 4 independent cartpole entities
    register_env("multi_cartpole", lambda _: MultiCartpole(4))
    single_env = gym.make("CartPole-v0")
    obs_space = single_env.observation_space
    act_space = single_env.action_space

    # You can also have multiple policy graphs per trainer, but here we just
    # show one each for PPO and DQN.
    policy_graphs = {
        "ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}),
        "dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}),
    }

    def policy_mapping_fn(agent_id):
        if agent_id % 2 == 0:
            return "ppo_policy"
        else: