def test_state_from_csv_invalid_format():
    buf = StringIO("abcdef\n")
    reader = CompilerEnvStateReader(buf)
    with pytest.raises(
            ValueError,
            match=r"Expected 4 columns in the first row of CSV: \['abcdef'\]"):
        next(iter(reader))
def test_compiler_env_state_reader_no_header():
    buf = StringIO("benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n")
    reader = CompilerEnvStateReader(buf)
    assert list(reader) == [
        CompilerEnvState(
            benchmark="benchmark://cbench-v0/foo",
            walltime=5,
            commandline="-a -b -c",
            reward=2,
        )
    ]
def test_compiler_env_state_reader_with_header_out_of_order_columns():
    buf = StringIO("commandline,reward,benchmark,walltime\n"
                   "-a -b -c,2.0,benchmark://cbench-v0/foo,5.0\n")
    reader = CompilerEnvStateReader(buf)
    assert list(reader) == [
        CompilerEnvState(
            benchmark="benchmark://cbench-v0/foo",
            walltime=5,
            commandline="-a -b -c",
            reward=2,
        )
    ]
def test_state_serialize_deserialize_equality_no_reward():
    original_state = CompilerEnvState(benchmark="benchmark://cbench-v0/foo",
                                      walltime=100,
                                      commandline="-a -b -c")
    buf = StringIO()
    CompilerEnvStateWriter(buf).write_state(original_state)
    buf.seek(0)  # Rewind the buffer for reading.
    state_from_csv = next(iter(CompilerEnvStateReader(buf)))

    assert state_from_csv.benchmark == "benchmark://cbench-v0/foo"
    assert state_from_csv.walltime == 100
    assert state_from_csv.reward is None
    assert state_from_csv.commandline == "-a -b -c"
예제 #5
0
def test_state_serialize_deserialize_equality(env: LlvmEnv):
    env.reset(benchmark="cbench-v1/crc32")
    env.episode_reward = 10

    state = env.state
    assert state.reward == 10

    buf = StringIO()
    CompilerEnvStateWriter(buf).write_state(state)
    buf.seek(0)  # Rewind the buffer for reading.
    state_from_csv = next(iter(CompilerEnvStateReader(buf)))

    assert state_from_csv.reward == 10
    assert state == state_from_csv
def test_compiler_env_state_reader_header_only():
    buf = StringIO("benchmark,reward,walltime,commandline\n")
    reader = CompilerEnvStateReader(buf)
    assert list(reader) == []
def test_compiler_env_state_reader_empty_input():
    buf = StringIO("")
    reader = CompilerEnvStateReader(buf)
    assert list(reader) == []
예제 #8
0
    def main(argv):
        assert len(argv) == 1, f"Unknown args: {argv[:1]}"
        assert FLAGS.n > 0, "n must be > 0"

        with gym.make("llvm-ic-v0") as env:

            # Stream verbose CompilerGym logs to file.
            logger = logging.getLogger("compiler_gym")
            logger.setLevel(logging.DEBUG)
            log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
            logger.addHandler(log_handler)
            logger.propagate = False

            print(f"Writing results to {FLAGS.leaderboard_results}")
            print(f"Writing logs to {FLAGS.leaderboard_logfile}")

            # Build the list of benchmarks to evaluate.
            benchmarks = env.datasets[FLAGS.test_dataset].benchmark_uris()
            if FLAGS.max_benchmarks:
                benchmarks = islice(benchmarks, FLAGS.max_benchmarks)
            benchmarks = list(benchmarks)

            # Repeat the searches for the requested number of iterations.
            benchmarks *= FLAGS.n
            total_count = len(benchmarks)

            # If we are resuming from a previous job, read the states that have
            # already been proccessed and remove those benchmarks from the list
            # of benchmarks to evaluate.
            init_states = []
            if FLAGS.resume and Path(FLAGS.leaderboard_results).is_file():
                with CompilerEnvStateReader(open(
                        FLAGS.leaderboard_results)) as reader:
                    for state in reader:
                        init_states.append(state)
                        if state.benchmark in benchmarks:
                            benchmarks.remove(state.benchmark)

            # Run the benchmark loop in background so that we can asynchronously
            # log progress.
            worker = _EvalPolicyWorker(env, benchmarks, policy, init_states)
            worker.start()
            timer = Timer().reset()
            try:
                print(f"=== Evaluating policy on "
                      f"{humanize.intcomma(total_count)} "
                      f"{FLAGS.test_dataset} benchmarks ==="
                      "\n\n"  # Blank lines will be filled below
                      )
                while worker.is_alive():
                    done_count = len(worker.states)
                    remaining_count = total_count - done_count
                    time = timer.time
                    gmean_reward = geometric_mean(
                        [s.reward for s in worker.states])
                    mean_walltime = (arithmetic_mean(
                        [s.walltime for s in worker.states]) or time)
                    print(
                        "\r\033[2A"
                        "\033[K"
                        f"Runtime: {humanize_duration_hms(time)}. "
                        f"Estimated completion: {humanize_duration_hms(mean_walltime * remaining_count)}. "
                        f"Completed: {humanize.intcomma(done_count)} / {humanize.intcomma(total_count)} "
                        f"({done_count / total_count:.1%})."
                        "\n\033[K"
                        f"Current mean walltime: {mean_walltime:.3f}s / benchmark."
                        "\n\033[K"
                        f"Current geomean reward: {gmean_reward:.4f}.",
                        flush=True,
                        end="",
                    )
                    sleep(1)
            except KeyboardInterrupt:
                print("\nkeyboard interrupt", flush=True)
                worker.alive = False
                # User interrupt, don't validate.
                FLAGS.validate = False

        if FLAGS.validate:
            FLAGS.env = "llvm-ic-v0"
            validate(["argv0", FLAGS.leaderboard_results])
