def test_meta_schedule_runner_matmul_test():
    """Test meta schedule runner with add module"""
    def _check_correct_matmul(
        args_before: List[np.ndarray],
        args_after: List[np.ndarray],
    ) -> None:
        a_before, b_before, c_before = args_before
        a_after, b_after, c_after = args_after
        c_before = np.matmul(a_before, b_before)
        assert (a_before == a_after).all()
        assert (b_before == b_after).all()
        tvm.testing.assert_allclose(c_before, c_after, rtol=1e-5)

    def test_alloc_argument(
        session: RPCSession,
        device: Device,
        args_info: Any,
        alloc_repeat: int,
    ) -> List[Any]:
        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
        repeated_args_before = []  # type: ignore
        repeated_args = rpc_default_alloc_argument(session, device, args_info,
                                                   alloc_repeat)
        for args in repeated_args:
            repeated_args_before.append([arg.numpy()
                                         for arg in args])  # type: ignore
        return repeated_args

    def test_run_evaluator(
        session: RPCSession,  # pylint: disable=unused-argument
        rt_mod: Module,
        device: Device,
        evaluator_config: EvaluatorConfig,
        repeated_args: List[Any],
    ) -> List[float]:
        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
        repeated_args_after = []
        evaluator = rt_mod.time_evaluator(
            func_name=rt_mod.entry_name,
            dev=device,
            number=evaluator_config.number,
            repeat=evaluator_config.repeat,
            min_repeat_ms=evaluator_config.min_repeat_ms,
            f_preproc="cache_flush_cpu_non_first_arg"
            if evaluator_config.enable_cpu_cache_flush else "",
        )
        repeated_costs: List[List[float]] = []
        for args in repeated_args:
            device.sync()
            profile_result = evaluator(*args)
            repeated_costs.append(profile_result.results)
            repeated_args_after.append([arg.numpy() for arg in args])
        costs = [
            float(cost)
            for cost in itertools.chain.from_iterable(repeated_costs)
        ]
        for args_before, args_after in zip(
                repeated_args_before,  # type: ignore
                repeated_args_after,
        ):
            _check_correct_matmul(args_before, args_after)
        del repeated_args_before  # type: ignore
        return costs

    # Build the module
    mod = MatmulModule
    builder = LocalBuilder()
    (builder_result, ) = builder.build([BuilderInput(mod, Target("llvm"))])
    assert builder_result.artifact_path is not None
    assert builder_result.error_msg is None

    runner_input = RunnerInput(
        builder_result.artifact_path,
        "llvm",
        [
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
        ],
    )

    with LocalRPC() as rpc:
        rpc_config = RPCConfig(
            tracker_host=rpc.tracker_host,
            tracker_port=rpc.tracker_port,
            tracker_key=rpc.tracker_key,
            session_priority=1,
            session_timeout_sec=100,
        )
        evaluator_config = EvaluatorConfig(
            number=1,
            repeat=1,
            min_repeat_ms=0,
            enable_cpu_cache_flush=False,
        )
        runner = RPCRunner(
            rpc_config,
            evaluator_config,
            f_alloc_argument=test_alloc_argument,
            f_run_evaluator=test_run_evaluator,
        )
        # Run the module
        (runner_future, ) = runner.run([runner_input])
        runner_result = runner_future.result()
    assert runner_result.error_msg is None
    for result in runner_result.run_secs:
        if isinstance(result, FloatImm):
            result = result.value
        assert isinstance(result, float)
        assert result >= 0.0
    _clean_build(builder_result.artifact_path)
