コード例 #1
0
ファイル: tune_network_arm.py プロジェクト: junrushao1994/tvm
def tune_and_evaluate():
    print("Begin tuning...")
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=
        200,  # change this to 20000 to achieve the best performance
        builder=auto_scheduler.LocalBuilder(
            build_func="ndk" if use_ndk else "default"),
        runner=auto_scheduler.RPCRunner(
            device_key,
            host=rpc_host,
            port=rpc_port,
            timeout=30,
            repeat=1,
            min_repeat_ms=200,
            enable_cpu_cache_flush=True,
        ),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner.tune(tune_option)

    # Compile with the history best
    print("Compile...")
    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(
                opt_level=3, config={"relay.backend.use_auto_scheduler":
                                     True}):
            lib = relay.build(mod, target=target, params=params)

    # Export library
    tmp = tempdir()
    if use_ndk:
        from tvm.contrib import ndk

        filename = "net.so"
        lib.export_library(tmp.relpath(filename), ndk.create_shared)
    else:
        filename = "net.tar"
        lib.export_library(tmp.relpath(filename))

    # Upload module to device
    print("Upload...")
    remote = auto_scheduler.utils.request_remote(device_key,
                                                 rpc_host,
                                                 rpc_port,
                                                 timeout=10000)
    remote.upload(tmp.relpath(filename))
    rlib = remote.load_module(filename)

    # Create graph executor
    dev = remote.cpu()
    module = graph_executor.GraphModule(rlib["default"](dev))
    data_tvm = tvm.nd.array(
        (np.random.uniform(size=input_shape)).astype(dtype))
    module.set_input("data", data_tvm)

    # Evaluate
    print("Evaluate inference time cost...")
    print(module.benchmark(dev, repeat=3, min_repeat_ms=500))
コード例 #2
0
def test_measure_special_inputs_map_by_name_rpc_runner():
    @auto_scheduler.register_workload
    def foo():
        X = te.placeholder(shape=[10], dtype="int32")
        Index = te.placeholder(shape=[1], dtype="int32", name="Index")
        Y = te.compute((1, ), lambda i: X[Index[i]])
        return [X, Index, Y]

    # This workload cannot use random input for the `Index` input
    task = auto_scheduler.SearchTask(
        func=foo,
        target="llvm",
        task_inputs={
            "Index": tvm.nd.array(np.array([5], dtype="int32")),
        },
    )

    for enable_cpu_cache_flush in [True, False]:
        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
        local_builder = auto_scheduler.LocalBuilder()
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(
            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)
        rpc_runner = measure_ctx.runner

        bress = local_builder.build([minp])
        assert bress[0].error_no == 0
        mress = rpc_runner.run([minp], bress)
        assert mress[0].error_no == 0
コード例 #3
0
def tune_network(network, target):
    # Extract tasks
    mod, params = get_network(network)
    target = tvm.target.Target(target)
    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params,
                                                       target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        # Tuning
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60)
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=100,
            num_measures_per_round=2,
            early_stopping=1,
            runner=measure_ctx.runner,
            builder=auto_scheduler.LocalBuilder(timeout=60),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        tuner.tune(tune_option, search_policy="sketch.random")
        del measure_ctx

        # Compile with the history best
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True}):
                lib = relay.build(mod, target=target, params=params)
コード例 #4
0
    def tune_and_evaluate():
        print("Begin tuning...")
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=200,
            builder=auto_scheduler.LocalBuilder(build_func="ndk"),
            runner=auto_scheduler.RPCRunner(
                device_key,
                host=rpc_host,
                port=rpc_port,
                timeout=30,
                repeat=1,
                min_repeat_ms=200,
                enable_cpu_cache_flush=True,
            ),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )

        tuner.tune(tune_option)

        # Compile with the history best
        print("Compile...")
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True}):
                lib = relay.build(mod, target=target, params=params)

        # Export library
        tmp = tempdir()
        filename = "net.so"
        lib.export_library(tmp.relpath(filename), ndk.create_shared)

        # Upload module to device
        print("Upload...")
        remote = auto_scheduler.utils.request_remote(device_key,
                                                     rpc_host,
                                                     rpc_port,
                                                     timeout=10000)
        remote.upload(tmp.relpath(filename))
        rlib = remote.load_module(filename)

        # Create graph executor
        dev = remote.cpu()
        module = graph_executor.GraphModule(rlib["default"](dev))
        for key, value in shape_dict.items():
            data_tvm = tvm.nd.array(
                (np.random.uniform(size=value)).astype("float32"))
            module.set_input(key, data_tvm)

        # Evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run",
                                              dev,
                                              repeat=3,
                                              min_repeat_ms=500)
        prof_res = np.array(ftimer().results) * 1e3  # convert to millisecond
        print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
              (np.mean(prof_res), np.std(prof_res)))
