Esempio n. 1
0
    def register_derived_space(
        self,
        base_name: str,
        derived_name: str,
        derived_space: Space,
        cb: Callable[[observation_t], observation_t],
    ) -> None:
        """Add a hook for implementing derived observation spaces.

        Subclasses of ObservationView call this method in their
        :code:`__init__()` after initializing the base class to register new
        observation spaces that are derived from those provided by the
        CompilerService.

        Example usage:

        Suppose we have a service that provides a "src" observation space
        that returns a string of source code. We want to create a new
        observation space, "src_len", that returns the length of the source
        code. We do this by calling :code:`register_derived_space()` and
        providing the a callback to translate from the base observation space
        to the derived value:

        .. code-block:: python

            class MyObservationView(ObservationView):
                def __init__(self, *args, **kwargs):
                    super().__init__(*args, **kwargs)
                    self.register_derived_space(
                        derived_name="src_len",
                        base_name="src",
                        derived_space=Box(low=0, high=float("inf"), shape=(1,), dtype=int),
                        derive=lambda src: [len(src)],
                    )

        Now we can request for "src_len" observation space and receive
        observations from this new derived space.

        >>> env.observation["src_len"]
        [1021,]

        :param base_name: The name of the observation space that this new
            observation space is derived from.
        :param derived_name: The name of the derived observation space
        """
        base_spec = self.spaces[base_name]
        spec = ObservationSpaceSpec(id=derived_name, space=derived_space)
        spec.index = base_spec.index
        spec.deterministic = base_spec.deterministic
        spec.platform_dependent = base_spec.platform_dependent
        self.spaces[derived_name] = spec
        self._translate_cbs[derived_name] = cb
Esempio n. 2
0
    def __init__(
        self,
        get_observation: Callable[[ObservationRequest], Observation],
        spaces: List[ObservationSpace],
    ):
        if not spaces:
            raise ValueError("No observation spaces")
        self.spaces = {
            s.name: ObservationSpaceSpec.from_proto(i, s) for i, s in enumerate(spaces)
        }
        self.session_id = -1

        self._get_observation = get_observation
Esempio n. 3
0
    def __init__(
        self,
        raw_step: Callable[
            [List[ActionType], List[ObservationType], List[RewardType]],
            StepType],
        spaces: List[ObservationSpace],
    ):
        if not spaces:
            raise ValueError("No observation spaces")
        self.spaces: Dict[str, ObservationSpaceSpec] = {}

        self._raw_step = raw_step

        for i, s in enumerate(spaces):
            self._add_space(ObservationSpaceSpec.from_proto(i, s))
Esempio n. 4
0
def compute_observation(observation_space: ObservationSpaceSpec,
                        bitcode: Path,
                        timeout: float = 300) -> ObservationType:
    """Compute an LLVM observation.

    This is a utility function that uses a standalone C++ binary to compute an
    observation from an LLVM bitcode file. It is intended for use cases where
    you want to compute an observation without the overhead of initializing a
    full environment.

    Example usage:

        >>> env = compiler_gym.make("llvm-v0")
        >>> space = env.observation.spaces["Ir"]
        >>> bitcode = Path("bitcode.bc")
        >>> observation = llvm.compute_observation(space, bitcode, timeout=30)

    .. warning::

        This is not part of the core CompilerGym API and may change in a future
        release.

    :param observation_space: The observation that is to be computed.

    :param bitcode: The path of an LLVM bitcode file.

    :param timeout: The maximum number of seconds to allow the computation to
        run before timing out.

    :raises ValueError: If computing the observation fails.

    :raises TimeoutError: If computing the observation times out.

    :raises FileNotFoundError: If the given bitcode does not exist.
    """
    if not Path(bitcode).is_file():
        raise FileNotFoundError(bitcode)

    observation_space_name = pascal_case_to_enum(observation_space.id)

    process = subprocess.Popen(
        [str(_COMPUTE_OBSERVATION_BIN), observation_space_name,
         str(bitcode)],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )

    try:
        stdout, stderr = process.communicate(timeout=timeout)
    except subprocess.TimeoutExpired as e:
        # kill() was added in Python 3.7.
        if sys.version_info >= (3, 7, 0):
            process.kill()
        else:
            process.terminate()
        process.communicate(timeout=timeout)  # Wait for shutdown to complete.
        raise TimeoutError(
            f"Failed to compute {observation_space.id} observation in "
            f"{timeout:.1f} {plural(int(round(timeout)), 'second', 'seconds')}"
        ) from e

    if process.returncode:
        try:
            stderr = stderr.decode("utf-8")
            raise ValueError(
                f"Failed to compute {observation_space.id} observation: {stderr}"
            )
        except UnicodeDecodeError as e:
            raise ValueError(
                f"Failed to compute {observation_space.id} observation") from e

    try:
        stdout = stdout.decode("utf-8")
    except UnicodeDecodeError as e:
        raise ValueError(
            f"Failed to parse {observation_space.id} observation: {e}") from e

    observation = Observation()
    try:
        google.protobuf.text_format.Parse(stdout, observation)
    except google.protobuf.text_format.ParseError as e:
        raise ValueError(
            f"Failed to parse {observation_space.id} observation") from e

    return observation_space.translate(observation)