def test_meta_schedule_local_runner_add_test():
    """Test meta schedule local runner with add module"""
    def _check_correct_add(args_before: List[np.array],
                           args_after: List[np.array]) -> None:
        a_before, b_before, c_before = args_before
        a_after, b_after, c_after = args_after
        c_before = a_before + b_before
        assert (a_before == a_after).all()
        assert (b_before == b_after).all()
        assert (c_before == c_after).all()

    def test_alloc_argument(
        device: Device,
        args_info: T_ARG_INFO_JSON_OBJ_LIST,  # pylint: disable=unused-argument
        alloc_repeat: int,
    ) -> List[T_ARGUMENT_LIST]:
        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
        repeated_args_before = []
        repeated_args = local_default_alloc_argument(device, args_info,
                                                     alloc_repeat)
        for args in repeated_args:
            repeated_args_before.append([arg.asnumpy() for arg in args])
        return repeated_args

    def test_run_evaluator(
        rt_mod: Module,
        device: Device,
        evaluator_config: EvaluatorConfig,
        repeated_args: List[Any],
    ) -> List[float]:
        global repeated_args_before  # pylint: disable=global-variable-undefined, invalid-name
        repeated_args_after = []
        evaluator = rt_mod.time_evaluator(
            func_name=rt_mod.entry_name,
            dev=device,
            number=evaluator_config.number,
            repeat=evaluator_config.repeat,
            min_repeat_ms=evaluator_config.min_repeat_ms,
            f_preproc="cache_flush_cpu_non_first_arg"
            if evaluator_config.enable_cpu_cache_flush else "",
        )
        repeated_costs: List[List[float]] = []
        for args in repeated_args:
            device.sync()
            profile_result = evaluator(*args)
            repeated_costs.append(profile_result.results)
            repeated_args_after.append([arg.asnumpy() for arg in args])
        costs = [
            float(cost)
            for cost in itertools.chain.from_iterable(repeated_costs)
        ]
        for args_before, args_after in zip(repeated_args_before,
                                           repeated_args_after):
            _check_correct_add(args_before, args_after)
        del repeated_args_before
        return costs

    # Build the module
    mod = AddModule
    builder = LocalBuilder()
    (builder_result, ) = builder.build([BuilderInput(mod, Target("llvm"))])
    assert builder_result.artifact_path is not None
    assert builder_result.error_msg is None

    runner_input = RunnerInput(
        builder_result.artifact_path,
        "llvm",
        [
            TensorInfo("float32", [MATMUL_M]),
            TensorInfo("float32", [MATMUL_M]),
            TensorInfo("float32", [MATMUL_M]),
        ],
    )

    evaluator_config = EvaluatorConfig(
        number=1,
        repeat=1,
        min_repeat_ms=0,
        enable_cpu_cache_flush=False,
    )
    runner = LocalRunner(
        timeout_sec=100,
        evaluator_config=evaluator_config,
        f_alloc_argument=test_alloc_argument,
        f_run_evaluator=test_run_evaluator,
    )
    # Run the module
    (runner_future, ) = runner.run([runner_input])
    runner_result = runner_future.result()
    assert runner_result.error_msg is None
    for result in runner_result.run_secs:
        if isinstance(result, FloatImm):
            result = result.value
        assert isinstance(result, float)
        assert result >= 0.0
    _clean_build(builder_result.artifact_path)
def test_meta_schedule_local_multiple_runs():
    """Test meta schedule local runner for multiple runs"""
    # Build the module
    mods = [
        MatmulModule,
        MatmulReluModule,
        BatchMatmulModule,
    ]
    builder = LocalBuilder()
    builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods]
    builder_results = builder.build(builder_inputs)
    for builder_result in builder_results:
        assert builder_result.artifact_path is not None
        assert builder_result.error_msg is None

    args_infos = [
        [
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
        ],
        [
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
            TensorInfo("float32", (MATMUL_N, MATMUL_N)),
        ],
        [
            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
            TensorInfo("float32", [16, MATMUL_M, MATMUL_M]),
        ],
    ]

    runner_inputs = [
        RunnerInput(builder_results[i].artifact_path, "llvm", args_infos[i])
        for i in range(len(mods))
    ]

    evaluator_config = EvaluatorConfig(
        number=1,
        repeat=1,
        min_repeat_ms=0,
        enable_cpu_cache_flush=False,
    )

    runner = LocalRunner(timeout_sec=100, evaluator_config=evaluator_config)

    # Run the module
    runner_futures = runner.run(runner_inputs)
    runner_results = [
        runner_future.result() for runner_future in runner_futures
    ]

    for runner_result in runner_results:
        assert runner_result.error_msg is None
        for result in runner_result.run_secs:
            if isinstance(result, FloatImm):
                result = result.value
            assert isinstance(result, float)
            assert result >= 0.0

    for builder_result in builder_results:
        _clean_build(builder_result.artifact_path)