コード例 #5
0
def tune_network(network, target):
    # Extract tasks
    mod, params = get_network(network)
    target = tvm.target.Target(target)
    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params,
                                                       target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        # Tuning
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60)
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=100,
            num_measures_per_round=2,
            early_stopping=1,
            runner=measure_ctx.runner,
            builder=auto_scheduler.LocalBuilder(timeout=60),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        tuner.tune(tune_option, search_policy="sketch.random")
        del measure_ctx

        # Compile with the history best
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True}):
                lib = relay.build(mod, target=target, params=params)

        # Compile without auto-scheduler and any other optimization for correctness check
        with tvm.transform.PassContext(opt_level=0):
            lib2 = relay.build(mod, target=target, params=params)

        # Check the correctness
        def get_output(data, lib):
            ctx = tvm.gpu()
            module = graph_runtime.GraphModule(lib["default"](ctx))
            module.set_input("data", data)
            module.run()
            return module.get_output(0).asnumpy()

        np.random.seed(0)
        if network == "mlp":
            data = np.random.uniform(size=(1, 32))
        elif network == "winograd-test":
            data = np.random.uniform(size=(1, 23, 40, 32))
        else:
            raise ValueError("Unknown network: " + network)

        actual_output = get_output(data, lib)
        expected_output = get_output(data, lib2)

        tvm.testing.assert_allclose(actual_output,
                                    expected_output,
                                    rtol=1e-4,
                                    atol=1e-4)
コード例 #6
0
def tune_and_evaluate():
    print("Begin tuning...")
    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tune_option = auto_scheduler.TuningOptions(
        num_measure_trials=
        200,  # change this to 20000 to achieve the best performance
        builder=auto_scheduler.LocalBuilder(
            build_func="ndk" if use_ndk else "default"),
        runner=auto_scheduler.RPCRunner(device_key,
                                        host="0.0.0.0",
                                        port=9190,
                                        repeat=3,
                                        timeout=50),
        measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    )

    tuner.tune(tune_option)

    # Compile the whole network
    print("Compile...")
    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(
                opt_level=3, config={"relay.backend.use_auto_scheduler":
                                     True}):
            lib = relay.build(mod,
                              target=target,
                              target_host=target_host,
                              params=params)

    # Create graph runtime
    print("=============== Request Remote ===============")
    from tvm.auto_scheduler.utils import request_remote

    remote = request_remote(device_key, "0.0.0.0", 9190)
    ctx = remote.cl()
    from tvm.contrib import utils, ndk

    temp = utils.tempdir()
    filename = "deploy_lib.so"
    path_lib = temp.relpath(filename)
    lib.export_library(path_lib, ndk.create_shared)
    remote.upload(path_lib)
    loaded_lib = remote.load_module(filename)
    module = graph_runtime.GraphModule(loaded_lib["default"](ctx))
    data = (np.random.uniform(size=input_shape)).astype(dtype)
    data_tvm = tvm.nd.array(data)
    module.set_input("data", data_tvm)

    # Evaluate
    print("Evaluate inference time cost...")
    ftimer = module.module.time_evaluator("run",
                                          ctx,
                                          repeat=3,
                                          min_repeat_ms=500)
    prof_res = np.array(ftimer().results) * 1e3  # convert to millisecond
    print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
          (np.mean(prof_res), np.std(prof_res)))
