Example #1
0
def test_add_validation_callbacks_call_count():
    """Test that custom validation callbacks are called on validate()."""
    a_call_count = 0
    b_call_count = 0

    def a(env):
        nonlocal a_call_count
        a_call_count += 1

    def b(env):
        nonlocal b_call_count
        b_call_count += 1

    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    benchmark.add_validation_callback(a)

    errors = benchmark.validate(env=None)
    assert errors == []
    assert a_call_count == 1
    assert b_call_count == 0

    benchmark.add_validation_callback(b)
    errors = benchmark.validate(env=None)
    assert errors == []
    assert a_call_count == 2
    assert b_call_count == 1
Example #2
0
def test_benchmark_immutable():
    """Test that benchmark properties are immutable."""
    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    with pytest.raises(AttributeError):
        benchmark.uri = 123
    with pytest.raises(AttributeError):
        benchmark.proto = 123
Example #3
0
    def benchmark_from_parsed_uri(self, uri: BenchmarkUri) -> Benchmark:
        """Select a benchmark.

        Returns the corresponding :class:`Benchmark
        <compiler_gym.datasets.Benchmark>`, regardless of whether the containing
        dataset is installed or deprecated.

        :param uri: The parsed URI of the benchmark to return.

        :return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
            instance.
        """
        if uri.scheme == "proto":
            path = Path(os.path.normpath(f"{uri.dataset}/{uri.path}"))
            if not path.is_file():
                raise FileNotFoundError(str(path))

            proto = BenchmarkProto()
            with open(path, "rb") as f:
                proto.ParseFromString(f.read())

            return Benchmark(proto=proto)

        if uri.scheme == "file":
            path = Path(os.path.normpath(f"{uri.dataset}/{uri.path}"))
            if not path.is_file():
                raise FileNotFoundError(str(path))

            return Benchmark.from_file(uri=uri, path=path)

        dataset = self.dataset_from_parsed_uri(uri)
        return dataset.benchmark_from_parsed_uri(uri)
Example #4
0
def test_add_benchmark_invalid_protocol(env: CompilerEnv):
    with pytest.raises(ValueError) as ctx:
        env.reset(benchmark=Benchmark(
            BenchmarkProto(uri="benchmark://foo",
                           program=File(uri="https://invalid/protocol")), ))
    assert str(ctx.value) == (
        "Invalid benchmark data URI. "
        'Only the file:/// protocol is supported: "https://invalid/protocol"')
Example #5
0
    def from_file_contents(cls, uri: str, data: bytes):
        """Construct a benchmark from raw data.

        :param uri: The URI of the benchmark.

        :param data: An array of bytes that will be passed to the compiler
            service.
        """
        return cls(proto=BenchmarkProto(uri=uri, program=File(contents=data)))
def test_invalid_benchmark_missing_file(env: LlvmEnv):
    benchmark = Benchmark(
        BenchmarkProto(
            uri="benchmark://new",
        )
    )

    with pytest.raises(ValueError, match="No program set"):
        env.reset(benchmark=benchmark)
Example #7
0
def test_validation_callback_error_iter():
    """Test error propagation from custom validation callback using iterable."""
    def a(env):
        yield ValidationError(type="Compilation Error")
        yield ValidationError(type="Runtime Error")

    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    benchmark.add_validation_callback(a)

    errors = benchmark.ivalidate(env=None)
    next(errors) == ValidationError(type="Compilation Error")
    next(errors) == ValidationError(type="Runtime Error")
def test_benchmark_path_invalid_protocol(env: LlvmEnv):
    benchmark = Benchmark(
        BenchmarkProto(uri="benchmark://new",
                       program=File(uri="invalid_protocol://test")), )

    with pytest.raises(
            ValueError,
            match=
        ("Invalid benchmark data URI. "
         'Only the file:/// protocol is supported: "invalid_protocol://test"'),
    ):
        env.reset(benchmark=benchmark)
Example #9
0
def test_dataset_proto_scheme(tmpdir):
    """Test the proto:// scheme handler."""
    tmpdir = Path(tmpdir)
    datasets = Datasets(datasets={})

    proto = BenchmarkProto(uri="hello world")
    with open(tmpdir / "file.pb", "wb") as f:
        f.write(proto.SerializeToString())

    benchmark = datasets.benchmark(f"proto://{tmpdir}/file.pb")

    assert benchmark.proto.uri == "hello world"
    assert benchmark.uri == "benchmark://hello world"
Example #10
0
def test_validation_callback_error():
    """Test error propagation from custom validation callback."""
    def a(env):
        yield ValidationError(type="Compilation Error")
        yield ValidationError(type="Runtime Error")

    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    benchmark.add_validation_callback(a)

    errors = benchmark.validate(env=None)
    assert errors == [
        ValidationError(type="Compilation Error"),
        ValidationError(type="Runtime Error"),
    ]
Example #11
0
def test_dataset_equality_and_sorting():
    """Test comparison operators between datasets."""
    a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/a"))
    a2 = Benchmark(BenchmarkProto(uri="benchmark://example-v0/a"))
    b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/b"))

    assert a == a2
    assert a != b
    assert a < b
    assert a <= b
    assert b > a
    assert b >= a

    # String comparisons
    assert a == "benchmark://example-v0/a"
    assert a != "benchmark://example-v0/b"
    assert a < "benchmark://example-v0/b"

    # Sorting
    assert sorted([a2, b, a]) == [
        "benchmark://example-v0/a",
        "benchmark://example-v0/a",
        "benchmark://example-v0/b",
    ]
