Exemple #1
0
def check_support_multiagent(alg, config):
    register_env("multi_agent_mountaincar",
                 lambda _: MultiAgentMountainCar({"num_agents": 2}))
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 2}))
    config["log_level"] = "ERROR"
    if "DDPG" in alg:
        a = get_agent_class(alg)(config=config, env="multi_agent_mountaincar")
    else:
        a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")
    try:
        a.train()
    finally:
        a.stop()
Exemple #2
0
def check_support_multiagent(alg, config):
    register_env("multi_agent_mountaincar",
                 lambda _: MultiAgentMountainCar({"num_agents": 2}))
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 2}))
    config["log_level"] = "ERROR"
    for _ in framework_iterator(config, frameworks=("torch", "tf")):
        if alg in ["DDPG", "APEX_DDPG", "SAC"]:
            a = get_agent_class(alg)(
                config=config, env="multi_agent_mountaincar")
        else:
            a = get_agent_class(alg)(config=config, env="multi_agent_cartpole")

        print(a.train())
        a.stop()
Exemple #3
0
def check_support_multiagent(alg, config):
    register_env("multi_agent_mountaincar",
                 lambda _: MultiAgentMountainCar({"num_agents": 2}))
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 2}))
    config["log_level"] = "ERROR"
    for fw in framework_iterator(config):
        if fw in ["tf2", "tfe"] and \
                alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
            continue
        if alg in ["DDPG", "APEX_DDPG", "SAC"]:
            a = get_trainer_class(alg)(config=config,
                                       env="multi_agent_mountaincar")
        else:
            a = get_trainer_class(alg)(config=config,
                                       env="multi_agent_cartpole")

        print(a.train())
        a.stop()
Exemple #4
0
def check_support_multiagent(alg, config):
    register_env("multi_agent_mountaincar",
                 lambda _: MultiAgentMountainCar({"num_agents": 2}))
    register_env("multi_agent_cartpole",
                 lambda _: MultiAgentCartPole({"num_agents": 2}))

    # Simulate a simple multi-agent setup.
    policies = {
        "policy_0": PolicySpec(config={"gamma": 0.99}),
        "policy_1": PolicySpec(config={"gamma": 0.95}),
    }
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id, episode, worker, **kwargs):
        pol_id = policy_ids[agent_id]
        return pol_id

    config["multiagent"] = {
        "policies": policies,
        "policy_mapping_fn": policy_mapping_fn,
    }

    for fw in framework_iterator(config):
        if fw in ["tf2", "tfe"] and \
                alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
            continue
        if alg in ["DDPG", "APEX_DDPG", "SAC"]:
            a = get_trainer_class(alg)(config=config,
                                       env="multi_agent_mountaincar")
        else:
            a = get_trainer_class(alg)(config=config,
                                       env="multi_agent_cartpole")

        results = a.train()
        check_train_results(results)
        print(results)
        a.stop()