def tune_and_check(mod, data, weight):
    # Extract tasks from a relay program
    target = tvm.target.Target("llvm")
    tasks, task_weights = auto_scheduler.extract_tasks(
        mod, target=target, params={"weight": weight})

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        # Tune tasks
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights, callbacks=[])
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=1,
            num_measures_per_round=1,
            builder=auto_scheduler.LocalBuilder(timeout=60),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        tuner.tune(tune_option, search_policy="sketch.random")

        # Compile
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True},
            ):
                lib = relay.build(mod,
                                  target=target,
                                  params={"weight": weight})

        # Compile without auto-scheduler for correctness check
        with tvm.transform.PassContext(opt_level=0):
            lib2 = relay.build(mod, target=target, params={"weight": weight})

        def get_output(data, lib):
            dev = tvm.cpu()
            module = graph_executor.GraphModule(lib["default"](dev))
            module.set_input("data", data)
            module.run()

            return module.get_output(0).numpy()

        # Check correctness
        actual_output = get_output(data, lib)
        expected_output = get_output(data, lib2)

        tvm.testing.assert_allclose(actual_output,
                                    expected_output,
                                    rtol=1e-4,
                                    atol=2e-4)
コード例 #8
0
def test_measure_local_builder_runner():
    if not tvm.runtime.enabled("llvm"):
        return

    dag, s0 = get_tiled_matmul()
    tgt = tvm.target.create("llvm")
    task = auto_scheduler.SearchTask(dag, "test", tgt)

    minp = auto_scheduler.MeasureInput(task, s0)
    local_builder = auto_scheduler.LocalBuilder()
    local_runner = auto_scheduler.LocalRunner(timeout=60)

    bress = local_builder.build([minp])
    assert bress[0].error_no == 0
    mress = local_runner.run([minp], bress)
    assert mress[0].error_no == 0
コード例 #9
0
def test_measure_local_builder_runner(enable_cpu_cache_flush=False):
    if not tvm.testing.device_enabled("llvm"):
        return

    dag, s0 = get_tiled_matmul()
    tgt = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(dag, "test", tgt)

    minp = auto_scheduler.MeasureInput(task, s0)
    local_builder = auto_scheduler.LocalBuilder()
    local_runner = auto_scheduler.LocalRunner(
        timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)

    bress = local_builder.build([minp])
    assert bress[0].error_no == 0
    mress = local_runner.run([minp], bress)
    assert mress[0].error_no == 0
コード例 #10
0
def test_measure_local_builder_runner():
    if not tvm.testing.device_enabled("llvm"):
        return

    task = auto_scheduler.create_task(matmul_auto_scheduler_test, [512, 512, 512], "llvm")

    for enable_cpu_cache_flush in [True, False]:
        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
        local_builder = auto_scheduler.LocalBuilder()
        local_runner = auto_scheduler.LocalRunner(
            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush
        )

        bress = local_builder.build([minp])
        assert bress[0].error_no == 0
        mress = local_runner.run([minp], bress)
        assert mress[0].error_no == 0
コード例 #11
0
def test_check_auto_schedule_tuning(host, port):  # pylint: disable=too-many-locals
    log_file = TEMPORARY_DIRECTORY.relpath("ios_tuning_stat.log")
    target = tvm.target.Target(target=f"llvm -mtriple={ARCH}-apple-darwin")
    mod, params = relay.testing.mlp.get_workload(batch_size=4,
                                                 image_shape=(1, 4, 4))

    try:
        status_ok = True
        measure_runner = auto_scheduler.RPCRunner(
            DEVICE_KEY,
            host,
            port,
            min_repeat_ms=1,
            timeout=10,
            n_parallel=multiprocessing.cpu_count(),
        )
        builder = auto_scheduler.LocalBuilder(timeout=10,
                                              build_func=ios_create_dylib)
        tune_option = auto_scheduler.TuningOptions(
            builder=builder,
            num_measure_trials=2,
            num_measures_per_round=1,
            runner=measure_runner,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
            verbose=0,
        )

        tasks, task_weights = auto_scheduler.extract_tasks(
            mod["main"], params, target)
        tasks, task_weights = tasks[:2], task_weights[:2]
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
        tuner.tune(tune_option, search_policy="sketch.random")

        # Check tuning log
        tuning_statistic = list(load_records(log_file))
        for _, measure_result in tuning_statistic:
            if measure_result.error_no != MeasureErrorNo.NO_ERROR:
                raise ValueError(
                    f"Error for MeasureResult. Error code: {measure_result.error_no},"
                    f" for details see MeasureErrorNO.")

    except Exception as e:  # pylint: disable=broad-except
        status_ok = False
        print(e)

    assert status_ok, "Tuning failed, see logs."
