コード例 #1
0
ファイル: test_learning.py プロジェクト: valaxkong/SMARTS
def test_learning_regression_rllib():
    from examples.rllib.rllib_agent import TrainingModel, rllib_agent

    ModelCatalog.register_custom_model(TrainingModel.NAME, TrainingModel)
    rllib_policies = {
        "policy": (
            None,
            rllib_agent["observation_space"],
            rllib_agent["action_space"],
            {
                "model": {
                    "custom_model": TrainingModel.NAME
                }
            },
        )
    }

    # XXX: We should be able to simply provide "scenarios/loop"?
    scenario_path = Path(__file__).parents[2] / "scenarios/loop"
    scenario_path = str(scenario_path.absolute())

    tune_confg = {
        "env": RLlibHiWayEnv,
        "env_config": {
            "scenarios": [scenario_path],
            "seed": 42,
            "headless": True,
            "agent_specs": {
                "Agent-007": rllib_agent["agent_spec"]
            },
        },
        "multiagent": {
            "policies": rllib_policies,
            "policy_mapping_fn": lambda _: "policy",
        },
        "log_level": "WARN",
        "num_workers": multiprocessing.cpu_count() - 1,
        "horizon": HORIZON,
    }

    analysis = tune.run(
        "PPO",
        name="learning_regression_test",
        stop={"training_iteration": 60},
        max_failures=10,
        local_dir=make_dir_in_smarts_log_dir("smarts_learning_regression"),
        config=tune_confg,
    )

    df = analysis.dataframe()

    # Lower-bound 95% confidence interval of mean reward after one hour, generated by manual analysis.
    # If you need to update this, run tools/regression_rllib.py.
    ci95_reward_file = Path(__file__).parent / "ci95_reward_lo"
    with open(ci95_reward_file.absolute()) as f:
        CI95_REWARD_MEAN_1_HOUR = float(f.readline())

    assert (df["episode_reward_mean"][0] >= CI95_REWARD_MEAN_1_HOUR
            ), "Mean reward did not reach the expected value ({} < {})".format(
                df["episode_reward_mean"][0], CI95_REWARD_MEAN_1_HOUR)
コード例 #2
0
def test_rllib_hiway_env(rllib_agent):
    def on_episode_end(info):
        episode = info["episode"]
        agent_info = episode.last_info_for(AGENT_ID)

        assert INFO_EXTRA_KEY in agent_info, "Failed to apply info adapter."

    rllib_model_name = "FullyConnectedNetwork"
    ModelCatalog.register_custom_model(rllib_model_name, FullyConnectedNetwork)

    rllib_policies = {
        "policy": (
            None,
            rllib_agent["observation_space"],
            rllib_agent["action_space"],
            {"model": {"custom_model": rllib_model_name}},
        )
    }

    # XXX: We should be able to simply provide "scenarios/loop"?
    scenario_path = Path(__file__).parent / "../../../scenarios/loop"
    tune_confg = {
        "env": RLlibHiWayEnv,
        "env_config": {
            "scenarios": [str(scenario_path.absolute())],
            "seed": 42,
            "headless": True,
            "agent_specs": {AGENT_ID: rllib_agent["agent_spec"]},
        },
        "callbacks": {"on_episode_end": on_episode_end},
        "multiagent": {
            "policies": rllib_policies,
            "policy_mapping_fn": lambda _: "policy",
        },
        "log_level": "WARN",
        "num_workers": 1,
    }

    # Test tune with the number of physical cpus with a minimum of 2 cpus
    num_cpus = max(2, psutil.cpu_count(logical=False) - 1)
    ray.init(num_cpus=num_cpus, num_gpus=0)
    analysis = tune.run(
        "PPO",
        name="RLlibHiWayEnv test",
        # terminate as soon as possible (this will run one training iteration)
        stop={"time_total_s": 1},
        max_failures=0,  # On failures, exit immediately
        local_dir=make_dir_in_smarts_log_dir("smarts_rllib_smoke_test"),
        config=tune_confg,
    )

    # trial status will be ERROR if there are any issues with the environment
    assert analysis.get_best_trial("episode_reward_mean").status == "TERMINATED"
コード例 #3
0
    def _resolve_log_dir(self, log_dir):
        if log_dir is None:
            log_dir = make_dir_in_smarts_log_dir("_sumo_run_logs")

        return os.path.abspath(log_dir)
コード例 #4
0
    def _resolve_log_dir(self, log_dir):
        if log_dir is None:
            log_dir = make_dir_in_smarts_log_dir("_duarouter_routing")

        return os.path.abspath(log_dir)