示例#1
0
def test_random_search():
    sys.argv = [
        "argv0",
        "--n=1",
        "--max_benchmarks=1",
        "--nproc=1",
        "--novalidate",
    ]
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(e_greedy_search)
示例#2
0
def test_tabular_q():
    FLAGS.unparse_flags()
    FLAGS([
        "argv0",
        "--n=1",
        "--max_benchmarks=1",
        "--nproc=1",
        "--novalidate",
    ])
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(train_and_run)
示例#3
0
def test_eval_llvm_instcount_policy_resume(tmpwd):
    # Run eval on a single benchmark.
    set_command_line_flags([
        "argv0",
        "--n=1",
        "--max_benchmarks=1",
        "--novalidate",
        "--resume",
        "--leaderboard_results=test.csv",
    ])
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(null_policy)

    # Check that the log has a single entry (and a header row.)
    assert Path("test.csv").is_file()
    with open("test.csv") as f:
        log = f.read()
    assert len(log.rstrip().split("\n")) == 2
    init_logfile = log

    # Repeat, but for two benchmarks.
    set_command_line_flags([
        "argv0",
        "--n=1",
        "--max_benchmarks=2",
        "--novalidate",
        "--resume",
        "--leaderboard_results=test.csv",
    ])
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(null_policy)

    # Check that the log extends the original.
    assert Path("test.csv").is_file()
    with open("test.csv") as f:
        log = f.read()
    assert log.startswith(init_logfile)
    assert len(log.rstrip().split("\n")) == 3
    init_logfile = log

    # Repeat, but for two runs of each benchmark.
    set_command_line_flags([
        "argv0",
        "--n=2",
        "--max_benchmarks=2",
        "--novalidate",
        "--resume",
        "--leaderboard_results=test.csv",
    ])
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(null_policy)

    # Check that the log extends the original.
    assert Path("test.csv").is_file()
    with open("test.csv") as f:
        log = f.read()
    assert log.startswith(init_logfile)
    assert len(log.rstrip().split("\n")) == 5
示例#4
0
#
"""Evaluate deep q network for leaderboard"""
from absl import app
from compiler_gym.leaderboard.llvm_instcount import eval_llvm_instcount_policy
from compiler_gym.envs import LlvmEnv
from dqn import rollout, train, Agent, train_and_run
import torch


def run(env: LlvmEnv) -> None:

    agent = Agent(n_actions=15, input_dims=[69])
    env.observation_space = "InstCountNorm"
    agent.Q_eval.load_state_dict(torch.load("./H10-N4000-INSTCOUNTNORM.pth"))
    rollout(agent, env)


if __name__ == "__main__":
    app.run(eval_llvm_instcount_policy(run))
示例#5
0
def test_eval_llvm_instcount_policy_invalid_flag():
    set_command_line_flags(["argv0", "--n=-1"])
    with pytest.raises(AssertionError):
        eval_llvm_instcount_policy(null_policy)
示例#6
0
def test_eval_llvm_instcount_policy():
    set_command_line_flags(
        ["argv0", "--n=1", "--max_benchmarks=1", "--novalidate"])
    with pytest.raises(SystemExit):
        eval_llvm_instcount_policy(null_policy)
示例#7
0
from absl import app, flags

from compiler_gym.envs import LlvmEnv
from compiler_gym.leaderboard.llvm_instcount import eval_llvm_instcount_policy

sys.path.insert(
    0,
    os.path.dirname(os.path.realpath(__file__)) + "/../../../examples")
from tabular_q import (  # noqa pylint: disable=wrong-import-position
    StateActionTuple, rollout, train,
)

FLAGS = flags.FLAGS


def train_and_run(env: LlvmEnv) -> None:
    """Run tabular Q learning on an environment"""
    FLAGS.log_every = 0  # Disable printing to stdout

    q_table: Dict[StateActionTuple, float] = {}
    env.observation_space = "Autophase"
    training_env = env.fork()
    train(q_table, training_env)
    training_env.close()
    rollout(q_table, env, printout=False)


if __name__ == "__main__":
    app.run(eval_llvm_instcount_policy(train_and_run))
示例#8
0
        ) for _ in range(FLAGS.nproc)
    ]
    for worker in workers:
        worker.start()

    sleep(FLAGS.search_time)

    # Stop the workers.
    for worker in workers:
        worker.alive = False
    for worker in workers:
        worker.join()

    # Aggregate the best results.
    best_actions = []
    best_reward = -float("inf")
    for worker in workers:
        if worker.best_returns > best_reward:
            best_reward, best_actions = worker.best_returns, list(
                worker.best_actions)

    # Replay the best sequence of actions to produce the final environment
    # state.
    for action in best_actions:
        _, _, done, _ = env.step(action)
        assert not done


if __name__ == "__main__":
    eval_llvm_instcount_policy(random_search)
示例#9
0
def rollout(agent, env):
    """Test trained agent for a single episode. Return the episode reward"""
    # run until episode ends
    episode_reward = 0
    done = False
    obs = env.reset()
    while not done:
        action = agent.compute_action(obs)
        obs, reward, done, info = env.step(action)
        episode_reward += reward
        
    return episode_reward

def test(env, agent_path):
    test_agent = load(agent_path)
    env.observation_space = "InstCount"
    # wrap env so episode can terminate after n rewardless steps
    env = stepWrapper(env)
    rollout(test_agent, env)

# start training
if __name__ == "__main__":
    #eval_llvm_instcount_policy(test)
    #test_agent = load(agent_path)
    save_dir = './log_dir'
    agent_path, anaysis_obj = train({"episodes_total":1000000}, save_dir)
    agent = load(agent_path)
    eval_llvm_instcount_policy(test(agent_path=agent_path))

示例#10
0
                    logging.debug(
                        "Greedy search terminated after %d steps, "
                        "no further reward attainable",
                        step_count,
                    )
                    done = True
                else:
                    _, reward, done, _ = env.step(best.action)
                    logging.debug(
                        "Step %d, greedy action %s, reward %.4f, cumulative %.4f",
                        step_count,
                        env.action_space.flags[best.action],
                        reward,
                        env.episode_reward,
                    )
                    if env.reward_space.deterministic and reward != best.reward:
                        logging.warning(
                            "Action %s produced different reward on replay, %.4f != %.4f",
                            env.action_space.flags[best.action],
                            best.reward,
                            reward,
                        )

                # Stop the search if we have reached a terminal state.
                if done:
                    return


if __name__ == "__main__":
    eval_llvm_instcount_policy(e_greedy_search)