def test_convert_int64_box_message():
    box = Int64Box(
        low=Int64Tensor(value=[1, 2], shape=[1, 2]),
        high=Int64Tensor(value=[2, 3], shape=[1, 2]),
    )
    converted_box = py_converters.convert_box_message(box)
    assert isinstance(converted_box, Box)
    assert converted_box.dtype == np.int64
    assert np.array_equal(box.low.shape, converted_box.shape)
    assert np.array_equal(box.high.shape, converted_box.shape)
    assert np.array_equal(box.low.value, converted_box.low.flatten())
    assert np.array_equal(box.high.value, converted_box.high.flatten())
def test_convert_int64_tensor_message_to_numpy():
    shape = [1, 2, 3]
    values = [1, 2, 3, 4, 5, 6]
    tensor_message = Int64Tensor(shape=shape, value=values)
    np_array = py_converters.convert_tensor_message_to_numpy(tensor_message)
    assert np_array.dtype == np.int64
    assert np.array_equal(np_array.shape, shape)
    flat_np_array = np_array.flatten()
    assert np.array_equal(flat_np_array, values)
 def get_observation(self, observation_space: ObservationSpace) -> Event:
     logging.info("Computing observation from space %s", observation_space)
     if observation_space.name == "ir":
         return Event(string_value="Hello, world!")
     elif observation_space.name == "features":
         observation = Event(
             int64_tensor=Int64Tensor(shape=[3], value=[0, 0, 0]))
         return observation
     elif observation_space.name == "runtime":
         return Event(double_value=0)
     else:
         raise KeyError(observation_space.name)
Exemple #4
0
 def get_observation(self, observation_space: ObservationSpace) -> Event:
     if observation_space.name == "action_state":
         # cursor, (size, tail)
         o = self.order[self.cursor]
         return Event(int64_tensor=Int64Tensor(
             shape=[3], value=[self.cursor, o[0], o[1]]))
     elif observation_space.name == "flops":
         return Event(double_value=self.flops())
     elif observation_space.name == "loop_tree":
         loop_tree, parallel = self.lower()
         return Event(string_value=loop_tree.dump(lambda x: "[thread]"
                                                  if x in parallel else ""))
     else:
         raise KeyError(observation_space.name)
Exemple #5
0
 def get_observation(self, observation_space: ObservationSpace) -> Event:
     """Get one of the observations"""
     if observation_space.name == "source":
         return Event(string_value=self.source or "")
     elif observation_space.name == "rtl":
         return Event(string_value=self.rtl or "")
     elif observation_space.name == "asm":
         return Event(string_value=self.asm or "")
     elif observation_space.name == "asm_size":
         return Event(int64_value=self.asm_size or -1)
     elif observation_space.name == "asm_hash":
         return Event(string_value=self.asm_hash or "")
     elif observation_space.name == "instruction_counts":
         return Event(string_value=self.instruction_counts or "{}")
     elif observation_space.name == "obj":
         value = self.obj or b""
         return Event(byte_tensor=ByteTensor(shape=[len(value)], value=value))
     elif observation_space.name == "obj_size":
         return Event(int64_value=self.obj_size or -1)
     elif observation_space.name == "obj_hash":
         return Event(string_value=self.obj_hash or "")
     elif observation_space.name == "choices":
         observation = Event(
             int64_tensor=Int64Tensor(
                 shape=[len(self.choices)], value=self.choices
             )
         )
         return observation
     elif observation_space.name == "command_line":
         return Event(
             string_value=gcc.bin
             + " "
             + " ".join(map(str, self.obj_command_line("src.c", "obj.o")))
         )
     else:
         raise KeyError(observation_space.name)
