예제 #1
0
def test_convert_space_sequence_space():
    space = Space(
        space_sequence=SpaceSequenceSpace(
            length_range=Int64Range(min=0, max=2),
            space=Space(int64_value=Int64Range(min=-1, max=1)),
        ),
    )
    converted_space = py_converters.message_default_converter(space)
    assert isinstance(converted_space, SpaceSequence)
    assert converted_space.size_range[0] == space.space_sequence.length_range.min
    assert converted_space.size_range[1] == space.space_sequence.length_range.max
    assert isinstance(converted_space.space, Scalar)
    assert np.dtype(converted_space.space.dtype) == np.int64
    assert converted_space.space.min == space.space_sequence.space.int64_value.min
    assert converted_space.space.max == space.space_sequence.space.int64_value.max
예제 #2
0
def test_observation_when_raw_step_returns_done():
    """Test that a SessionNotFoundError from the raw_step() callback propagates as a"""
    def make_failing_raw_step(error_msg=None):
        def failing_raw_step(*args, **kwargs):
            """A callback that returns done=True."""
            info = {}
            if error_msg:
                info["error_details"] = error_msg
            return [], None, True, info

        return failing_raw_step

    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(int64_value=Int64Range(min=0)),
        )
    ]

    observation = ObservationView(make_failing_raw_step(), spaces)
    with pytest.raises(ServiceError,
                       match=r"^Failed to compute observation 'ir'$"):
        observation["ir"]  # pylint: disable=pointless-statement

    observation = ObservationView(make_failing_raw_step("Oh no!"), spaces)
    with pytest.raises(ServiceError,
                       match=r"^Failed to compute observation 'ir': Oh no!$"):
        observation["ir"]  # pylint: disable=pointless-statement
예제 #3
0
def test_observation_when_raw_step_returns_incorrect_no_of_observations():
    """Test that a ServiceError is propagated when raw_step() returns unexpected
    number of observations."""
    def make_failing_raw_step(n: int):
        def failing_raw_step(*args, **kwargs):
            """A callback that returns done=True."""
            del args  # Unused
            del kwargs  # Unused
            return ["ir"] * n, None, False, {}

        return failing_raw_step

    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(int64_value=Int64Range(min=0)),
        )
    ]

    observation = ObservationView(make_failing_raw_step(0), spaces)
    with pytest.raises(
            ServiceError,
            match=r"^Expected 1 'ir' observation but the service returned 0$"):
        observation["ir"]

    observation = ObservationView(make_failing_raw_step(3), spaces)
    with pytest.raises(
            ServiceError,
            match=r"^Expected 1 'ir' observation but the service returned 3$"):
        observation["ir"]
예제 #4
0
def test_space_message_default_converter():
    message_converter = py_converters.TypeBasedConverter(
        conversion_map={StringSpace: py_converters.convert_sequence_space}
    )
    space_converter = py_converters.SpaceMessageDefaultConverter(message_converter)
    val = StringSpace(length_range=Int64Range(min=1, max=2))
    space_message = Space(string_value=val)
    converted_space = space_converter(space_message)
    assert isinstance(converted_space, Sequence)
    assert converted_space.dtype == str
    assert converted_space.size_range[0] == 1
    assert converted_space.size_range[1] == 2
예제 #5
0
def test_convert_permutation_space_message():
    msg = Space(
        type_id="permutation",
        int64_sequence=Int64SequenceSpace(
            length_range=Int64Range(min=5, max=5), scalar_range=Int64Range(min=0, max=4)
        ),
    )
    permutation = py_converters.message_default_converter(msg)
    assert isinstance(permutation, Permutation)
    assert permutation.scalar_range.min == 0
    assert permutation.scalar_range.max == 4
    assert permutation.size_range[0] == 5
    assert permutation.size_range[1] == 5

    invalid_permutation_space_msg = Space(
        type_id="permutation",
        int64_sequence=Int64SequenceSpace(
            length_range=Int64Range(min=3, max=5), scalar_range=Int64Range(min=0, max=4)
        ),
    )
    with pytest.raises(ValueError, match="Invalid permutation space message"):
        py_converters.message_default_converter(invalid_permutation_space_msg)
예제 #6
0
def test_list_space_message_converter():
    message_converter = py_converters.TypeBasedConverter(
        conversion_map={StringSpace: py_converters.convert_sequence_space}
    )
    space_converter = py_converters.SpaceMessageDefaultConverter(message_converter)
    list_converter = py_converters.ListSpaceMessageConverter(space_converter)
    space_message = ListSpace(
        space=[
            Space(
                string_value=StringSpace(length_range=Int64Range(min=1, max=2)),
            )
        ]
    )
    converted_space = list_converter(space_message)
    assert isinstance(converted_space, Tuple)
    assert len(converted_space.spaces) == 1
    assert converted_space.spaces[0].dtype == str
    assert converted_space.spaces[0].size_range[0] == 1
    assert converted_space.spaces[0].size_range[1] == 2
예제 #7
0
def test_dict_space_message_converter():
    message_converter = py_converters.TypeBasedConverter(
        conversion_map={StringSpace: py_converters.convert_sequence_space}
    )
    space_converter = py_converters.SpaceMessageDefaultConverter(message_converter)
    dict_converter = py_converters.DictSpaceMessageConverter(space_converter)
    space_message = DictSpace(
        space={
            "key": Space(
                string_value=StringSpace(length_range=Int64Range(min=1, max=2)),
            )
        }
    )
    converted_space = dict_converter(space_message)
    assert isinstance(converted_space, Dict)
    assert len(converted_space.spaces) == 1
    assert "key" in converted_space.spaces
    assert converted_space.spaces["key"].dtype == str
    assert converted_space.spaces["key"].size_range[0] == 1
    assert converted_space.spaces["key"].size_range[1] == 2
