Esempio n. 1
0
def test_random_search_smoke_test():
    with tempfile.TemporaryDirectory() as tmp:
        outdir = Path(tmp)
        flags.FLAGS.unparse_flags()
        flags.FLAGS(["argv0"])
        random_search(
            make_env=make_env,
            outdir=outdir,
            patience=50,
            total_runtime=3,
            nproc=1,
            skip_done=False,
        )

        assert (outdir / "random_search.json").is_file()
        assert (outdir / "random_search_progress.csv").is_file()
        assert (outdir / "random_search_best_actions.txt").is_file()
        assert (outdir / "optimized.bc").is_file()

        env = make_env()
        try:
            replay_actions_from_logs(env, Path(outdir))
            assert (outdir /
                    "random_search_best_actions_progress.csv").is_file()
            assert (outdir /
                    "random_search_best_actions_commandline.txt").is_file()
        finally:
            env.close()
Esempio n. 2
0
def main(argv):
    """Main entry point."""
    argv = FLAGS(argv)
    if len(argv) != 1:
        raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

    if FLAGS.ls_benchmark:
        env = env_from_flags()
        print("\n".join(sorted(env.benchmarks)))
        env.close()
        return
    if FLAGS.ls_reward:
        env = env_from_flags()
        print("\n".join(sorted(env.reward.indices.keys())))
        env.close()
        return

    assert FLAGS.patience >= 0, "--patience must be >= 0"

    def make_env():
        return env_from_flags(benchmark=benchmark_from_flags())

    env = make_env()
    try:
        env.reset()
        if not env.benchmark:
            raise app.UsageError("No benchmark specified.")
    finally:
        env.close()

    best_reward, _ = random_search(
        make_env=make_env,
        outdir=Path(FLAGS.output_dir) if FLAGS.output_dir else None,
        patience=FLAGS.patience,
        total_runtime=FLAGS.runtime,
        nproc=FLAGS.nproc,
        skip_done=FLAGS.skip_done,
    )

    # Exit with error if --fail_threshold was set and the best reward does not
    # meet this value.
    if FLAGS.fail_threshold is not None and best_reward < FLAGS.fail_threshold:
        print(
            f"Best reward {best_reward:.3f} below threshold of {FLAGS.fail_threshold}",
            file=sys.stderr,
        )
        sys.exit(1)
Esempio n. 3
0
def main(argv):
    """Main entry point."""
    argv = FLAGS(argv)
    if len(argv) != 1:
        raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

    if FLAGS.ls_reward:
        with env_from_flags() as env:
            print("\n".join(sorted(env.reward.indices.keys())))
        return

    assert FLAGS.patience >= 0, "--patience must be >= 0"

    # Create an environment now to catch a startup time error before we launch
    # a bunch of workers.
    with env_from_flags() as env:
        env.reset(benchmark=benchmark_from_flags())

    env = random_search(
        make_env=lambda: env_from_flags(benchmark=benchmark_from_flags()),
        outdir=Path(FLAGS.output_dir) if FLAGS.output_dir else None,
        patience=FLAGS.patience,
        total_runtime=FLAGS.runtime,
        nproc=FLAGS.nproc,
        skip_done=FLAGS.skip_done,
    )
    try:
        # Exit with error if --fail_threshold was set and the best reward does not
        # meet this value.
        if (
            FLAGS.fail_threshold is not None
            and env.episode_reward < FLAGS.fail_threshold
        ):
            print(
                f"Best reward {env.episode_reward:.3f} below threshold of {FLAGS.fail_threshold}",
                file=sys.stderr,
            )
            sys.exit(1)
    finally:
        env.close()