def test_observed_value_types():
    spaces = [
        ObservationSpace(
            name="ir",
            space=Space(string_value=StringSpace(length_range=Int64Range(
                min=0))),
        ),
        ObservationSpace(
            name="features",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[2], value=[-100, -100]),
                high=Int64Tensor(shape=[2], value=[100, 100]),
            ), ),
        ),
        ObservationSpace(
            name="dfeat",
            space=Space(double_box=DoubleBox(
                low=DoubleTensor(shape=[1], value=[0.5]),
                high=DoubleTensor(shape=[1], value=[2.5]),
            ), ),
        ),
        ObservationSpace(
            name="binary",
            space=Space(int64_value=Int64Range(min=5, max=5)),
        ),
    ]
    mock = MockRawStep(ret=[
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
        "Hello, IR",
        [1.0, 2.0],
        [-5, 15],
        b"Hello, bytes\0",
    ])
    observation = ObservationView(mock, spaces)

    value = observation["ir"]
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation["dfeat"]
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation["features"]
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation["binary"]
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
    mock.called_observation_spaces = []

    # Repeat the above tests using the generated bound methods.
    value = observation.ir()
    assert isinstance(value, str)
    assert value == "Hello, IR"

    value = observation.dfeat()
    np.testing.assert_array_almost_equal(value, [1.0, 2.0])

    value = observation.features()
    np.testing.assert_array_equal(value, [-5, 15])

    value = observation.binary()
    assert value == b"Hello, bytes\0"

    # Check that the correct observation_space_list indices were used.
    assert mock.called_observation_spaces == [
        "ir", "dfeat", "features", "binary"
    ]
    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s",
                     observation_space.name)
        if observation_space.name == "ir":
            return Event(string_value=self.ir)
        elif observation_space.name == "Inst2vec":
            Inst2vec_str = self.inst2vec.preprocess(self.ir)
            Inst2vec_ids = self.inst2vec.encode(Inst2vec_str)
            return Event(int64_tensor=Int64Tensor(shape=[len(Inst2vec_ids)],
                                                  value=Inst2vec_ids))
        elif observation_space.name == "Autophase":
            Autophase_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            Autophase_list = list(map(int, list(Autophase_str.split(" "))))
            return Event(int64_tensor=Int64Tensor(shape=[len(Autophase_list)],
                                                  value=Autophase_list))
        elif observation_space.name == "AutophaseDict":
            Autophase_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/autophase/compute_autophase-prelinked",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            Autophase_list = list(map(int, list(Autophase_str.split(" "))))
            Autophase_dict = {
                name: Event(int64_value=val)
                for name, val in zip(AUTOPHASE_FEATURE_NAMES, Autophase_list)
            }
            return Event(event_dict=DictEvent(event=Autophase_dict))
        elif observation_space.name == "Programl":
            Programl_str = run_command(
                [
                    os.path.join(
                        os.path.dirname(__file__),
                        "../../../compiler_gym/third_party/programl/compute_programl",
                    ),
                    self._llvm_path,
                ],
                timeout=30,
            )
            return Event(string_value=Programl_str)
        elif observation_space.name == "runtime":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-O3",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )

            # TODO: add documentation that benchmarks need print out execution time
            # Running 5 times and taking the average of middle 3
            exec_times = []
            for _ in range(5):
                stdout = run_command(
                    [self._exe_path],
                    timeout=30,
                )
                try:
                    exec_times.append(int(stdout))
                except ValueError:
                    raise ValueError(
                        f"Error in parsing execution time from output of command\n"
                        f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n"
                        f"Stdout of the program: {stdout}")
            exec_times = np.sort(exec_times)
            avg_exec_time = np.mean(exec_times[1:4])
            return Event(double_value=avg_exec_time)
        elif observation_space.name == "size":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-Oz",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )
            binary_size = os.path.getsize(self._exe_path)
            return Event(double_value=binary_size)
        else:
            raise KeyError(observation_space.name)