Example #12
0
def test_add_validation_callbacks_values():
    """Test methods for adding and checking custom validation callbacks."""
    def a(env):
        pass

    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    assert benchmark.validation_callbacks() == []
    assert not benchmark.is_validatable()

    benchmark.add_validation_callback(a)
    assert benchmark.validation_callbacks() == [a]
    assert benchmark.is_validatable()

    benchmark.add_validation_callback(a)
    assert benchmark.validation_callbacks() == [a, a]
Example #13
0
def test_benchmark_sources(tmpwd: Path):
    a = Benchmark(
        BenchmarkProto(uri="benchmark://example-v0/foo"),
        sources=[("example.py", "Hello, world!".encode("utf-8"))],
    )
    a.add_source(BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8")))

    assert list(a.sources) == [
        BenchmarkSource("example.py", "Hello, world!".encode("utf-8")),
        BenchmarkSource(filename="foo.py", contents="Hi".encode("utf-8")),
    ]

    a.write_sources_to_directory("benchmark_sources")

    with open(tmpwd / "benchmark_sources" / "example.py") as f:
        assert f.read() == "Hello, world!"
    with open(tmpwd / "benchmark_sources" / "foo.py") as f:
        assert f.read() == "Hi"
Example #14
0
def test_validation_callback_flaky():
    """Test error propagation on callback which *may* fail."""
    flaky = False

    def a(env):
        nonlocal flaky
        del env
        if flaky:
            yield ValidationError(type="Runtime Error")

    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    benchmark.add_validation_callback(a)

    errors = benchmark.validate(env=None)
    assert errors == []

    flaky = True
    errors = benchmark.validate(env=None)
    assert errors == [
        ValidationError(type="Runtime Error"),
    ]
Example #15
0
    def from_file(cls, uri: str, path: Path):
        """Construct a benchmark from a file.

        :param uri: The URI of the benchmark.

        :param path: A filesystem path.

        :raise FileNotFoundError: If the path does not exist.

        :return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
            instance.
        """
        path = Path(path)
        if not path.is_file():
            raise FileNotFoundError(path)
        # Read the file data into memory and embed it inside the File protocol
        # buffer. An alternative would be to simply embed the file path in the
        # File.uri field, but this won't work for distributed services which
        # don't share a filesystem.
        with open(path, "rb") as f:
            contents = f.read()
        return cls(proto=BenchmarkProto(uri=uri, program=File(contents=contents)))
Example #16
0
    def __init__(self, invocation: GccInvocation, bitcode: bytes, timeout: int):
        uri = f"benchmark://clang-v0/{urllib.parse.quote_plus(join_cmd(invocation.original_argv))}"
        super().__init__(
            proto=BenchmarkProto(uri=str(uri), program=File(contents=bitcode))
        )
        self.command_line = invocation.original_argv

        # Modify the commandline so that it takes the bitcode file as input.
        #
        # Strip the original sources from the build command, but leave any
        # object file inputs.
        sources = set(s for s in invocation.sources if not s.endswith(".o"))
        build_command = [arg for arg in invocation.original_argv if arg not in sources]

        # Convert any object file inputs to absolute paths since the backend
        # service will have a different working directory.
        #
        # TODO(github.com/facebookresearch/CompilerGym/issues/325): To support
        # distributed execution, we should embed the contents of these object
        # files in the benchmark proto.
        object_files = set(s for s in invocation.sources if s.endswith(".o"))
        build_command = [
            os.path.abspath(arg) if arg in object_files else arg
            for arg in build_command
        ]

        # Append the new source to the build command and specify the absolute path
        # to the output.
        for i in range(len(build_command) - 2, -1, -1):
            if build_command[i] == "-o":
                del build_command[i + 1]
                del build_command[i]
        build_command += ["-xir", "$IN", "-o", str(invocation.output_path)]
        self.proto.dynamic_config.build_cmd.argument[:] = build_command
        self.proto.dynamic_config.build_cmd.outfile[:] = [str(invocation.output_path)]
        self.proto.dynamic_config.build_cmd.timeout_seconds = timeout
Example #17
0
def test_benchmark_properties():
    """Test benchmark properties."""
    benchmark = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foobar"))
    assert benchmark.uri == "benchmark://example-v0/foobar"
    assert benchmark.proto == BenchmarkProto(
        uri="benchmark://example-v0/foobar")
Example #18
0
def test_ne_strings():
    a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo"))
    b = "benchmark://example-v0/bar"

    assert a != b
Example #19
0
def test_ne_benchmarks():
    a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo"))
    b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/bar"))

    assert a != b
Example #20
0
def test_eq_strings():
    a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo"))
    b = "benchmark://example-v0/foo"

    assert a == b
Example #21
0
def test_eq_benchmarks():
    a = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo"))
    b = Benchmark(BenchmarkProto(uri="benchmark://example-v0/foo"))

    assert a == b