예제 #9
0
def main(argv):
    """Main entry point."""
    try:
        states = list(CompilerEnvStateReader.read_paths(argv[1:]))
    except ValueError as e:
        print(e, file=sys.stderr)
        sys.exit(1)

    if not states:
        print(
            "No inputs to validate. Pass a CSV file path as an argument, or "
            "use - to read from stdin.",
            file=sys.stderr,
        )
        sys.exit(1)

    # Send the states off for validation
    if FLAGS.debug_force_valid:
        validation_results = (
            ValidationResult(
                state=state,
                reward_validated=True,
                actions_replay_failed=False,
                reward_validation_failed=False,
                benchmark_semantics_validated=False,
                benchmark_semantics_validation_failed=False,
                walltime=0,
            )
            for state in states
        )
    else:
        validation_results = validate_states(
            env_from_flags,
            states,
            nproc=FLAGS.nproc,
            inorder=FLAGS.inorder,
        )

    # Determine the name of the reward space.
    with env_from_flags() as env:
        if FLAGS.reward_aggregation == "geomean":

            def reward_aggregation(a):
                return geometric_mean(np.clip(a, 0, None))

            reward_aggregation_name = "Geometric mean"
        elif FLAGS.reward_aggregation == "mean":
            reward_aggregation = arithmetic_mean
            reward_aggregation_name = "Mean"
        else:
            raise app.UsageError(
                f"Unknown aggregation type: '{FLAGS.reward_aggregation}'"
            )

        if env.reward_space:
            reward_name = f"{reward_aggregation_name} {env.reward_space.id}"
        else:
            reward_name = ""

    # Determine the maximum column width required for printing tabular output.
    max_state_name_length = max(
        len(s)
        for s in [state_name(s) for s in states]
        + [
            "Mean inference walltime",
            reward_name,
        ]
    )
    name_col_width = min(max_state_name_length + 2, 78)

    error_count = 0
    rewards = []
    walltimes = []

    if FLAGS.summary_only:

        def intermediate_print(*args, **kwargs):
            del args
            del kwargs

    else:
        intermediate_print = print

    def progress_message(i):
        intermediate_print(
            f"{i} remaining {plural(i, 'state', 'states')} to validate ... ",
            end="",
            flush=True,
        )

    progress_message(len(states))
    result_dicts = []

    def dump_result_dicst_to_json():
        with open(FLAGS.validation_logfile, "w") as f:
            json.dump(result_dicts, f)

    for i, result in enumerate(validation_results, start=1):
        intermediate_print("\r\033[K", to_string(result, name_col_width), sep="")
        progress_message(len(states) - i)
        result_dicts.append(result.dict())

        if not result.okay():
            error_count += 1
        elif result.reward_validated and not result.reward_validation_failed:
            rewards.append(result.state.reward)
            walltimes.append(result.state.walltime)

        if not i % 10:
            dump_result_dicst_to_json()

    dump_result_dicst_to_json()

    # Print a summary footer.
    intermediate_print("\r\033[K----", "-" * name_col_width, "-----------", sep="")
    print(f"Number of validated results: {emph(len(walltimes))} of {len(states)}")
    walltime_mean = f"{arithmetic_mean(walltimes):.3f}s"
    walltime_std = f"{stdev(walltimes):.3f}s"
    print(
        f"Mean walltime per benchmark: {emph(walltime_mean)} "
        f"(std: {emph(walltime_std)})"
    )
    reward = f"{reward_aggregation(rewards):.3f}"
    reward_std = f"{stdev(rewards):.3f}"
    print(f"{reward_name}: {emph(reward)} " f"(std: {emph(reward_std)})")

    if error_count:
        sys.exit(1)
예제 #10
0
def test_state_from_csv_invalid_benchmark_uri():
    buf = StringIO("benchmark,reward,walltime,commandline\n"
                   "invalid-uri,2.0,5.0,-a -b -c\n")
    reader = CompilerEnvStateReader(buf)
    with pytest.raises(ValueError, match="string does not match regex"):
        next(iter(reader))