class ExampleCompilationSession(CompilationSession):
    """Represents an instance of an interactive compilation session."""

    compiler_version: str = "1.0.0"

    # The action spaces supported by this service. Here we will implement a
    # single action space, called "default", that represents a command line with
    # three options: "a", "b", and "c".
    action_spaces = [
        ActionSpace(
            name="default",
            space=Space(named_discrete=NamedDiscreteSpace(name=[
                "a",
                "b",
                "c",
            ], ), ),
        )
    ]

    # A list of observation spaces supported by this service. Each of these
    # ObservationSpace protos describes an observation space.
    observation_spaces = [
        ObservationSpace(
            name="ir",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="features",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[3], value=[-100, -100, -100]),
                high=Int64Tensor(shape=[3], value=[100, 100, 100]),
            ), ),
        ),
        ObservationSpace(
            name="runtime",
            space=Space(double_value=DoubleRange(min=0), ),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
    ]

    def __init__(self, working_directory: Path, action_space: ActionSpace,
                 benchmark: Benchmark):
        super().__init__(working_directory, action_space, benchmark)
        logging.info("Started a compilation session for %s", benchmark.uri)

    def apply_action(
            self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        num_choices = len(self.action_spaces[0].space.named_discrete.name)

        # This is the index into the action space's values ("a", "b", "c") that
        # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c".
        choice_index = action.int64_value
        logging.info("Applying action %d", choice_index)

        if choice_index < 0 or choice_index >= num_choices:
            raise ValueError("Out-of-range")

        # Here is where we would run the actual action to update the environment's
        # state.

        return False, None, False

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s", observation_space)
        if observation_space.name == "ir":
            return Event(string_value="Hello, world!")
        elif observation_space.name == "features":
            observation = Event(
                int64_tensor=Int64Tensor(shape=[3], value=[0, 0, 0]))
            return observation
        elif observation_space.name == "runtime":
            return Event(double_value=0)
        else:
            raise KeyError(observation_space.name)
Exemple #9
0
class UnrollingCompilationSession(CompilationSession):
    """Represents an instance of an interactive compilation session."""

    compiler_version: str = "1.0.0"

    # The list of actions that are supported by this service.
    action_spaces = [
        ActionSpace(
            name="unrolling",
            space=Space(
                named_discrete=NamedDiscreteSpace(
                    name=[
                        "-loop-unroll -unroll-count=2",
                        "-loop-unroll -unroll-count=4",
                        "-loop-unroll -unroll-count=8",
                    ],
                ),
            ),
        )
    ]

    # A list of observation spaces supported by this service. Each of these
    # ObservationSpace protos describes an observation space.
    observation_spaces = [
        ObservationSpace(
            name="ir",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)),
            ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value=""),
        ),
        ObservationSpace(
            name="features",
            space=Space(
                int64_box=Int64Box(
                    low=Int64Tensor(shape=[3], value=[0, 0, 0]),
                    high=Int64Tensor(shape=[3], value=[100000, 100000, 100000]),
                ),
            ),
        ),
        ObservationSpace(
            name="runtime",
            space=Space(
                double_value=DoubleRange(min=0),
            ),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(
                double_value=0,
            ),
        ),
        ObservationSpace(
            name="size",
            space=Space(
                double_value=DoubleRange(min=0),
            ),
            deterministic=True,
            platform_dependent=True,
            default_observation=Event(
                double_value=0,
            ),
        ),
    ]

    def __init__(
        self,
        working_directory: Path,
        action_space: ActionSpace,
        benchmark: Benchmark,
        use_custom_opt: bool = True,
    ):
        super().__init__(working_directory, action_space, benchmark)
        logging.info("Started a compilation session for %s", benchmark.uri)
        self._benchmark = benchmark
        self._action_space = action_space

        # Resolve the paths to LLVM binaries once now.
        self._clang = str(llvm.clang_path())
        self._llc = str(llvm.llc_path())
        self._llvm_diff = str(llvm.llvm_diff_path())
        self._opt = str(llvm.opt_path())
        # LLVM's opt does not always enforce the unrolling options passed as cli arguments. Hence, we created our own exeutable with custom unrolling pass in examples/example_unrolling_service/loop_unroller that enforces the unrolling factors passed in its cli.
        # if self._use_custom_opt is true, use our custom exeutable, otherwise use LLVM's opt
        self._use_custom_opt = use_custom_opt

        # Dump the benchmark source to disk.
        self._src_path = str(self.working_dir / "benchmark.c")
        with open(self.working_dir / "benchmark.c", "wb") as f:
            f.write(benchmark.program.contents)

        self._llvm_path = str(self.working_dir / "benchmark.ll")
        self._llvm_before_path = str(self.working_dir / "benchmark.previous.ll")
        self._obj_path = str(self.working_dir / "benchmark.o")
        self._exe_path = str(self.working_dir / "benchmark.exe")

        run_command(
            [
                self._clang,
                "-Xclang",
                "-disable-O0-optnone",
                "-emit-llvm",
                "-S",
                self._src_path,
                "-o",
                self._llvm_path,
            ],
            timeout=30,
        )

    def apply_action(self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        num_choices = len(self._action_space.space.named_discrete.name)

        # This is the index into the action space's values ("a", "b", "c") that
        # the user selected, e.g. 0 -> "a", 1 -> "b", 2 -> "c".
        choice_index = action.int64_value
        if choice_index < 0 or choice_index >= num_choices:
            raise ValueError("Out-of-range")

        args = self._action_space.space.named_discrete.name[choice_index]
        logging.info(
            "Applying action %d, equivalent command-line arguments: '%s'",
            choice_index,
            args,
        )
        args = args.split()

        # make a copy of the LLVM file to compare its contents after applying the action
        shutil.copyfile(self._llvm_path, self._llvm_before_path)

        # apply action
        if self._use_custom_opt:
            # our custom unroller has an additional `f` at the beginning of each argument
            for i, arg in enumerate(args):
                # convert -<argument> to -f<argument>
                arg = arg[0] + "f" + arg[1:]
                args[i] = arg
            run_command(
                [
                    "../loop_unroller/loop_unroller",
                    self._llvm_path,
                    *args,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )
        else:
            run_command(
                [
                    self._opt,
                    *args,
                    self._llvm_path,
                    "-S",
                    "-o",
                    self._llvm_path,
                ],
                timeout=30,
            )

        # compare the IR files to check if the action had an effect
        try:
            subprocess.check_call(
                [self._llvm_diff, self._llvm_before_path, self._llvm_path],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                timeout=60,
            )
            action_had_no_effect = True
        except subprocess.CalledProcessError:
            action_had_no_effect = False

        end_of_session = False  # TODO: this needs investigation: for how long can we apply loop unrolling? e.g., detect if there are no more loops in the IR?
        new_action_space = None
        return (end_of_session, new_action_space, action_had_no_effect)

    @property
    def ir(self) -> str:
        with open(self._llvm_path) as f:
            return f.read()

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s", observation_space.name)
        if observation_space.name == "ir":
            return Event(string_value=self.ir)
        elif observation_space.name == "features":
            stats = utils.extract_statistics_from_ir(self.ir)
            observation = Event(
                int64_tensor=Int64Tensor(
                    shape=[len(list(stats.values()))], value=list(stats.values())
                )
            )
            return observation
        elif observation_space.name == "runtime":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-O3",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )

            # TODO: add documentation that benchmarks need print out execution time
            # Running 5 times and taking the average of middle 3
            exec_times = []
            for _ in range(5):
                stdout = run_command(
                    [self._exe_path],
                    timeout=30,
                )
                try:
                    exec_times.append(int(stdout))
                except ValueError:
                    raise ValueError(
                        f"Error in parsing execution time from output of command\n"
                        f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n"
                        f"Stdout of the program: {stdout}"
                    )
            exec_times = np.sort(exec_times)
            avg_exec_time = np.mean(exec_times[1:4])
            return Event(double_value=avg_exec_time)
        elif observation_space.name == "size":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-Oz",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )
            binary_size = os.path.getsize(self._exe_path)
            return Event(double_value=binary_size)
        else:
            raise KeyError(observation_space.name)
Exemple #10
0
    def get_observation(self, observation_space: ObservationSpace) -> Event:
        logging.info("Computing observation from space %s", observation_space.name)
        if observation_space.name == "ir":
            return Event(string_value=self.ir)
        elif observation_space.name == "features":
            stats = utils.extract_statistics_from_ir(self.ir)
            observation = Event(
                int64_tensor=Int64Tensor(
                    shape=[len(list(stats.values()))], value=list(stats.values())
                )
            )
            return observation
        elif observation_space.name == "runtime":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-O3",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )

            # TODO: add documentation that benchmarks need print out execution time
            # Running 5 times and taking the average of middle 3
            exec_times = []
            for _ in range(5):
                stdout = run_command(
                    [self._exe_path],
                    timeout=30,
                )
                try:
                    exec_times.append(int(stdout))
                except ValueError:
                    raise ValueError(
                        f"Error in parsing execution time from output of command\n"
                        f"Please ensure that the source code of the benchmark measures execution time and prints to stdout\n"
                        f"Stdout of the program: {stdout}"
                    )
            exec_times = np.sort(exec_times)
            avg_exec_time = np.mean(exec_times[1:4])
            return Event(double_value=avg_exec_time)
        elif observation_space.name == "size":
            # compile LLVM to object file
            run_command(
                [
                    self._llc,
                    "-filetype=obj",
                    self._llvm_path,
                    "-o",
                    self._obj_path,
                ],
                timeout=30,
            )

            # build object file to binary
            run_command(
                [
                    self._clang,
                    self._obj_path,
                    "-Oz",
                    "-o",
                    self._exe_path,
                ],
                timeout=30,
            )
            binary_size = os.path.getsize(self._exe_path)
            return Event(double_value=binary_size)
        else:
            raise KeyError(observation_space.name)
Exemple #11
0
class LoopToolCompilationSession(CompilationSession):
    """Represents an instance of an interactive loop_tool session."""

    compiler_version: str = pkg_resources.get_distribution(
        "loop-tool-py").version

    # keep it simple for now: 1 variable, 1 nest
    action_spaces = [
        ActionSpace(
            name="simple",
            space=Space(
                # shift around a single pre-split order, changing the size of splits
                named_discrete=NamedDiscreteSpace(
                    name=["toggle_mode", "up", "down", "toggle_thread"], ), ),
        ),
        ActionSpace(
            name="split",
            space=Space(
                # potentially define new splits
                named_discrete=NamedDiscreteSpace(name=[
                    "toggle_mode", "up", "down", "toggle_thread", "split"
                ], ), ),
        ),
    ]

    observation_spaces = [
        ObservationSpace(
            name="flops",
            space=Space(double_value=DoubleRange()),
            deterministic=False,
            platform_dependent=True,
            default_observation=Event(double_value=0, ),
        ),
        ObservationSpace(
            name="loop_tree",
            space=Space(
                string_value=StringSpace(length_range=Int64Range(min=0)), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(string_value="", ),
        ),
        ObservationSpace(
            name="action_state",
            space=Space(int64_box=Int64Box(
                low=Int64Tensor(shape=[1], value=[0]),
                high=Int64Tensor(shape=[1], value=[2**36]),
            ), ),
            deterministic=True,
            platform_dependent=False,
            default_observation=Event(int64_tensor=Int64Tensor(shape=[1],
                                                               value=[0]), ),
        ),
    ]

    def __init__(self, working_directory: Path, action_space: ActionSpace,
                 benchmark: Benchmark):
        super().__init__(working_directory, action_space, benchmark)
        self.action_space = action_space
        if "cuda" in benchmark.uri:
            self.backend = "cuda"
            lt.set_default_hardware("cuda")
        else:
            self.backend = "cpu"
        if self.backend not in lt.backends():
            raise EnvironmentNotSupported(
                f"Failed to load {self.backend} dataset for loop_tool.  Have you installed all required dependecies?  See <https://facebookresearch.github.io/CompilerGym/envs/loop_tool.html#installation> for details. "
            )
        self.ir = lt.IR()
        self.var = self.ir.create_var("a")
        r0 = self.ir.create_node("read", [], [self.var])
        r1 = self.ir.create_node("read", [], [self.var])
        add = self.ir.create_node("add", [r0, r1], [self.var])
        w = self.ir.create_node("write", [add], [self.var])
        self.ir.set_inputs([r0, r1])
        self.ir.set_outputs([w])
        self.size = int(benchmark.uri.split("/")[-1])
        self.Ap = np.random.randn(self.size)
        self.Bp = np.random.randn(self.size)
        self.order = [(self.size, 0), (1, 0), (1, 0)]
        self.thread = [1, 0, 0]
        self.cursor = 0
        self.mode = "size"
        logger.info("Started a compilation session for %s", benchmark.uri)

    def resize(self, increment):
        """
        The idea is pull from or add to the parent loop.

        Three mutations possible to any size:
        A) x, y -> x + 1, 0
          remove the tail, increase loop size, shrink parent
        B) x, y -> x, 0
          only remove the tail, add to parent
        C) x, 0 -> x - 1, 0
          if no tail, shrink the loop size, increase parent

        note: this means tails can never exist on innermost loops. this makes good sense :)

        A)

        [(a, b), (x, y), ...k] -> [(a', b'), (x + 1, 0), ...k]
        a * (x * k + y) + b = a' * (x + 1) * k + b'
        a' = (a * (x * k + y) + b) // ((x + 1) * k)
        b' = "                   " %  "           "

        B)

        [(a, b), (x, y), ...k] -> [(a', b'), (x, 0), ...k]
        a * (x * k + y) + b = a' * (x) * k + b'
        a' = (a * (x * k + y) + b) // ((x) * k)
        b' = "                   " %  "           "

        C)

        [(a, b), (x, y), ...k] -> [(a', b'), (x - 1, 0), ...k]
        a * (x * k + y) + b = a' * (x - 1) * k + b'
        a' = (a * (x * k + y) + b) // ((x - 1) * k)
        b' = "                   " %  "           "

        example interaction model:
        1. cursor = 1        [1024, 1, 1]
        2. up                [512, 2, 1]
        3. up                [(341,1), 3, 1]
        4. up                [256, 4, 1]
        5. cursor = 2, up    [256, 2, 2]
        6. up                [256, (1, 1), 3]
        7. cursor = 1, down  [(341, 1), 1, 3]
        8. cursor = 2, down  [(341, 1), (1, 1), 2]
        9. cursor = 1, down  [512, 1, 2]"""
        if self.cursor == 0:
            return
        parent_size = self.order[self.cursor - 1]
        a = parent_size[0]
        b = parent_size[1]
        size = self.order[self.cursor]
        x = size[0]
        y = size[1]

        def lam(v, x):
            return v * x[0] + x[1]

        k = reduce(lam, self.order[self.cursor + 1:][::-1], 1)
        if increment == -1 and y:
            increment = 0
        if (x + increment) < 1:
            return
        if (x + increment) > self.size:
            return
        n = a * x * k + b
        d = (x + increment) * k
        a_ = n // d
        b_ = n % d
        if a_ < 1:
            return
        if a_ > self.size:
            return
        self.order[self.cursor - 1] = (a_, b_)
        self.order[self.cursor] = (x + increment, 0)
        end_size = reduce(lam, self.order[::-1], 1)
        assert (
            end_size == self.size
        ), f"{end_size} != {self.size} ({a}, {b}), ({x}, {y}) -> ({a_}, {b_}), ({x + increment}, 0)"

    def apply_action(
            self, action: Event) -> Tuple[bool, Optional[ActionSpace], bool]:
        if not action.HasField("int64_value"):
            raise ValueError("Invalid action. int64_value expected.")

        choice_index = action.int64_value
        if choice_index < 0 or choice_index >= len(
                self.action_space.space.named_discrete.name):
            raise ValueError("Out-of-range")

        logger.info("Applied action %d", choice_index)

        act = self.action_space.space.named_discrete.name[choice_index]
        if self.mode not in ["size", "select"]:
            raise RuntimeError("Invalid mode set: {}".format(self.mode))
        if act == "toggle_mode":
            if self.mode == "size":
                self.mode = "select"
            elif self.mode == "select":
                self.mode = "size"
        if act == "toggle_thread":
            self.thread[self.cursor] = not self.thread[self.cursor]
        if act == "down":
            # always loop around
            if self.mode == "size":
                self.resize(-1)
            elif self.mode == "select":
                next_cursor = (self.cursor - 1) % len(self.order)
                self.cursor = next_cursor
        if act == "up":
            # always loop around
            if self.mode == "size":
                self.resize(1)
            elif self.mode == "select":
                next_cursor = (self.cursor + 1) % len(self.order)
                self.cursor = next_cursor

        return False, None, False

    def lower(self):
        for n in self.ir.nodes:
            o = [(self.var, k) for k in self.order]
            self.ir.set_order(n, o)
            # always disable innermost
            self.ir.disable_reuse(n, len(o) - 1)
        loop_tree = lt.LoopTree(self.ir)
        parallel = set()
        t = loop_tree.roots[0]
        for b in self.thread:
            if b:
                parallel.add(t)
                if self.backend == "cpu":
                    loop_tree.annotate(t, "cpu_parallel")
            t = loop_tree.children(t)[0]
        return loop_tree, parallel

    def flops(self):
        loop_tree, parallel = self.lower()
        if self.backend == "cuda":
            c = lt.cuda(loop_tree, parallel)
        else:
            c = lt.cpu(loop_tree)
        A = lt.Tensor(self.size)
        B = lt.Tensor(self.size)
        C = lt.Tensor(self.size)
        A.set(self.Ap)
        B.set(self.Bp)
        iters = 1000
        warmup = 50
        for i in range(warmup):
            c([A, B, C])
        t = time.time()
        for i in range(iters - 1):
            c([A, B, C], False)
        c([A, B, C])
        t_ = time.time()
        flops = self.size * iters / (t_ - t) / 1e9
        return flops

    def get_observation(self, observation_space: ObservationSpace) -> Event:
        if observation_space.name == "action_state":
            # cursor, (size, tail)
            o = self.order[self.cursor]
            return Event(int64_tensor=Int64Tensor(
                shape=[3], value=[self.cursor, o[0], o[1]]))
        elif observation_space.name == "flops":
            return Event(double_value=self.flops())
        elif observation_space.name == "loop_tree":
            loop_tree, parallel = self.lower()
            return Event(string_value=loop_tree.dump(lambda x: "[thread]"
                                                     if x in parallel else ""))
        else:
            raise KeyError(observation_space.name)