예제 #8
0
def test_observed_value_types():
    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(string_value=StringSpace(length_range=Int64Range(
                min=0))),
        ),
        ObservationSpace(
            name="features",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[2], value=[-100, -100]),
                high=Int64Tensor(shape=[2], value=[100, 100]),
            ), ),
        ),
        ObservationSpace(
            name="dfeat",
            space=Space(double_box=DoubleBox(
                low=DoubleTensor(shape=[1], value=[0.5]),
                high=DoubleTensor(shape=[1], value=[2.5]),
            ), ),
        ),
        ObservationSpace(
            name="binary",
            space=Space(int64_value=Int64Range(min=5, max=5)),
        ),
    ]
    mock = MockRawStep(ret=[
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
    ])
    observation = ObservationView(mock, spaces)

    value = observation["ir"]
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation["dfeat"]
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation["features"]
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation["binary"]
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
    mock.called_observation_spaces = []

    # Repeat the above tests using the generated bound methods.
    value = observation.ir()
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation.dfeat()
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation.features()
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation.binary()
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
예제 #9
0
class LoopsOptCompilationSession(CompilationSession):
    """Represents an instance of an interactive compilation session."""

    compiler_version: str = "1.0.0"

    # The list of actions that are supported by this service.
    action_spaces = [
        ActionSpace(
            name="loop-opt",
            space=Space(named_discrete=NamedDiscreteSpace(name=[
                "--loop-unroll --unroll-count=2",
                "--loop-unroll --unroll-count=4",
                "--loop-unroll --unroll-count=8",
                "--loop-unroll --unroll-count=16",
                "--loop-unroll --unroll-count=32",
                "--loop-vectorize -force-vector-width=2",
                "--loop-vectorize -force-vector-width=4",
                "--loop-vectorize -force-vector-width=8",
                "--loop-vectorize -force-vector-width=16",
                "--loop-vectorize -force-vector-width=32",
            ]), ),
        )
    ]

    # A list of observation spaces supported by this service. Each of these
    # ObservationSpace protos describes an observation space.
    observation_spaces = [
        ObservationSpace(
            name="ir",
            space=Space(string_value=StringSpace(length_range=Int64Range(
                min=0))),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="Inst2vec",
            space=Space(int64_sequence=Int64SequenceSpace(
                length_range=Int64Range(min=0)), ),
        ),
        ObservationSpace(
            name="Autophase",
            space=Space(
                int64_sequence=Int64SequenceSpace(length_range=Int64Range(
                    min=len(AUTOPHASE_FEATURE_NAMES),
                    max=len(AUTOPHASE_FEATURE_NAMES),
                )), ),
            deterministic=True,
            platform_dependent=False,
        ),
        ObservationSpace(
            name="AutophaseDict",
            space=Space(space_dict=DictSpace(
                space={
                    name: Space(int64_value=Int64Range(min=0))
                    for name in AUTOPHASE_FEATURE_NAMES
                })),
            deterministic=True,
            platform_dependent=False,
        ),
        ObservationSpace(
            name="Programl",
            space=Space(string_value=StringSpace(length_range=Int64Range(
                min=0))),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="runtime",
            space=Space(double_value=DoubleRange(min=0)),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
        ObservationSpace(
            name="size",
            space=Space(double_value=DoubleRange(min=0)),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
    ]

    def __init__(
        self,
        working_directory: Path,
        action_space: ActionSpace,
        benchmark: Benchmark,
        use_custom_opt: bool = True,
    ):
        super().__init__(working_directory, action_space, benchmark)
        logging.info("Started a compilation session for %s", benchmark.uri)
        self._benchmark = benchmark
        self._action_space = action_space

        self.inst2vec = _INST2VEC_ENCODER

        # Resolve the paths to LLVM binaries once now.
        self._clang = str(llvm.clang_path())
        self._llc = str(llvm.llc_path())
        self._llvm_diff = str(llvm.llvm_diff_path())
        self._opt = str(llvm.opt_path())
        # LLVM's opt does not always enforce the loop optimization options passed as cli arguments.
        # Hence, we created our own exeutable with custom unrolling and vectorization pass in examples/loops_opt_service/opt_loops that enforces the unrolling and vectorization factors passed in its cli.
        # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt
        self._use_custom_opt = use_custom_opt

        # Dump the benchmark source to disk.
        self._src_path = str(self.working_dir / "benchmark.c")
        with open(self.working_dir / "benchmark.c", "wb") as f:
            f.write(benchmark.program.contents)

        self._llvm_path = str(self.working_dir / "benchmark.ll")
        self._llvm_before_path = str(self.working_dir /
                                     "benchmark.previous.ll")
        self._obj_path = str(self.working_dir / "benchmark.o")
        self._exe_path = str(self.working_dir / "benchmark.exe")

        run_command(
            [
                self._clang,
                "-Xclang",
                "-disable-O0-optnone",
                "-emit-llvm",
                "-S",
                self._src_path,
                "-o",
                self._llvm_path,
            ],
            timeout=30,
        )

    def apply_action(
            self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        num_choices = len(self._action_space.space.named_discrete.name)

        # This is the index into the action space's values ("a", "b", "c") that
        # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c".
        choice_index = action.int64_value
        if choice_index < 0 or choice_index >= num_choices:
            raise ValueError("Out-of-range")

        args = self._action_space.space.named_discrete.name[choice_index]
        logging.info(
            "Applying action %d, equivalent command-line arguments: '%s'",
            choice_index,
            args,
        )
        args = args.split()

        # make a copy of the LLVM file to compare its contents after applying the action
        shutil.copyfile(self._llvm_path, self._llvm_before_path)

        # apply action
        if self._use_custom_opt:
            # our custom opt-loops has an additional `f` at the beginning of each argument
            for i, arg in enumerate(args):
                # convert --<argument> to --f<argument>
                arg = arg[0:2] + "f" + arg[2:]
                args[i] = arg
            run_command(
                [
                    os.path.join(os.path.dirname(__file__),
                                 "../opt_loops/opt_loops"),
                    self._llvm_path,
                    *args,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )
        else:
            run_command(
                [
                    self._opt,
                    *args,
                    self._llvm_path,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )

        # compare the IR files to check if the action had an effect
        try:
            subprocess.check_call(
                [self._llvm_diff, self._llvm_before_path, self._llvm_path],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                timeout=60,
            )
            action_had_no_effect = True
        except subprocess.CalledProcessError:
            action_had_no_effect = False

        end_of_session = False  # TODO: this needs investigation: for how long can we apply loop optimizations? e.g., detect if there are no more loops in the IR? or look at the metadata?
        new_action_space = None
        return (end_of_session, new_action_space, action_had_no_effect)

    @property
    def ir(self) -> str:
        with open(self._llvm_path) as f:
            return f.read()

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s",
                     observation_space.name)
        if observation_space.name == "ir":
            return Event(string_value=self.ir)
        elif observation_space.name == "Inst2vec":
            Inst2vec_str = self.inst2vec.preprocess(self.ir)
            Inst2vec_ids = self.inst2vec.encode(Inst2vec_str)
            return Event(int64_tensor=Int64Tensor(shape=[len(Inst2vec_ids)],
                                                  value=Inst2vec_ids))
        elif observation_space.name == "Autophase":
            Autophase_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            Autophase_list = list(map(int, list(Autophase_str.split(" "))))
            return Event(int64_tensor=Int64Tensor(shape=[len(Autophase_list)],
                                                  value=Autophase_list))
        elif observation_space.name == "AutophaseDict":
            Autophase_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            Autophase_list = list(map(int, list(Autophase_str.split(" "))))
            Autophase_dict = {
                name: Event(int64_value=val)
                for name, val in zip(AUTOPHASE_FEATURE_NAMES, Autophase_list)
            }
            return Event(event_dict=DictEvent(event=Autophase_dict))
        elif observation_space.name == "Programl":
            Programl_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/programl/compute_programl",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            return Event(string_value=Programl_str)
        elif observation_space.name == "runtime":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-O3",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )

            # TODO: add documentation that benchmarks need print out execution time
            # Running 5 times and taking the average of middle 3
            exec_times = []
            for _ in range(5):
                stdout = run_command(
                    [self._exe_path],
                    timeout=30,
                )
                try:
                    exec_times.append(int(stdout))
                except ValueError:
                    raise ValueError(
                        f"Error in parsing execution time from output of command\n"
                        f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n"
                        f"Stdout of the program: {stdout}")
            exec_times = np.sort(exec_times)
            avg_exec_time = np.mean(exec_times[1:4])
            return Event(double_value=avg_exec_time)
        elif observation_space.name == "size":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-Oz",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )
            binary_size = os.path.getsize(self._exe_path)
            return Event(double_value=binary_size)
        else:
            raise KeyError(observation_space.name)
예제 #10
0
class ExampleCompilationSession(CompilationSession):
    """Represents an instance of an interactive compilation session."""

    compiler_version: str = "1.0.0"

    # The action spaces supported by this service. Here we will implement a
    # single action space, called "default", that represents a command line with
    # three options: "a", "b", and "c".
    action_spaces = [
        ActionSpace(
            name="default",
            space=Space(named_discrete=NamedDiscreteSpace(name=[
                "a",
                "b",
                "c",
            ], ), ),
        )
    ]

    # A list of observation spaces supported by this service. Each of these
    # ObservationSpace protos describes an observation space.
    observation_spaces = [
        ObservationSpace(
            name="ir",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="features",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[3], value=[-100, -100, -100]),
                high=Int64Tensor(shape=[3], value=[100, 100, 100]),
            ), ),
        ),
        ObservationSpace(
            name="runtime",
            space=Space(double_value=DoubleRange(min=0), ),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
    ]

    def __init__(self, working_directory: Path, action_space: ActionSpace,
                 benchmark: Benchmark):
        super().__init__(working_directory, action_space, benchmark)
        logging.info("Started a compilation session for %s", benchmark.uri)

    def apply_action(
            self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        num_choices = len(self.action_spaces[0].space.named_discrete.name)

        # This is the index into the action space's values ("a", "b", "c") that
        # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c".
        choice_index = action.int64_value
        logging.info("Applying action %d", choice_index)

        if choice_index < 0 or choice_index >= num_choices:
            raise ValueError("Out-of-range")

        # Here is where we would run the actual action to update the environment's
        # state.

        return False, None, False

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s", observation_space)
        if observation_space.name == "ir":
            return Event(string_value="Hello, world!")
        elif observation_space.name == "features":
            observation = Event(
                int64_tensor=Int64Tensor(shape=[3], value=[0, 0, 0]))
            return observation
        elif observation_space.name == "runtime":
            return Event(double_value=0)
        else:
            raise KeyError(observation_space.name)
예제 #11
0
def make_gcc_compilation_session(gcc_bin: str):
    """Create a class to represent a GCC compilation service.

    :param gcc_bin: Path to the gcc executable. This can a command name, like
        "gcc", or it can be path to the executable. Finally, if prefixed with
        "docker:" it can be the name of a docker image, e.g. "docker:gcc:11.2.0"
    """
    gcc = Gcc(gcc_bin)

    # The available actions
    actions = []

    # Actions that are small will have all their various choices made as
    # explicit actions.
    # Actions that are not small will have the abbility to increment the choice
    # by different amounts.
    for i, option in enumerate(gcc.spec.options):
        if len(option) < 10:
            for j in range(len(option)):
                actions.append(SimpleAction(option, i, j))
        if len(option) >= 10:
            actions.append(IncrAction(option, i, 1))
            actions.append(IncrAction(option, i, -1))
        if len(option) >= 50:
            actions.append(IncrAction(option, i, 10))
            actions.append(IncrAction(option, i, -10))
        if len(option) >= 500:
            actions.append(IncrAction(option, i, 100))
            actions.append(IncrAction(option, i, -100))
        if len(option) >= 5000:
            actions.append(IncrAction(option, i, 1000))
            actions.append(IncrAction(option, i, -1000))

    action_spaces_ = [
        ActionSpace(
            name="default",
            space=Space(
                named_discrete=NamedDiscreteSpace(name=[str(a) for a in actions]),
            ),
        ),
    ]

    observation_spaces_ = [
        # A string of the source code
        ObservationSpace(
            name="source",
            space=Space(string_value=StringSpace(length_range=Int64Range(min=0))),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        # A string of the rtl code
        ObservationSpace(
            name="rtl",
            space=Space(string_value=StringSpace(length_range=Int64Range(min=0))),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
        # A string of the assembled code
        ObservationSpace(
            name="asm",
            space=Space(string_value=StringSpace(length_range=Int64Range(min=0))),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
        # The size of the assembled code
        ObservationSpace(
            name="asm_size",
            space=Space(int64_value=Int64Range(min=-1)),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(
                int64_value=-1,
            ),
        ),
        # The hash of the assembled code
        ObservationSpace(
            name="asm_hash",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0, max=200)),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
        # Asm instruction counts - Counter as a JSON string
        ObservationSpace(
            name="instruction_counts",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
        # A bytes of the object code
        ObservationSpace(
            name="obj",
            space=Space(
                byte_sequence=ByteSequenceSpace(length_range=Int64Range(min=0)),
            ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(byte_tensor=ByteTensor(shape=[0], value=b"")),
        ),
        # The size of the object code
        ObservationSpace(
            name="obj_size",
            space=Space(int64_value=Int64Range(min=-1)),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(
                int64_value=-1,
            ),
        ),
        # The hash of the object code
        ObservationSpace(
            name="obj_hash",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0, max=200)),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
        # A list of the choices. Each element corresponds to an option in the spec.
        # '-1' indicates that this is empty on the command line (e.g. if the choice
        # corresponding to the '-O' option is -1, then no -O flag will be emitted.)
        # If a nonnegative number if given then that particular choice is used
        # (e.g. for the -O flag, 5 means use '-Ofast' on the command line.)
        ObservationSpace(
            name="choices",
            space=Space(
                space_list=ListSpace(
                    space=[
                        Space(int64_value=Int64Range(min=0, max=len(option) - 1))
                        for option in gcc.spec.options
                    ]
                ),
            ),
        ),
        # The command line for compiling the object file as a string
        ObservationSpace(
            name="command_line",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0, max=200)),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(string_value=""),
        ),
    ]

    class GccCompilationSession(CompilationSession):
        """A GCC interactive compilation session."""

        compiler_version: str = gcc.spec.version
        action_spaces = action_spaces_
        observation_spaces = observation_spaces_

        def __init__(
            self,
            working_directory: Path,
            action_space: ActionSpace,
            benchmark: Benchmark,
        ):
            super().__init__(working_directory, action_space, benchmark)
            # The benchmark being used
            self.benchmark = benchmark
            # Timeout value for compilation (in seconds)
            self._timeout = None
            # The source code
            self._source = None
            # The rtl code
            self._rtl = None
            # The assembled code
            self._asm = None
            # Size of the assembled code
            self._asm_size = None
            # Hash of the assembled code
            self._asm_hash = None
            # The object binary
            self._obj = None
            # size of the object binary
            self._obj_size = None
            # Hash of the object binary
            self._obj_hash = None
            # Set the path to the GCC executable
            self._gcc_bin = "gcc"
            # Initially the choices and the spec, etc are empty. They will be
            # initialised lazily
            self._choices = None

        @property
        def num_actions(self) -> int:
            return len(self.action_spaces[0].space.named_discrete.name)

        @property
        def choices(self) -> List[int]:
            if self._choices is None:
                self._choices = [-1] * len(gcc.spec.options)
            return self._choices

        @choices.setter
        def choices(self, value: List[int]):
            self._choices = value

        @property
        def source(self) -> str:
            """Get the benchmark source"""
            self.prepare_files()
            return self._source

        @property
        def rtl(self) -> bytes:
            """Get the RTL code"""
            self.dump_rtl()
            return self._rtl

        @property
        def asm(self) -> bytes:
            """Get the assembled code"""
            self.assemble()
            return self._asm

        @property
        def asm_size(self) -> int:
            """Get the assembled code size"""
            self.assemble()
            return self._asm_size

        @property
        def asm_hash(self) -> str:
            """Get the assembled code hash"""
            self.assemble()
            return self._asm_hash

        @property
        def instruction_counts(self) -> str:
            """Get the instuction counts as a JSON string"""
            self.assemble()
            insn_pat = re.compile("\t([a-zA-Z-0-9.-]+).*")
            insn_cnts = Counter()
            lines = self._asm.split("\n")
            for line in lines:
                m = insn_pat.fullmatch(line)
                if m:
                    insn_cnts[m.group(1)] += 1

            return json.dumps(insn_cnts)

        @property
        def obj(self) -> bytes:
            """Get the compiled code"""
            self.compile()
            return self._obj

        @property
        def obj_size(self) -> int:
            """Get the compiled code size"""
            self.compile()
            return self._obj_size

        @property
        def obj_hash(self) -> str:
            """Get the compiled code hash"""
            self.compile()
            return self._obj_hash

        @property
        def src_path(self) -> Path:
            """Get the path to the source file"""
            return self.working_dir / "src.c"

        @property
        def obj_path(self) -> Path:
            """Get the path to object file"""
            return self.working_dir / "obj.o"

        @property
        def asm_path(self) -> Path:
            """Get the path to the assembly"""
            return self.working_dir / "asm.s"

        @property
        def rtl_path(self) -> Path:
            """Get the path to the rtl"""
            return self.working_dir / "rtl.lsp"

        def obj_command_line(
            self, src_path: Path = None, obj_path: Path = None
        ) -> List[str]:
            """Get the command line to create the object file.
            The 'src_path' and 'obj_path' give the input and output paths. If not
            set, then they are taken from 'self.src_path' and 'self.obj_path'. This
            is useful for printing where the actual paths are not important."""
            src_path = src_path or self.src_path
            obj_path = obj_path or self.obj_path
            # Gather the choices as strings
            opts = [
                option[choice]
                for option, choice in zip(gcc.spec.options, self.choices)
                if choice >= 0
            ]
            cmd_line = opts + ["-w", "-c", src_path, "-o", obj_path]
            return cmd_line

        def asm_command_line(
            self, src_path: Path = None, asm_path: Path = None
        ) -> List[str]:
            """Get the command line to create the assembly file.
            The 'src_path' and 'asm_path' give the input and output paths. If not
            set, then they are taken from 'self.src_path' and 'self.obj_path'. This
            is useful for printing where the actual paths are not important."""
            src_path = src_path or self.src_path
            asm_path = asm_path or self.asm_path
            opts = [
                option[choice]
                for option, choice in zip(gcc.spec.options, self.choices)
                if choice >= 0
            ]
            cmd_line = opts + ["-w", "-S", src_path, "-o", asm_path]
            return cmd_line

        def rtl_command_line(
            self, src_path: Path = None, rtl_path: Path = None, asm_path: Path = None
        ) -> List[str]:
            """Get the command line to create the rtl file - might as well do the
            asm at the same time.
            The 'src_path', 'rtl_path', 'asm_path' give the input and output paths. If not
            set, then they are taken from 'self.src_path' and 'self.obj_path'. This
            is useful for printing where the actual paths are not important."""
            src_path = src_path or self.src_path
            rtl_path = rtl_path or self.rtl_path
            asm_path = asm_path or self.asm_path
            opts = [
                option[choice]
                for option, choice in zip(gcc.spec.options, self.choices)
                if choice >= 0
            ]
            cmd_line = opts + [
                "-w",
                "-S",
                src_path,
                f"-fdump-rtl-dfinish={rtl_path}",
                "-o",
                asm_path,
            ]
            return cmd_line

        def prepare_files(self):
            """Copy the source to the working directory."""
            if not self._source:
                if self.benchmark.program.contents:
                    self._source = self.benchmark.program.contents.decode()
                else:
                    with urlopen(self.benchmark.program.uri) as r:
                        self._source = r.read().decode()

                with open(self.src_path, "w") as f:
                    print(self._source, file=f)

        def compile(self) -> Optional[str]:
            """Compile the benchmark"""
            if not self._obj:
                self.prepare_files()
                logger.debug(
                    "Compiling: %s", " ".join(map(str, self.obj_command_line()))
                )
                gcc(
                    *self.obj_command_line(),
                    cwd=self.working_dir,
                    timeout=self._timeout,
                )
                with open(self.obj_path, "rb") as f:
                    # Set the internal variables
                    self._obj = f.read()
                    self._obj_size = os.path.getsize(self.obj_path)
                    self._obj_hash = hashlib.md5(self._obj).hexdigest()

        def assemble(self) -> Optional[str]:
            """Assemble the benchmark"""
            if not self._asm:
                self.prepare_files()
                logger.debug(
                    "Assembling: %s", " ".join(map(str, self.asm_command_line()))
                )
                gcc(
                    *self.asm_command_line(),
                    cwd=self.working_dir,
                    timeout=self._timeout,
                )
                with open(self.asm_path, "rb") as f:
                    # Set the internal variables
                    asm_bytes = f.read()
                    self._asm = asm_bytes.decode()
                    self._asm_size = os.path.getsize(self.asm_path)
                    self._asm_hash = hashlib.md5(asm_bytes).hexdigest()

        def dump_rtl(self) -> Optional[str]:
            """Dump the RTL (and assemble the benchmark)"""
            if not self._rtl:
                self.prepare_files()
                logger.debug(
                    "Dumping RTL: %s", " ".join(map(str, self.rtl_command_line()))
                )
                gcc(
                    *self.rtl_command_line(),
                    cwd=self.working_dir,
                    timeout=self._timeout,
                )
                with open(self.asm_path, "rb") as f:
                    # Set the internal variables
                    asm_bytes = f.read()
                    self._asm = asm_bytes.decode()
                    self._asm_size = os.path.getsize(self.asm_path)
                    self._asm_hash = hashlib.md5(asm_bytes).hexdigest()
                with open(self.rtl_path, "rb") as f:
                    # Set the internal variables
                    rtl_bytes = f.read()
                    self._rtl = rtl_bytes.decode()

        def reset_cached(self):
            """Reset the cached values"""
            self._obj = None
            self._obj_size = None
            self._obj_hash = None
            self._rtl = None
            self._asm = None
            self._asm_size = None
            self._asm_hash = None

        def apply_action(
            self, action_proto: Event
        ) -> Tuple[bool, Optional[ActionSpace], bool]:
            """Apply an action."""
            if not action_proto.HasField("int64_value"):
                raise ValueError("Invalid action, int64_value expected.")

            choice_index = action_proto.int64_value
            if choice_index < 0 or choice_index >= self.num_actions:
                raise ValueError("Out-of-range")

            # Get the action
            action = actions[choice_index]
            # Apply the action to this session and check if we changed anything
            old_choices = self.choices.copy()
            action(self)
            logger.debug("Applied action %s", action)

            # Reset the internal variables if this action has caused a change in the
            # choices
            if old_choices != self.choices:
                self.reset_cached()

            # The action has not changed anything yet. That waits until an
            # observation is taken
            return False, None, False

        def get_observation(self, observation_space: ObservationSpace) -> Event:
            """Get one of the observations"""
            if observation_space.name == "source":
                return Event(string_value=self.source or "")
            elif observation_space.name == "rtl":
                return Event(string_value=self.rtl or "")
            elif observation_space.name == "asm":
                return Event(string_value=self.asm or "")
            elif observation_space.name == "asm_size":
                return Event(int64_value=self.asm_size or -1)
            elif observation_space.name == "asm_hash":
                return Event(string_value=self.asm_hash or "")
            elif observation_space.name == "instruction_counts":
                return Event(string_value=self.instruction_counts or "{}")
            elif observation_space.name == "obj":
                value = self.obj or b""
                return Event(byte_tensor=ByteTensor(shape=[len(value)], value=value))
            elif observation_space.name == "obj_size":
                return Event(int64_value=self.obj_size or -1)
            elif observation_space.name == "obj_hash":
                return Event(string_value=self.obj_hash or "")
            elif observation_space.name == "choices":
                observation = Event(
                    int64_tensor=Int64Tensor(
                        shape=[len(self.choices)], value=self.choices
                    )
                )
                return observation
            elif observation_space.name == "command_line":
                return Event(
                    string_value=gcc.bin
                    + " "
                    + " ".join(map(str, self.obj_command_line("src.c", "obj.o")))
                )
            else:
                raise KeyError(observation_space.name)

        def handle_session_parameter(self, key: str, value: str) -> Optional[str]:
            if key == "gcc_spec":
                return codecs.encode(pickle.dumps(gcc.spec), "base64").decode()
            elif key == "choices":
                choices = list(map(int, value.split(",")))
                assert len(choices) == len(gcc.spec.options)
                assert all(
                    -1 <= p <= len(gcc.spec.options[i]) for i, p in enumerate(choices)
                )
                if choices != self.choices:
                    self.choices = choices
                    self.reset_cached()
                return ""
            elif key == "timeout":
                self._timeout = None if value == "" else int(value)
                return ""
            return None

    return GccCompilationSession
예제 #12
0
class UnrollingCompilationSession(CompilationSession):
    """Represents an instance of an interactive compilation session."""

    compiler_version: str = "1.0.0"

    # The list of actions that are supported by this service.
    action_spaces = [
        ActionSpace(
            name="unrolling",
            space=Space(
                named_discrete=NamedDiscreteSpace(
                    name=[
                        "-loop-unroll -unroll-count=2",
                        "-loop-unroll -unroll-count=4",
                        "-loop-unroll -unroll-count=8",
                    ],
                ),
            ),
        )
    ]

    # A list of observation spaces supported by this service. Each of these
    # ObservationSpace protos describes an observation space.
    observation_spaces = [
        ObservationSpace(
            name="ir",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)),
            ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="features",
            space=Space(
                int64_box=Int64Box(
                    low=Int64Tensor(shape=[3], value=[0, 0, 0]),
                    high=Int64Tensor(shape=[3], value=[100000, 100000, 100000]),
                ),
            ),
        ),
        ObservationSpace(
            name="runtime",
            space=Space(
                double_value=DoubleRange(min=0),
            ),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(
                double_value=0,
            ),
        ),
        ObservationSpace(
            name="size",
            space=Space(
                double_value=DoubleRange(min=0),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(
                double_value=0,
            ),
        ),
    ]

    def __init__(
        self,
        working_directory: Path,
        action_space: ActionSpace,
        benchmark: Benchmark,
        use_custom_opt: bool = True,
    ):
        super().__init__(working_directory, action_space, benchmark)
        logging.info("Started a compilation session for %s", benchmark.uri)
        self._benchmark = benchmark
        self._action_space = action_space

        # Resolve the paths to LLVM binaries once now.
        self._clang = str(llvm.clang_path())
        self._llc = str(llvm.llc_path())
        self._llvm_diff = str(llvm.llvm_diff_path())
        self._opt = str(llvm.opt_path())
        # LLVM's opt does not always enforce the unrolling options passed as cli arguments. Hence, we created our own exeutable with custom unrolling pass in examples/example_unrolling_service/loop_unroller that enforces the unrolling factors passed in its cli.
        # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt
        self._use_custom_opt = use_custom_opt

        # Dump the benchmark source to disk.
        self._src_path = str(self.working_dir / "benchmark.c")
        with open(self.working_dir / "benchmark.c", "wb") as f:
            f.write(benchmark.program.contents)

        self._llvm_path = str(self.working_dir / "benchmark.ll")
        self._llvm_before_path = str(self.working_dir / "benchmark.previous.ll")
        self._obj_path = str(self.working_dir / "benchmark.o")
        self._exe_path = str(self.working_dir / "benchmark.exe")

        run_command(
            [
                self._clang,
                "-Xclang",
                "-disable-O0-optnone",
                "-emit-llvm",
                "-S",
                self._src_path,
                "-o",
                self._llvm_path,
            ],
            timeout=30,
        )

    def apply_action(self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        num_choices = len(self._action_space.space.named_discrete.name)

        # This is the index into the action space's values ("a", "b", "c") that
        # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c".
        choice_index = action.int64_value
        if choice_index < 0 or choice_index >= num_choices:
            raise ValueError("Out-of-range")

        args = self._action_space.space.named_discrete.name[choice_index]
        logging.info(
            "Applying action %d, equivalent command-line arguments: '%s'",
            choice_index,
            args,
        )
        args = args.split()

        # make a copy of the LLVM file to compare its contents after applying the action
        shutil.copyfile(self._llvm_path, self._llvm_before_path)

        # apply action
        if self._use_custom_opt:
            # our custom unroller has an additional `f` at the beginning of each argument
            for i, arg in enumerate(args):
                # convert -<argument> to -f<argument>
                arg = arg[0] + "f" + arg[1:]
                args[i] = arg
            run_command(
                [
                    "../loop_unroller/loop_unroller",
                    self._llvm_path,
                    *args,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )
        else:
            run_command(
                [
                    self._opt,
                    *args,
                    self._llvm_path,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )

        # compare the IR files to check if the action had an effect
        try:
            subprocess.check_call(
                [self._llvm_diff, self._llvm_before_path, self._llvm_path],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                timeout=60,
            )
            action_had_no_effect = True
        except subprocess.CalledProcessError:
            action_had_no_effect = False

        end_of_session = False  # TODO: this needs investigation: for how long can we apply loop unrolling? e.g., detect if there are no more loops in the IR?
        new_action_space = None
        return (end_of_session, new_action_space, action_had_no_effect)

    @property
    def ir(self) -> str:
        with open(self._llvm_path) as f:
            return f.read()

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s", observation_space.name)
        if observation_space.name == "ir":
            return Event(string_value=self.ir)
        elif observation_space.name == "features":
            stats = utils.extract_statistics_from_ir(self.ir)
            observation = Event(
                int64_tensor=Int64Tensor(
                    shape=[len(list(stats.values()))], value=list(stats.values())
                )
            )
            return observation
        elif observation_space.name == "runtime":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-O3",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )

            # TODO: add documentation that benchmarks need print out execution time
            # Running 5 times and taking the average of middle 3
            exec_times = []
            for _ in range(5):
                stdout = run_command(
                    [self._exe_path],
                    timeout=30,
                )
                try:
                    exec_times.append(int(stdout))
                except ValueError:
                    raise ValueError(
                        f"Error in parsing execution time from output of command\n"
                        f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n"
                        f"Stdout of the program: {stdout}"
                    )
            exec_times = np.sort(exec_times)
            avg_exec_time = np.mean(exec_times[1:4])
            return Event(double_value=avg_exec_time)
        elif observation_space.name == "size":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-Oz",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )
            binary_size = os.path.getsize(self._exe_path)
            return Event(double_value=binary_size)
        else:
            raise KeyError(observation_space.name)
예제 #13
0
class LoopToolCompilationSession(CompilationSession):
    """Represents an instance of an interactive loop_tool session."""

    compiler_version: str = pkg_resources.get_distribution(
        "loop-tool-py").version

    # keep it simple for now: 1 variable, 1 nest
    action_spaces = [
        ActionSpace(
            name="simple",
            space=Space(
                # shift around a single pre-split order, changing the size of splits
                named_discrete=NamedDiscreteSpace(
                    name=["toggle_mode", "up", "down", "toggle_thread"], ), ),
        ),
        ActionSpace(
            name="split",
            space=Space(
                # potentially define new splits
                named_discrete=NamedDiscreteSpace(name=[
                    "toggle_mode", "up", "down", "toggle_thread", "split"
                ], ), ),
        ),
    ]

    observation_spaces = [
        ObservationSpace(
            name="flops",
            space=Space(double_value=DoubleRange()),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
        ObservationSpace(
            name="loop_tree",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value="", ),
        ),
        ObservationSpace(
            name="action_state",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[1], value=[0]),
                high=Int64Tensor(shape=[1], value=[2**36]),
            ), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(int64_tensor=Int64Tensor(shape=[1],
                                                               value=[0]), ),
        ),
    ]

    def __init__(self, working_directory: Path, action_space: ActionSpace,
                 benchmark: Benchmark):
        super().__init__(working_directory, action_space, benchmark)
        self.action_space = action_space
        if "cuda" in benchmark.uri:
            self.backend = "cuda"
            lt.set_default_hardware("cuda")
        else:
            self.backend = "cpu"
        if self.backend not in lt.backends():
            raise EnvironmentNotSupported(
                f"Failed to load {self.backend} dataset for loop_tool.  Have you installed all required dependecies?  See <https://facebookresearch.github.io/CompilerGym/envs/loop_tool.html#installation> for details. "
            )
        self.ir = lt.IR()
        self.var = self.ir.create_var("a")
        r0 = self.ir.create_node("read", [], [self.var])
        r1 = self.ir.create_node("read", [], [self.var])
        add = self.ir.create_node("add", [r0, r1], [self.var])
        w = self.ir.create_node("write", [add], [self.var])
        self.ir.set_inputs([r0, r1])
        self.ir.set_outputs([w])
        self.size = int(benchmark.uri.split("/")[-1])
        self.Ap = np.random.randn(self.size)
        self.Bp = np.random.randn(self.size)
        self.order = [(self.size, 0), (1, 0), (1, 0)]
        self.thread = [1, 0, 0]
        self.cursor = 0
        self.mode = "size"
        logger.info("Started a compilation session for %s", benchmark.uri)

    def resize(self, increment):
        """
        The idea is pull from or add to the parent loop.

        Three mutations possible to any size:
        A) x, y -> x + 1, 0
          remove the tail, increase loop size, shrink parent
        B) x, y -> x, 0
          only remove the tail, add to parent
        C) x, 0 -> x - 1, 0
          if no tail, shrink the loop size, increase parent

        note: this means tails can never exist on innermost loops. this makes good sense :)

        A)

        [(a, b), (x, y), ...k] -> [(a', b'), (x + 1, 0), ...k]
        a * (x * k + y) + b = a' * (x + 1) * k + b'
        a' = (a * (x * k + y) + b) // ((x + 1) * k)
        b' = "                   " %  "           "

        B)

        [(a, b), (x, y), ...k] -> [(a', b'), (x, 0), ...k]
        a * (x * k + y) + b = a' * (x) * k + b'
        a' = (a * (x * k + y) + b) // ((x) * k)
        b' = "                   " %  "           "

        C)

        [(a, b), (x, y), ...k] -> [(a', b'), (x - 1, 0), ...k]
        a * (x * k + y) + b = a' * (x - 1) * k + b'
        a' = (a * (x * k + y) + b) // ((x - 1) * k)
        b' = "                   " %  "           "

        example interaction model:
        1. cursor = 1        [1024, 1, 1]
        2. up                [512, 2, 1]
        3. up                [(341,1), 3, 1]
        4. up                [256, 4, 1]
        5. cursor = 2, up    [256, 2, 2]
        6. up                [256, (1, 1), 3]
        7. cursor = 1, down  [(341, 1), 1, 3]
        8. cursor = 2, down  [(341, 1), (1, 1), 2]
        9. cursor = 1, down  [512, 1, 2]"""
        if self.cursor == 0:
            return
        parent_size = self.order[self.cursor - 1]
        a = parent_size[0]
        b = parent_size[1]
        size = self.order[self.cursor]
        x = size[0]
        y = size[1]

        def lam(v, x):
            return v * x[0] + x[1]

        k = reduce(lam, self.order[self.cursor + 1:][::-1], 1)
        if increment == -1 and y:
            increment = 0
        if (x + increment) < 1:
            return
        if (x + increment) > self.size:
            return
        n = a * x * k + b
        d = (x + increment) * k
        a_ = n // d
        b_ = n % d
        if a_ < 1:
            return
        if a_ > self.size:
            return
        self.order[self.cursor - 1] = (a_, b_)
        self.order[self.cursor] = (x + increment, 0)
        end_size = reduce(lam, self.order[::-1], 1)
        assert (
            end_size == self.size
        ), f"{end_size} != {self.size} ({a}, {b}), ({x}, {y}) -> ({a_}, {b_}), ({x + increment}, 0)"

    def apply_action(
            self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        if not action.HasField("int64_value"):
            raise ValueError("Invalid action. int64_value expected.")

        choice_index = action.int64_value
        if choice_index < 0 or choice_index >= len(
                self.action_space.space.named_discrete.name):
            raise ValueError("Out-of-range")

        logger.info("Applied action %d", choice_index)

        act = self.action_space.space.named_discrete.name[choice_index]
        if self.mode not in ["size", "select"]:
            raise RuntimeError("Invalid mode set: {}".format(self.mode))
        if act == "toggle_mode":
            if self.mode == "size":
                self.mode = "select"
            elif self.mode == "select":
                self.mode = "size"
        if act == "toggle_thread":
            self.thread[self.cursor] = not self.thread[self.cursor]
        if act == "down":
            # always loop around
            if self.mode == "size":
                self.resize(-1)
            elif self.mode == "select":
                next_cursor = (self.cursor - 1) % len(self.order)
                self.cursor = next_cursor
        if act == "up":
            # always loop around
            if self.mode == "size":
                self.resize(1)
            elif self.mode == "select":
                next_cursor = (self.cursor + 1) % len(self.order)
                self.cursor = next_cursor

        return False, None, False

    def lower(self):
        for n in self.ir.nodes:
            o = [(self.var, k) for k in self.order]
            self.ir.set_order(n, o)
            # always disable innermost
            self.ir.disable_reuse(n, len(o) - 1)
        loop_tree = lt.LoopTree(self.ir)
        parallel = set()
        t = loop_tree.roots[0]
        for b in self.thread:
            if b:
                parallel.add(t)
                if self.backend == "cpu":
                    loop_tree.annotate(t, "cpu_parallel")
            t = loop_tree.children(t)[0]
        return loop_tree, parallel

    def flops(self):
        loop_tree, parallel = self.lower()
        if self.backend == "cuda":
            c = lt.cuda(loop_tree, parallel)
        else:
            c = lt.cpu(loop_tree)
        A = lt.Tensor(self.size)
        B = lt.Tensor(self.size)
        C = lt.Tensor(self.size)
        A.set(self.Ap)
        B.set(self.Bp)
        iters = 1000
        warmup = 50
        for i in range(warmup):
            c([A, B, C])
        t = time.time()
        for i in range(iters - 1):
            c([A, B, C], False)
        c([A, B, C])
        t_ = time.time()
        flops = self.size * iters / (t_ - t) / 1e9
        return flops

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        if observation_space.name == "action_state":
            # cursor, (size, tail)
            o = self.order[self.cursor]
            return Event(int64_tensor=Int64Tensor(
                shape=[3], value=[self.cursor, o[0], o[1]]))
        elif observation_space.name == "flops":
            return Event(double_value=self.flops())
        elif observation_space.name == "loop_tree":
            loop_tree, parallel = self.lower()
            return Event(string_value=loop_tree.dump(lambda x: "[thread]"
                                                     if x in parallel else ""))
        else:
            raise KeyError(observation_space.name)