コード例 #1
0
def test_score_agent(config):
    """Smoke test for score agent to check it runs with some different configs."""
    config = dict(config)
    if "episodes" not in config:
        config["episodes"] = 1  # speed up tests
    config[
        "render"] = False  # faster without, test_experiment already tests with render

    run = score_ex.run(config_updates=config)
    assert run.status == "COMPLETED"

    outcomes = [run.result[k] for k in ["ties", "win0", "win1"]]
    assert sum(outcomes) == run.config["episodes"]

    if config.get("record_traj", False):
        try:
            for i in range(2):
                traj_file_path = os.path.join(
                    config["record_traj_params"]["save_dir"], f"agent_{i}.npz")
                traj_data = np.load(traj_file_path)
                assert set(traj_data.keys()).issuperset(
                    ["observations", "actions", "rewards"])
                for k, ep_data in traj_data.items():
                    assert len(ep_data) == config[
                        "episodes"], f"unexpected array length at '{k}'"
                os.remove(traj_file_path)
        finally:
            os.rmdir(config["record_traj_params"]["save_dir"])
コード例 #2
0
def score_worker(base_config, tune_config, reporter):
    """Run a aprl.score experiment with specified config, logging to reporter.

    :param base_config: (dict) default config
    :param tune_config: (dict) overrides values in base_config
    :param reporter: (ray.tune.StatusReporter) Ray Tune internal logger."""
    common_worker.fix_sacred_capture()

    # score_ex is not pickleable, so we cannot close on it.
    # Instead, import inside the function.
    from aprl.score_agent import score_ex

    config = dict(base_config)
    tune_config = common_worker.flatten_config(tune_config)
    common_worker.update(config, tune_config)

    # We're breaking the Sacred interface by running an experiment from within another experiment.
    # This is the best thing we can do, since we need to run the experiment with varying configs.
    # Just be careful: this could easily break things.
    observer = observers.FileStorageObserver(
        osp.join("data", "sacred", "score"))
    score_ex.observers.append(observer)
    run = score_ex.run(config_updates=config)
    index_keys = config.get("index_keys", [])

    idx = {
        k: v
        for k, v in config.items()
        if k.startswith("agent") or k == "env_name" or k in index_keys
    }

    reporter(done=True, score=run.result, idx=idx)
コード例 #3
0
def test_score_agent_video():
    # Confirm that experiment runs properly saving videos to a temp dir
    none_dir_run = score_ex.run(
        config_updates=SCORE_AGENT_VIDEO_CONFIGS["none_dir"])
    assert none_dir_run.status == "COMPLETED"

    try:
        # Confirm that the first time you try to save videos to a specified dir, it works properly
        specified_dir_run = score_ex.run(
            config_updates=SCORE_AGENT_VIDEO_CONFIGS["specified_dir"])
        assert specified_dir_run.status == "COMPLETED"

        # Confirm that the second time you try to save videos to the same specified dir, it fails
        with pytest.raises(AssertionError):
            _ = score_ex.run(
                config_updates=SCORE_AGENT_VIDEO_CONFIGS["specified_dir"])
    finally:
        shutil.rmtree(SCORE_AGENT_VIDEO_CONFIGS["specified_dir"]
                      ["video_params"]["save_dir"])