def test_add_validation_callbacks_call_count(): """Test that custom validation callbacks are called on validate().""" a_call_count = 0 b_call_count = 0 def a(env): nonlocal a_call_count a_call_count += 1 def b(env): nonlocal b_call_count b_call_count += 1 benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) benchmark.add_validation_callback(a) errors = benchmark.validate(env=None) assert errors == [] assert a_call_count == 1 assert b_call_count == 0 benchmark.add_validation_callback(b) errors = benchmark.validate(env=None) assert errors == [] assert a_call_count == 2 assert b_call_count == 1
def test_benchmark_immutable(): """Test that benchmark properties are immutable.""" benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) with pytest.raises(AttributeError): benchmark.uri = 123 with pytest.raises(AttributeError): benchmark.proto = 123
def benchmark_from_parsed_uri(self, uri: BenchmarkUri) -> Benchmark: """Select a benchmark. Returns the corresponding :class:`Benchmark <compiler_gym.datasets.Benchmark>`, regardless of whether the containing dataset is installed or deprecated. :param uri: The parsed URI of the benchmark to return. :return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>` instance. """ if uri.scheme == "proto": path = Path(os.path.normpath(f"{uri.dataset}/{uri.path}")) if not path.is_file(): raise FileNotFoundError(str(path)) proto = BenchmarkProto() with open(path, "rb") as f: proto.ParseFromString(f.read()) return Benchmark(proto=proto) if uri.scheme == "file": path = Path(os.path.normpath(f"{uri.dataset}/{uri.path}")) if not path.is_file(): raise FileNotFoundError(str(path)) return Benchmark.from_file(uri=uri, path=path) dataset = self.dataset_from_parsed_uri(uri) return dataset.benchmark_from_parsed_uri(uri)
def test_add_benchmark_invalid_protocol(env: CompilerEnv): with pytest.raises(ValueError) as ctx: env.reset(benchmark=Benchmark( BenchmarkProto(uri="benchmark://foo", program=File(uri="https://invalid/protocol")), )) assert str(ctx.value) == ( "Invalid benchmark data URI. " 'Only the file:/// protocol is supported: "https://invalid/protocol"')
def from_file_contents(cls, uri: str, data: bytes): """Construct a benchmark from raw data. :param uri: The URI of the benchmark. :param data: An array of bytes that will be passed to the compiler service. """ return cls(proto=BenchmarkProto(uri=uri, program=File(contents=data)))
def test_invalid_benchmark_missing_file(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto( uri="benchmark://new", ) ) with pytest.raises(ValueError, match="No program set"): env.reset(benchmark=benchmark)
def test_validation_callback_error_iter(): """Test error propagation from custom validation callback using iterable.""" def a(env): yield ValidationError(type="Compilation Error") yield ValidationError(type="Runtime Error") benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) benchmark.add_validation_callback(a) errors = benchmark.ivalidate(env=None) next(errors) == ValidationError(type="Compilation Error") next(errors) == ValidationError(type="Runtime Error")
def test_benchmark_path_invalid_protocol(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto(uri="benchmark://new", program=File(uri="invalid_protocol://test")), ) with pytest.raises( ValueError, match= ("Invalid benchmark data URI. " 'Only the file:/// protocol is supported: "invalid_protocol://test"'), ): env.reset(benchmark=benchmark)
def test_dataset_proto_scheme(tmpdir): """Test the proto:// scheme handler.""" tmpdir = Path(tmpdir) datasets = Datasets(datasets={}) proto = BenchmarkProto(uri="hello world") with open(tmpdir / "file.pb", "wb") as f: f.write(proto.SerializeToString()) benchmark = datasets.benchmark(f"proto://{tmpdir}/file.pb") assert benchmark.proto.uri == "hello world" assert benchmark.uri == "benchmark://hello world"
def test_validation_callback_error(): """Test error propagation from custom validation callback.""" def a(env): yield ValidationError(type="Compilation Error") yield ValidationError(type="Runtime Error") benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) benchmark.add_validation_callback(a) errors = benchmark.validate(env=None) assert errors == [ ValidationError(type="Compilation Error"), ValidationError(type="Runtime Error"), ]
def test_dataset_equality_and_sorting(): """Test comparison operators between datasets.""" a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/a")) a2 = Benchmark(BenchmarkProto(uri="benchmark://example-v0/a")) b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/b")) assert a == a2 assert a != b assert a < b assert a <= b assert b > a assert b >= a # String comparisons assert a == "benchmark://example-v0/a" assert a != "benchmark://example-v0/b" assert a < "benchmark://example-v0/b" # Sorting assert sorted([a2, b, a]) == [ "benchmark://example-v0/a", "benchmark://example-v0/a", "benchmark://example-v0/b", ]
def test_add_validation_callbacks_values(): """Test methods for adding and checking custom validation callbacks.""" def a(env): pass benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) assert benchmark.validation_callbacks() == [] assert not benchmark.is_validatable() benchmark.add_validation_callback(a) assert benchmark.validation_callbacks() == [a] assert benchmark.is_validatable() benchmark.add_validation_callback(a) assert benchmark.validation_callbacks() == [a, a]
def test_benchmark_sources(tmpwd: Path): a = Benchmark( BenchmarkProto(uri="benchmark://example-v0/foo"), sources=[("example.py", "Hello, world!".encode("utf-8"))], ) a.add_source(BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8"))) assert list(a.sources) == [ BenchmarkSource("example.py", "Hello, world!".encode("utf-8")), BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8")), ] a.write_sources_to_directory("benchmark_sources") with open(tmpwd / "benchmark_sources" / "example.py") as f: assert f.read() == "Hello, world!" with open(tmpwd / "benchmark_sources" / "foo.py") as f: assert f.read() == "Hi"
def test_validation_callback_flaky(): """Test error propagation on callback which *may* fail.""" flaky = False def a(env): nonlocal flaky del env if flaky: yield ValidationError(type="Runtime Error") benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) benchmark.add_validation_callback(a) errors = benchmark.validate(env=None) assert errors == [] flaky = True errors = benchmark.validate(env=None) assert errors == [ ValidationError(type="Runtime Error"), ]
def from_file(cls, uri: str, path: Path): """Construct a benchmark from a file. :param uri: The URI of the benchmark. :param path: A filesystem path. :raise FileNotFoundError: If the path does not exist. :return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>` instance. """ path = Path(path) if not path.is_file(): raise FileNotFoundError(path) # Read the file data into memory and embed it inside the File protocol # buffer. An alternative would be to simply embed the file path in the # File.uri field, but this won't work for distributed services which # don't share a filesystem. with open(path, "rb") as f: contents = f.read() return cls(proto=BenchmarkProto(uri=uri, program=File(contents=contents)))
def __init__(self, invocation: GccInvocation, bitcode: bytes, timeout: int): uri = f"benchmark://clang-v0/{urllib.parse.quote_plus(join_cmd(invocation.original_argv))}" super().__init__( proto=BenchmarkProto(uri=str(uri), program=File(contents=bitcode)) ) self.command_line = invocation.original_argv # Modify the commandline so that it takes the bitcode file as input. # # Strip the original sources from the build command, but leave any # object file inputs. sources = set(s for s in invocation.sources if not s.endswith(".o")) build_command = [arg for arg in invocation.original_argv if arg not in sources] # Convert any object file inputs to absolute paths since the backend # service will have a different working directory. # # TODO(github.com/facebookresearch/CompilerGym/issues/325): To support # distributed execution, we should embed the contents of these object # files in the benchmark proto. object_files = set(s for s in invocation.sources if s.endswith(".o")) build_command = [ os.path.abspath(arg) if arg in object_files else arg for arg in build_command ] # Append the new source to the build command and specify the absolute path # to the output. for i in range(len(build_command) - 2, -1, -1): if build_command[i] == "-o": del build_command[i + 1] del build_command[i] build_command += ["-xir", "$IN", "-o", str(invocation.output_path)] self.proto.dynamic_config.build_cmd.argument[:] = build_command self.proto.dynamic_config.build_cmd.outfile[:] = [str(invocation.output_path)] self.proto.dynamic_config.build_cmd.timeout_seconds = timeout
def test_benchmark_properties(): """Test benchmark properties.""" benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar")) assert benchmark.uri == "benchmark://example-v0/foobar" assert benchmark.proto == BenchmarkProto( uri="benchmark://example-v0/foobar")
def test_ne_strings(): a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) b = "benchmark://example-v0/bar" assert a != b
def test_ne_benchmarks(): a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/bar")) assert a != b
def test_eq_strings(): a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) b = "benchmark://example-v0/foo" assert a == b
def test_eq_benchmarks(): a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo")) assert a == b