def test_random_search_smoke_test(): with tempfile.TemporaryDirectory() as tmp: outdir = Path(tmp) flags.FLAGS.unparse_flags() flags.FLAGS(["argv0"]) random_search( make_env=make_env, outdir=outdir, patience=50, total_runtime=3, nproc=1, skip_done=False, ) assert (outdir / "random_search.json").is_file() assert (outdir / "random_search_progress.csv").is_file() assert (outdir / "random_search_best_actions.txt").is_file() assert (outdir / "optimized.bc").is_file() env = make_env() try: replay_actions_from_logs(env, Path(outdir)) assert (outdir / "random_search_best_actions_progress.csv").is_file() assert (outdir / "random_search_best_actions_commandline.txt").is_file() finally: env.close()
def main(argv): """Main entry point.""" argv = FLAGS(argv) if len(argv) != 1: raise app.UsageError(f"Unknown command line arguments: {argv[1:]}") if FLAGS.ls_benchmark: env = env_from_flags() print("\n".join(sorted(env.benchmarks))) env.close() return if FLAGS.ls_reward: env = env_from_flags() print("\n".join(sorted(env.reward.indices.keys()))) env.close() return assert FLAGS.patience >= 0, "--patience must be >= 0" def make_env(): return env_from_flags(benchmark=benchmark_from_flags()) env = make_env() try: env.reset() if not env.benchmark: raise app.UsageError("No benchmark specified.") finally: env.close() best_reward, _ = random_search( make_env=make_env, outdir=Path(FLAGS.output_dir) if FLAGS.output_dir else None, patience=FLAGS.patience, total_runtime=FLAGS.runtime, nproc=FLAGS.nproc, skip_done=FLAGS.skip_done, ) # Exit with error if --fail_threshold was set and the best reward does not # meet this value. if FLAGS.fail_threshold is not None and best_reward < FLAGS.fail_threshold: print( f"Best reward {best_reward:.3f} below threshold of {FLAGS.fail_threshold}", file=sys.stderr, ) sys.exit(1)
def main(argv): """Main entry point.""" argv = FLAGS(argv) if len(argv) != 1: raise app.UsageError(f"Unknown command line arguments: {argv[1:]}") if FLAGS.ls_reward: with env_from_flags() as env: print("\n".join(sorted(env.reward.indices.keys()))) return assert FLAGS.patience >= 0, "--patience must be >= 0" # Create an environment now to catch a startup time error before we launch # a bunch of workers. with env_from_flags() as env: env.reset(benchmark=benchmark_from_flags()) env = random_search( make_env=lambda: env_from_flags(benchmark=benchmark_from_flags()), outdir=Path(FLAGS.output_dir) if FLAGS.output_dir else None, patience=FLAGS.patience, total_runtime=FLAGS.runtime, nproc=FLAGS.nproc, skip_done=FLAGS.skip_done, ) try: # Exit with error if --fail_threshold was set and the best reward does not # meet this value. if ( FLAGS.fail_threshold is not None and env.episode_reward < FLAGS.fail_threshold ): print( f"Best reward {env.episode_reward:.3f} below threshold of {FLAGS.fail_threshold}", file=sys.stderr, ) sys.exit(1) finally: env.close()