コード例 #12
0
def test_measure_local_builder_rpc_runner():
    if not tvm.testing.device_enabled("llvm"):
        return

    task = auto_scheduler.SearchTask(func=matmul_auto_scheduler_test,
                                     args=(512, 512, 512),
                                     target="llvm")

    for enable_cpu_cache_flush in [True, False]:
        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
        local_builder = auto_scheduler.LocalBuilder()
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(
            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)
        rpc_runner = measure_ctx.runner

        bress = local_builder.build([minp])
        assert bress[0].error_no == 0
        mress = rpc_runner.run([minp], bress)
        assert mress[0].error_no == 0

        del measure_ctx
コード例 #13
0
def test_dag_measure_local_builder_runner():
    if not tvm.testing.device_enabled("llvm"):
        return

    A = te.placeholder((512, 512), name="A")
    B = te.placeholder((512, 512), name="B")
    k = te.reduce_axis((0, 512), name="k")
    C = te.compute((512, 512),
                   lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]),
                   name="C")
    D = topi.nn.relu(C)
    E = topi.nn.relu(D)

    tensors = [A, B, E]
    dag = auto_scheduler.ComputeDAG(tensors)
    key = workload_registry.register_workload_tensors(dag.workload_key(),
                                                      tensors)
    transfer_data = workload_registry.serialize_workload_registry_entry(key)
    f_data = pickle.dumps(transfer_data)
    f_new = pickle.loads(f_data)
    del workload_registry.WORKLOAD_FUNC_REGISTRY[key]
    workload_registry.deserialize_workload_registry_entry(f_new)

    target = tvm.target.Target("llvm")
    task = auto_scheduler.SearchTask(compute_dag=dag,
                                     workload_key=key,
                                     target=target)

    for enable_cpu_cache_flush in [True, False]:
        minp = auto_scheduler.MeasureInput(task, task.compute_dag.init_state)
        local_builder = auto_scheduler.LocalBuilder()
        local_runner = auto_scheduler.LocalRunner(
            timeout=60, enable_cpu_cache_flush=enable_cpu_cache_flush)

        bress = local_builder.build([minp])
        assert bress[0].error_no == 0
        mress = local_runner.run([minp], bress)
        assert mress[0].error_no == 0
コード例 #14
0
ファイル: main.py プロジェクト: irvingzhang0512/tvm_examples
    def remote_auto_scheduler(self, device_key, rpc_host, rpc_port):
        # generate tasks
        tasks, task_weights = auto_scheduler.extract_tasks(
            self.mod["main"], self.params, self.target)
        for idx, task in enumerate(tasks):
            logger.debug("========== Task %d  (workload key: %s) ==========" %
                         (idx, task.workload_key))
            logger.debug(task.compute_dag)

        # generate tuner
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)

        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=200,
            builder=auto_scheduler.LocalBuilder(),
            runner=auto_scheduler.RPCRunner(
                device_key,
                host=rpc_host,
                port=rpc_port,
                timeout=30,
                repeat=1,
                min_repeat_ms=200,
                enable_cpu_cache_flush=True,
            ),
            measure_callbacks=[auto_scheduler.RecordToFile(self.log_file)],
        )
        tuner.tune(tune_option)

        # update self.lib
        with auto_scheduler.ApplyHistoryBest(self.log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True}):
                self._lib = relay.build(self.mod,
                                        target=self.target,
                                        params=self.params)
            logger.info(f"load optimized library from {self.log_file}")
