def test_nested_capture(): with capture_output() as outer: with capture_output() as inner: print("Hello") print("World!") assert inner.stdout == "Hello\n" assert outer.stdout == "World!\n"
def test_capture_print_statements(): with capture_output() as out: print("Hello") print("World!", file=sys.stderr) assert out.stdout == "Hello\n" assert out.stderr == "World!\n"
def io_check(input, output, rnd_seed=100): """Run the shell with the given input and check the output matches the output regex""" seed(rnd_seed) old_stdin = sys.stdin try: with capture_output() as out: try: sys.stdin = StringIO(input) main(["argv0", "--env=llvm-v0"]) except SystemExit: pass # Expected behaviour is to call sys.exit(). print(out.stdout) pattern = ( r"""Initialized environment in [0-9.mu]*s Welcome to the CompilerGym Shell! --------------------------------- Type help or \? for more information. The 'tutorial' command will give a step by step guide. """ + output ) if not re.match(pattern, out.stdout): pytest.fail( f"Failed to match regex:\n{pattern}\n" + ("*" * 80) + f"\nUsing output:\n{out.stdout}\n" ) finally: sys.stdin = old_stdin
def test_ls_env(): with capture_output() as out: try: main(["argv0", "--ls_env"]) except SystemExit: pass # Expected behaviour is to call sys.exit(). assert "llvm-" in out.stdout
def test_load_test(env, tmpwd): del env # Unused. del tmpwd # Unused. set_command_line_flags( [ "arv0", "--env=llvm-v0", "--benchmark=cbench-v1/crc32", "--max_nproc=3", "--nproc_increment=1", "--num_steps=2", "--num_episodes=2", ] ) with capture_output() as out: load_test(["argv0"]) assert "Run 1 threaded workers in " in out.stdout assert "Run 1 process workers in " in out.stdout assert "Run 2 threaded workers in " in out.stdout assert "Run 2 process workers in " in out.stdout assert "Run 3 threaded workers in " in out.stdout assert "Run 3 process workers in " in out.stdout assert Path("parallelization_load_test.csv").is_file()
def test_no_input(monkeypatch): set_command_line_flags(["argv0", "--env=llvm-ic-v0"]) monkeypatch.setattr("sys.stdin", StringIO("")) with capture_output() as out: with pytest.raises(SystemExit): main(["argv0", "-"]) assert "No inputs to validate" in out.stderr
def test_invalid_csv_format(monkeypatch): stdin = "invalid\ncsv\nformat" set_command_line_flags(["argv0", "--env=llvm-ic-v0"]) monkeypatch.setattr("sys.stdin", StringIO(stdin)) with capture_output() as out: with pytest.raises(SystemExit): main(["argv0", "-"]) assert "Expected 4 columns in the first row of CSV" in out.stderr
def test_invalid_csv_format(monkeypatch): input = "invalid\ncsv\nformat" flags.FLAGS.unparse_flags() flags.FLAGS(["argv0", "--env=llvm-ic-v0", "--dataset=cBench-v0"]) monkeypatch.setattr("sys.stdin", StringIO(input)) with capture_output() as out: with pytest.raises(SystemExit): main(["argv0"]) assert "Failed to parse input:" in out.stderr
def test_validate_cbench_null_options(monkeypatch, benchmarks: List[str]): stdin = "\n".join([ "benchmark,reward,walltime,commandline", ] + [f"{b},,0,opt input.bc -o output.bc" for b in benchmarks]) set_command_line_flags(["argv0", "--env=llvm-v0"]) monkeypatch.setattr("sys.stdin", StringIO(stdin)) with capture_output() as out: main(["argv0", "-"]) assert not out.stderr assert out.stdout.count("✅") == len(benchmarks) # Every benchmark passed.
def test_run_random_walk_smoke_test(): FLAGS.unparse_flags() FLAGS(["argv0"]) with capture_output() as out: with compiler_gym.make("llvm-autophase-ic-v0") as env: env.benchmark = "cbench-v1/crc32" run_random_walk(env=env, step_count=5) print(out.stdout) # Note the ".*" before and after the step count to ignore the shell # formatting. assert re.search(r"Completed .*5.* steps in ", out.stdout)
def test_okay_llvm_result(monkeypatch): stdin = """ benchmark,reward,commandline,walltime benchmark://cbench-v1/crc32,0,opt input.bc -o output.bc,0.3 """.strip() set_command_line_flags(["argv0", "--env=llvm-ic-v0"]) monkeypatch.setattr("sys.stdin", StringIO(stdin)) with capture_output() as out: main(["argv0", "-"]) assert "✅ cbench-v1/crc32 " in out.stdout assert not out.stderr
def test_invalid_reward_llvm_result(monkeypatch): stdin = """ benchmark,reward,commandline,walltime benchmark://cbench-v1/crc32,0.5,opt input.bc -o output.bc,0.3 """.strip() set_command_line_flags(["argv0", "--env=llvm-ic-v0"]) monkeypatch.setattr("sys.stdin", StringIO(stdin)) with capture_output() as out: with pytest.raises(SystemExit): main(["argv0", "-"]) assert ("❌ cbench-v1/crc32 Expected reward 0.5 but received reward 0.0\n" in out.stdout) assert not out.stderr
def test_okay_llvm_result(monkeypatch): input = """ benchmark,reward,commandline,walltime benchmark://cBench-v0/dijkstra,0,opt input.bc -o output.bc,0.3 """.strip() flags.FLAGS.unparse_flags() flags.FLAGS(["argv0", "--env=llvm-ic-v0", "--dataset=cBench-v0"]) monkeypatch.setattr("sys.stdin", StringIO(input)) with capture_output() as out: main(["argv0"]) assert out.stdout.startswith("✅ cBench-v0/dijkstra ") assert not out.stderr
def test_run_train_smoke_test(): flags = [ "argv0", "--dataset_size=64", "--batch_size=4", "--num_epoch=2", "--device=cpu", ] sys.argv = flags FLAGS(flags) with capture_output() as out: main(["argv0"]) assert "Epoch num 0 training" in out.stdout
def test_run_tabular_q_smoke_test(): FLAGS.unparse_flags() FLAGS( [ "argv0", "--episode_length=5", "--episodes=10", "--log_every=2", "--benchmark=cbench-v1/crc32", ] ) with capture_output() as out: main(["argv0"]) assert "Resulting sequence" in out.stdout
def test_multiple_valid_inputs(monkeypatch): stdin = """ benchmark,reward,walltime,commandline benchmark://cbench-v1/crc32,,0,opt input.bc -o output.bc benchmark://cbench-v1/crc32,,0,opt input.bc -o output.bc benchmark://cbench-v1/crc32,,0,opt input.bc -o output.bc """.strip() set_command_line_flags(["argv0", "--env=llvm-v0"]) monkeypatch.setattr("sys.stdin", StringIO(stdin)) with capture_output() as out: main(["argv0", "-"]) assert not out.stderr assert out.stdout.count("✅") == 3 # Every benchmark passed.
def test_okay_llvm_result_file_input(): with tempfile.TemporaryDirectory() as d: path = Path(d) / "test.csv" with open(str(path), "w") as f: f.write(""" benchmark,reward,commandline,walltime benchmark://cbench-v1/crc32,0,opt input.bc -o output.bc,0.3 """.strip()) set_command_line_flags(["argv0", "--env=llvm-ic-v0"]) with capture_output() as out: main(["argv0", str(path)]) assert "✅ cbench-v1/crc32 " in out.stdout assert not out.stderr
def test_invalid_reward_llvm_result(monkeypatch): input = """ benchmark,reward,commandline,walltime benchmark://cBench-v0/dijkstra,0.5,opt input.bc -o output.bc,0.3 """.strip() flags.FLAGS.unparse_flags() flags.FLAGS(["argv0", "--env=llvm-ic-v0", "--dataset=cBench-v0"]) monkeypatch.setattr("sys.stdin", StringIO(input)) with capture_output() as out: with pytest.raises(SystemExit): main(["argv0"]) assert out.stdout.startswith( "❌ cBench-v0/dijkstra Expected reward 0.5000 but received reward 0.0000\n" ) assert not out.stderr
def test_run_actor_critic_smoke_test(): flags = [ "argv0", "--seed=0", "--episode_len=2", "--episodes=10", "--log_interval=5", "--benchmark=cbench-v1/crc32", ] sys.argv = flags FLAGS.unparse_flags() FLAGS(flags) with capture_output() as out: main(["argv0"]) assert "Final performance (avg reward)" in out.stdout
def io_check(input, output, rnd_seed=100): """Run the shell with the given input and check the output matches the output regex""" seed(rnd_seed) old_stdin = sys.stdin try: with capture_output() as out: try: sys.stdin = StringIO(input) main(["argv0", "--env=llvm-v0"]) except SystemExit: pass # Expected behaviour is to call sys.exit(). print(out.stdout) pattern = (r"""Initialized environment in [0-9.mu]*s Welcome to the CompilerGym Shell! --------------------------------- Type help or \? for more information. The 'tutorial' command will give a step by step guide. """ + output + r""" compiler_gym:[a-zA-Z0-9/-]+> Exiting """) # Strip ANSI escape sequences from output that are used for formatting. ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") stdout = ansi_escape.sub("", out.stdout) # Strip trailing whitespace from output. stdout = "\n".join(n.rstrip() for n in stdout.split("\n")) if not re.match(pattern, stdout): # Create a diff between the expected regex and the actual output. # Diffing a regex will create a lot of false-positives, since any # character groups or other expressions will be different, but can # still be helful for tracking down the important differences. diff = unified_diff( pattern.split("\n"), stdout.split("\n"), fromfile="Expected output regex", tofile="Actual output", ) pytest.fail("\n".join(diff)) finally: sys.stdin = old_stdin
def __call__(self, env: CompilerEnv, seed: int = 0xCC) -> CompilerEnvState: """Autotune the given environment. :param env: The environment to autotune. :param seed: The random seed for the autotuner. :returns: A CompilerEnvState tuple describing the autotuning result. """ # Run the autotuner in a temporary working directory and capture the # stdout/stderr. with tempfile.TemporaryDirectory(dir=transient_cache_path("."), prefix="autotune-") as tmpdir: with temporary_working_directory(Path(tmpdir)): with capture_output(): with Timer() as timer: self.autotune(env, seed=seed, **self.autotune_kwargs) return CompilerEnvState( benchmark=env.benchmark.uri, commandline=env.commandline(), walltime=timer.time, reward=self.optimization_target.final_reward(env), )
def test_validate_cBenh_null_options(monkeypatch): input = """ benchmark,reward,walltime,commandline benchmark://cBench-v0/gsm,,0,opt input.bc -o output.bc benchmark://cBench-v0/lame,,0,opt input.bc -o output.bc benchmark://cBench-v0/stringsearch,,0,opt input.bc -o output.bc benchmark://cBench-v0/ghostscript,,0,opt input.bc -o output.bc benchmark://cBench-v0/qsort,,0,opt input.bc -o output.bc benchmark://cBench-v0/sha,,0,opt input.bc -o output.bc benchmark://cBench-v0/ispell,,0,opt input.bc -o output.bc benchmark://cBench-v0/blowfish,,0,opt input.bc -o output.bc benchmark://cBench-v0/adpcm,,0,opt input.bc -o output.bc benchmark://cBench-v0/tiffdither,,0,opt input.bc -o output.bc benchmark://cBench-v0/bzip2,,0,opt input.bc -o output.bc benchmark://cBench-v0/stringsearch2,,0,opt input.bc -o output.bc benchmark://cBench-v0/bitcount,,0,opt input.bc -o output.bc benchmark://cBench-v0/jpeg-d,,0,opt input.bc -o output.bc benchmark://cBench-v0/jpeg-c,,0,opt input.bc -o output.bc benchmark://cBench-v0/dijkstra,,0,opt input.bc -o output.bc benchmark://cBench-v0/rijndael,,0,opt input.bc -o output.bc benchmark://cBench-v0/patricia,,0,opt input.bc -o output.bc benchmark://cBench-v0/tiff2rgba,,0,opt input.bc -o output.bc benchmark://cBench-v0/crc32,,0,opt input.bc -o output.bc benchmark://cBench-v0/tiff2bw,,0,opt input.bc -o output.bc benchmark://cBench-v0/tiffmedian,,0,opt input.bc -o output.bc benchmark://cBench-v0/susan,,0,opt input.bc -o output.bc """.strip() flags.FLAGS.unparse_flags() flags.FLAGS(["argv0", "--env=llvm-v0", "--dataset=cBench-v0"]) monkeypatch.setattr("sys.stdin", StringIO(input)) with capture_output() as out: main(["argv0"]) assert out.stdout.count("✅") == 23 # Every benchmark passed. assert not out.stderr