コード例 #15
0
def tune_network(network, target):
    auto_scheduler.enable_relay_integration()

    # Extract tasks
    mod, params = get_network(network)
    target = tvm.target.Target(target)
    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params,
                                                       target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        # Tuning
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60)
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=100,
            num_measures_per_round=2,
            early_stopping=1,
            runner=measure_ctx.runner,
            builder=auto_scheduler.LocalBuilder(timeout=60),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        tuner.tune(tune_option, search_policy="sketch.random")
        del measure_ctx

        # Compile with the history best
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(opt_level=3):
                lib = relay.build(mod, target=target, params=params)

    # Todo(merrymercy): when the cpu backend is upstreamed, do the following things:
    # 1. compile without history to test the fallback mechanism
    # 2. check the correctness of layout rewrite / winograd pre-transform

    auto_scheduler.enable_relay_integration(False)
コード例 #16
0
def tune_network(network, target):
    # Extract tasks
    mod, params = get_network(network)
    target = tvm.target.Target(target)
    tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target)

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        # Tuning
        measure_ctx = auto_scheduler.LocalRPCMeasureContext(timeout=60, device=0)
        tuner = auto_scheduler.TaskScheduler(tasks, task_weights, callbacks=[])
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=100,
            num_measures_per_round=2,
            early_stopping=1,
            runner=measure_ctx.runner,
            builder=auto_scheduler.LocalBuilder(timeout=60),
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        tuner.tune(tune_option, search_policy="sketch.random")
        del measure_ctx

        # Compile with the history best
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                opt_level=3, config={"relay.backend.use_auto_scheduler": True}
            ):
                lib = relay.build(mod, target=target, params=params)

        # Also test that multiple log files can be loaded.
        with auto_scheduler.ApplyHistoryBest([log_file, log_file]) as best:
            assert isinstance(
                best, auto_scheduler.dispatcher.ApplyHistoryBest
            ), "Unable to load multiple log files jointly."

        # Confirm iterables can be directly loaded.
        loaded_recs = auto_scheduler.dispatcher.load_records(log_file)
        with auto_scheduler.ApplyHistoryBest(iter(loaded_recs)) as best:
            assert isinstance(
                best, auto_scheduler.dispatcher.ApplyHistoryBest
            ), "Unable to ingest logs from an interator."

        # Sample a schedule when missing
        with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
            with tvm.transform.PassContext(
                opt_level=3, config={"relay.backend.use_auto_scheduler": True}
            ):
                lib2 = relay.build(mod, target=target, params=params)

        # Compile without auto-scheduler and any other optimization for correctness check
        with tvm.transform.PassContext(opt_level=0):
            ref_lib = relay.build(mod, target=target, params=params)

        # Check the correctness
        def get_output(data, lib):
            dev = tvm.cuda()
            module = graph_executor.GraphModule(lib["default"](dev))
            module.set_input("data", data)
            module.run()
            return module.get_output(0).numpy()

        np.random.seed(0)
        if network == "mlp":
            data = np.random.uniform(size=(1, 32))
        elif network == "winograd-test":
            data = np.random.uniform(size=(1, 23, 40, 32))
        else:
            raise ValueError("Unknown network: " + network)

        actual_output1 = get_output(data, lib)
        actual_output2 = get_output(data, lib2)
        expected_output = get_output(data, ref_lib)

        tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4)
        tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4)
コード例 #17
0
#
# * :code:`num_measure_trials` is the number of measurement trials we can use during the search.
#   We only make 10 trials in this tutorial for a fast demonstration. In practice, 1000 is a
#   good value for the search to converge. You can do more trials according to your time budget.
# * In addition, we use :code:`RecordToFile` to dump measurement records into a file
#   `sparse_dense.json`.
#   The measurement records can be used to query the history best, resume the search,
#   and do more analyses later.
# * see :any:`auto_scheduler.TuningOptions` for more parameters
# * Here, we need to create a :code:`auto_scheduler.SketchPolicy` object, and add the custom sketch
#   rule as a `init_search_callbacks`.

log_file = "spmm_mali.json"
tune_option = auto_scheduler.TuningOptions(
    num_measure_trials=1000,
    builder=auto_scheduler.LocalBuilder(build_func="default"),
    runner=auto_scheduler.RPCRunner(
        device_key,
        host=rpc_host,
        port=rpc_port,
        timeout=30,
        repeat=1,
        min_repeat_ms=200,
        enable_cpu_cache_flush=True,
    ),
    measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
    verbose=2,
)

search_policy = auto_scheduler.SketchPolicy(
    task,