Ejemplo n.º 1
0
def main():
    model_cache_dir = args.model_cache_dir
    try:
        os.makedirs(model_cache_dir, exist_ok=True)
    except OSError:
        print(f"Directory {model_cache_dir} cannot be created successfully.")
    keys = _build_dataset()
    for name, input_shape in tqdm(keys):
        get_network(name=name, input_shape=input_shape, cache_dir=model_cache_dir)
Ejemplo n.º 2
0
def test_meta_schedule_integration_extract_from_resnet():
    mod, params, _ = get_network(name="resnet_18",
                                 input_shape=[1, 3, 224, 224])
    extracted_tasks = ms.integration.extract_task_from_relay(mod,
                                                             target="llvm",
                                                             params=params)
    expected_task_names = [
        "vm_mod_fused_" + s for s in [
            "nn_max_pool2d",
            "nn_adaptive_avg_pool2d",
            "nn_dense_add",
            "nn_conv2d_add",
            "nn_conv2d_add_1",
            "nn_conv2d_add_2",
            "nn_conv2d_add_add_nn_relu",
            "nn_conv2d_add_add_nn_relu_1",
            "nn_conv2d_add_nn_relu",
            "nn_conv2d_add_nn_relu_1",
            "nn_conv2d_add_nn_relu_2",
            "nn_conv2d_add_nn_relu_3",
            "nn_conv2d_add_nn_relu_4",
            "nn_conv2d_add_nn_relu_5",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
            # The two tasks below are purely spatial and are ruled out by AutoScheduler
            "layout_transform",
            "layout_transform_reshape_squeeze",
        ]
    ]

    assert len(extracted_tasks) == 20
    for t in extracted_tasks:
        assert t.task_name in expected_task_names, t.task_name
Ejemplo n.º 3
0
def test_relay_model(model_name: str, input_shape: List[int], use_trt: bool):
    mod, params, _ = get_network(model_name, input_shape)
    verify_meta_schedule_with_tensorrt(
        mod,
        params,
        input_shape,
        use_trt,
    )
Ejemplo n.º 4
0
def test_meta_schedule_integration_task_extraction_query():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    env = TaskExtraction()
    env.query(task_name="mock-task",
              mod=mod,
              target=Target("llvm"),
              dispatched=[MockModule])
    _check_mock_task(env.tasks, mod)
Ejemplo n.º 5
0
def main():
    describe()
    print(f"Workload: {ARGS.workload}")

    mod, params, (input_name, input_shape, input_dtype) = get_network(
        ARGS.workload,
        ARGS.input_shape,
        cache_dir=ARGS.cache_dir,
    )
    input_info = {input_name: input_shape}
    input_data = {
        item["name"]: generate_input_data(item["shape"], item["dtype"])
        for item in ARGS.input_shape
    }
    for input_name, input_shape in input_info.items():
        print(f"  input_name : {input_name}")
        print(f"  input_shape: {input_shape}")
        print(f"  input_dtype: {input_dtype}")

    runner = ms.runner.RPCRunner(
        rpc_config=ARGS.rpc_config,
        evaluator_config=ms.runner.EvaluatorConfig(
            number=ARGS.number,
            repeat=ARGS.repeat,
            min_repeat_ms=ARGS.min_repeat_ms,
            enable_cpu_cache_flush=ARGS.cpu_flush,
        ),
        alloc_repeat=1,
    )

    with ms.Profiler() as profiler:
        lib = ms.tune_relay(
            mod=mod,
            target=ARGS.target,
            config=ms.TuneConfig(
                strategy="evolutionary",
                num_trials_per_iter=64,
                max_trials_per_task=ARGS.num_trials,
                max_trials_global=ARGS.num_trials,
                adaptive_training=ARGS.adaptive_training,
            ),
            runner=runner,  # type: ignore
            work_dir=ARGS.work_dir,
            params=params,
            backend=ARGS.backend,
        )

    print("Tuning Time:")
    print(profiler.table())

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=lib,
        dev_type=ARGS.target.kind.name,
        args=input_data,
        continuation=create_timer(ARGS.backend),
        backend=ARGS.backend,
    )
Ejemplo n.º 6
0
def test_meta_schedule_integration_query_inside_with_scope():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    env = TaskExtraction()
    with env:
        MetaScheduleContext.query_inside_with_scope(
            task_name="mock-task",
            mod=mod,
            target=Target("llvm"),
            dispatched=[MockModule],
        )
    _check_mock_task(env.tasks, mod)
Ejemplo n.º 7
0
def test_meta_schedule_integration_apply_history_best():
    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, [])
    )
    mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule])
    assert tvm.ir.structural_equal(mod, workload.mod)
def test_meta_schedule_tune_relay(
    model_name: str,
    input_shape: List[int],
    target: str,
):
    dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda()
    if model_name.startswith("bert"):
        data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape),
                            dev)  # embedding size
    else:
        data = tvm.nd.array(
            np.random.randn(*input_shape).astype("float32"), dev)

    mod, params, (input_name, _, _) = get_network(name=model_name,
                                                  input_shape=input_shape)
    target = Target(target)
    with tempfile.TemporaryDirectory() as work_dir:
        rt_mod1: tvm.runtime.Module = tune_relay(
            mod=mod,
            params=params,
            target=target,
            config=TuneConfig(
                strategy="evolutionary",
                num_trials_per_iter=32,
                max_trials_per_task=20000,
                max_trials_global=20000,
                search_strategy_config={
                    "genetic_num_iters": 10,
                },
            ),
            work_dir=work_dir,
            database=JSONDatabase(
                osp.join(work_dir, "workload.json"),
                osp.join(work_dir, "records.json"),
            ),
        )
        # Compile without meta-scheduler for correctness check
        with tvm.transform.PassContext(opt_level=0):
            rt_mod2 = relay.build(mod, target=target, params=params)

        def get_output(data, lib):
            module = graph_executor.GraphModule(lib["default"](dev))
            module.set_input(input_name, data)
            module.run()
            return module.get_output(0).numpy()

        # Check correctness
        actual_output = get_output(data, rt_mod1)
        expected_output = get_output(data, rt_mod2)
        assert np.allclose(actual_output,
                           expected_output,
                           rtol=1e-4,
                           atol=2e-4)
Ejemplo n.º 9
0
def test_meta_schedule_integration_apply_history_best():
    @derived_object
    class DummyDatabase(PyDatabase):
        def __init__(self):
            super().__init__()
            self.records = []
            self.workload_reg = []

        def has_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return True
            return False

        def commit_tuning_record(self, record: TuningRecord) -> None:
            self.records.append(record)

        def commit_workload(self, mod: IRModule) -> Workload:
            for workload in self.workload_reg:
                if tvm.ir.structural_equal(workload.mod, mod):
                    return workload
            workload = Workload(mod)
            self.workload_reg.append(workload)
            return workload

        def get_top_k(self, workload: Workload,
                      top_k: int) -> List[TuningRecord]:
            return list(
                filter(
                    lambda x: x.workload == workload,
                    sorted(self.records,
                           key=lambda x: sum(x.run_secs) / len(x.run_secs)),
                ))[:int(top_k)]

        def __len__(self) -> int:
            return len(self.records)

        def print_results(self) -> None:
            print("\n".join([str(r) for r in self.records]))

    mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
    database = DummyDatabase()
    env = ApplyHistoryBest(database)
    target = Target("llvm")
    workload = database.commit_workload(MockModule)
    database.commit_tuning_record(
        TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []))
    mod = env.query(task_name="mock-task",
                    mod=mod,
                    target=target,
                    dispatched=[MockModule])
    mod = IRModule({"main": mod})
    assert tvm.ir.structural_equal(mod, workload.mod)
Ejemplo n.º 10
0
def test_meta_schedule_tune_relay(
    model_name: str,
    input_shape: List[int],
    target: str,
):
    dev = tvm.cpu() if str(target).startswith("llvm") else tvm.cuda()
    if model_name.startswith("bert"):
        data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape),
                            dev)  # embedding size
    else:
        data = tvm.nd.array(
            np.random.randn(*input_shape).astype("float32"), dev)

    mod, params, (input_name, _, _) = get_network(name=model_name,
                                                  input_shape=input_shape)
    target = Target(target)
    with tempfile.TemporaryDirectory() as work_dir:
        database = DummyDatabase()
        rt_mod: tvm.runtime.Module = tune_relay(
            mod=mod,
            params=params,
            target=target,
            config=ReplayTraceConfig(
                num_trials_per_iter=32,
                num_trials_total=32,
            ),
            work_dir=work_dir,
            database=database,
        )
        # Compile without meta-scheduler for correctness check
        with tvm.transform.PassContext(opt_level=0):
            rt_mod2 = relay.build(mod, target=Target("llvm"), params=params)

        def get_output(data, lib):
            module = graph_executor.GraphModule(lib["default"](dev))
            module.set_input(input_name, data)
            module.run()
            return module.get_output(0).numpy()

        # Check correctness
        actual_output = get_output(data, rt_mod)
        expected_output = get_output(
            tvm.nd.array(data.numpy(), device=tvm.cpu()), rt_mod2)
        assert np.allclose(actual_output,
                           expected_output,
                           rtol=1e-4,
                           atol=2e-4)
Ejemplo n.º 11
0
def main():
    log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")

    runner = auto_scheduler.RPCRunner(
        key=ARGS.rpc_key,
        host=ARGS.rpc_host,
        port=ARGS.rpc_port,
        n_parallel=cpu_count(logical=True),
        number=ARGS.number,
        repeat=ARGS.repeat,
        min_repeat_ms=ARGS.min_repeat_ms,
        enable_cpu_cache_flush=ARGS.cpu_flush,
        timeout=ARGS.rpc_config.session_timeout_sec,
    )

    if ARGS.target.kind.name == "llvm":
        hardware_params = auto_scheduler.HardwareParams(
            num_cores=int(ARGS.target.attrs["num-cores"]),
            target=ARGS.target,
        )
    elif ARGS.target.kind.name == "cuda":
        hardware_params = auto_scheduler.HardwareParams(
            num_cores=-1,
            vector_unit_bytes=16,
            cache_line_bytes=64,
            max_shared_memory_per_block=int(
                ARGS.target.attrs["max_shared_memory_per_block"]),
            max_threads_per_block=int(
                ARGS.target.attrs["max_threads_per_block"]),
            # The value `max_local_memory_per_block` is not used in AutoScheduler,
            # but is required by the API.
            max_local_memory_per_block=12345678,
            max_vthread_extent=8,
            warp_size=32,
        )
    else:
        raise NotImplementedError(f"Unsupported target {ARGS.target}")

    describe()
    print(f"Workload: {ARGS.workload}")
    mod, params, (input_name, input_shape, input_dtype) = get_network(
        ARGS.workload,
        ARGS.input_shape,
        cache_dir=ARGS.cache_dir,
    )
    input_info = {input_name: input_shape}
    input_data = {
        item["name"]: generate_input_data(item["shape"], item["dtype"])
        for item in ARGS.input_shape
    }
    for input_name, input_shape in input_info.items():
        print(f"  input_name : {input_name}")
        print(f"  input_shape: {input_shape}")
        print(f"  input_dtype: {input_dtype}")

    with ms.Profiler() as profiler:
        tasks, task_weights = auto_scheduler.extract_tasks(
            mod["main"],
            params,
            target=ARGS.target,
            hardware_params=hardware_params,
        )
        for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
            print(f"==== Task {idx}: {task.desc} "
                  f"(weight {task_weight} key: {task.workload_key}) =====")
            print(task.compute_dag)

        if ARGS.num_trials > 0:
            tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
            tuner.tune(
                auto_scheduler.TuningOptions(
                    num_measure_trials=ARGS.num_trials,
                    runner=runner,
                    measure_callbacks=[
                        auto_scheduler.RecordToFile(log_file),
                    ],
                ),
                adaptive_training=ARGS.adaptive_training,
            )

        relay_build = {
            "graph": relay.build,
            "vm": relay.vm.compile
        }[ARGS.backend]
        with auto_scheduler.ApplyHistoryBest(log_file):
            with tvm.transform.PassContext(
                    opt_level=3,
                    config={"relay.backend.use_auto_scheduler": True},
            ):
                lib = relay_build(
                    mod,
                    target=ARGS.target,
                    params=params,
                )
    print("Tuning Time:")
    print(profiler.table())

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=lib,
        dev_type=ARGS.target.kind.name,
        args=input_data,
        continuation=create_timer(ARGS.backend),
        backend=ARGS.backend,
    )
Ejemplo n.º 12
0
def main():
    mod, params, (input_name, input_shape, input_dtype) = get_network(
        ARGS.workload,
        ARGS.input_shape,
        cache_dir=ARGS.cache_dir,
    )
    print(f"Workload: {ARGS.workload}")
    print(f"  input_name: {input_name}")
    print(f"  input_shape: {input_shape}")
    print(f"  input_dtype: {input_dtype}")
    alloc_repeat = 1
    runner = ms.runner.RPCRunner(
        rpc_config=ARGS.rpc_config,
        evaluator_config=ms.runner.EvaluatorConfig(
            number=3,
            repeat=1,
            min_repeat_ms=100,
            enable_cpu_cache_flush=False,
        ),
        alloc_repeat=alloc_repeat,
        max_workers=ARGS.rpc_workers,
    )
    lib = ms.tune_relay(
        mod=mod,
        target=ARGS.target,
        config=ms.TuneConfig(
            strategy="evolutionary",
            num_trials_per_iter=64,
            max_trials_per_task=ARGS.num_trials,
            max_trials_global=ARGS.num_trials,
        ),
        runner=runner,  # type: ignore
        work_dir=ARGS.work_dir,
        params=params,
    )
    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
    if input_dtype.startswith("float"):
        input_data = np.random.uniform(size=input_shape).astype(input_dtype)
    else:
        input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)

    def f_timer(rt_mod, dev, input_data):
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.graph_executor import GraphModule

        # pylint: enable=import-outside-toplevel

        mod = GraphModule(rt_mod["default"](dev))
        mod.set_input(input_name, input_data)
        ftimer = mod.module.time_evaluator(
            "run",
            dev,
            min_repeat_ms=500,
            repeat=3,
        )
        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
        print("Running time in time_evaluator: ", results)

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=lib,
        dev_type=ARGS.target.kind.name,
        args=[input_data],
        continuation=f_timer,
    )

    def f_per_layer(rt_mod, dev, input_data):
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.debugger.debug_executor import create

        # pylint: enable=import-outside-toplevel
        mod = create(graph, rt_mod, dev)
        mod.set_input(input_name, input_data)
        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
        graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
        print("|graph_nodes| = ", len(graph_nodes))
        print("|graph_time| = ", len(graph_time))
        graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
        for k, v in graph_nodes_time.items():
            print(f"{k} : {v:.3f}")

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=rt_mod,
        dev_type=ARGS.target.kind.name,
        args=[input_data],
        continuation=f_per_layer,
    )
Ejemplo n.º 13
0
def main():
    log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")

    runner = auto_scheduler.RPCRunner(
        key=ARGS.rpc_key,
        host=ARGS.rpc_host,
        port=ARGS.rpc_port,
        n_parallel=ARGS.rpc_workers,
        number=3,
        repeat=1,
        min_repeat_ms=100,  # TODO
        enable_cpu_cache_flush=False,  # TODO
    )

    if ARGS.target.kind.name == "llvm":
        hardware_params = auto_scheduler.HardwareParams(
            num_cores=int(ARGS.target.attrs["num-cores"]),
            target=ARGS.target,
        )
    elif ARGS.target.kind.name == "cuda":
        hardware_params = auto_scheduler.HardwareParams(
            num_cores=-1,
            vector_unit_bytes=16,
            cache_line_bytes=64,
            max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
            max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
            # The value `max_local_memory_per_block` is not used in AutoScheduler,
            # but is required by the API.
            max_local_memory_per_block=12345678,
            max_vthread_extent=8,
            warp_size=32,
        )
    else:
        raise NotImplementedError(f"Unsupported target {ARGS.target}")
    mod, params, (input_name, input_shape, input_dtype) = get_network(
        ARGS.workload,
        ARGS.input_shape,
        cache_dir=ARGS.cache_dir,
    )
    input_info = {input_name: input_shape}
    input_data = {}
    print(f"Workload: {ARGS.workload}")
    for input_name, input_shape in input_info.items():
        print(f"  input_name: {input_name}")
        print(f"  input_shape: {input_shape}")
        print(f"  input_dtype: {input_dtype}")
    tasks, task_weights = auto_scheduler.extract_tasks(
        mod["main"],
        params,
        target=ARGS.target,
        hardware_params=hardware_params,
    )
    for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
        print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
        print(task.compute_dag)

    tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
    tuner.tune(
        auto_scheduler.TuningOptions(
            num_measure_trials=ARGS.num_trials,
            runner=runner,
            measure_callbacks=[
                auto_scheduler.RecordToFile(log_file),
            ],
        )
    )

    with auto_scheduler.ApplyHistoryBest(log_file):
        with tvm.transform.PassContext(
            opt_level=3,
            config={"relay.backend.use_auto_scheduler": True},
        ):
            lib = relay.build(
                mod,
                target=ARGS.target,
                params=params,
            )
    graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
    for input_name, input_shape in input_info.items():
        if input_dtype.startswith("float"):
            input_data[input_name] = np.random.uniform(size=input_shape).astype(input_dtype)
        else:
            input_data[input_name] = np.random.randint(
                low=0, high=10000, size=input_shape, dtype=input_dtype
            )

    def f_timer(rt_mod, dev, input_data):
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.graph_executor import GraphModule

        # pylint: enable=import-outside-toplevel

        mod = GraphModule(rt_mod["default"](dev))
        for input_name, input_value in input_data.items():
            mod.set_input(input_name, input_value)
        ftimer = mod.module.time_evaluator(
            "run",
            dev,
            min_repeat_ms=500,
            repeat=3,
        )
        results = list(np.array(ftimer().results) * 1000.0)  # type: ignore
        print("Running time in time_evaluator: ", results)

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=lib,
        dev_type=ARGS.target.kind.name,
        args=input_data,
        continuation=f_timer,
    )

    def f_per_layer(rt_mod, dev, input_data):
        # pylint: disable=import-outside-toplevel
        from tvm.contrib.debugger.debug_executor import create

        # pylint: enable=import-outside-toplevel
        mod = create(graph, rt_mod, dev)
        for input_name, input_value in input_data.items():
            mod.set_input(input_name, input_value)
        graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
        graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
        print("|graph_nodes| = ", len(graph_nodes))
        print("|graph_time| = ", len(graph_time))
        graph_nodes_time = {k: float(v) for k, v in zip(graph_nodes, graph_time)}
        for k, v in graph_nodes_time.items():
            print(f"{k} : {v:.3f}")

    run_module_via_rpc(
        rpc_config=ARGS.rpc_config,
        lib=rt_mod,
        dev_type=ARGS.target.kind.name,
        args=input_data,
        continuation=f_per_layer,
    )
Ejemplo n.º 14
0
def test_meta_schedule_integration_extract_from_bert_base():
    expected = {
        "fused_nn_dense_2": (
            12,
            [[64, 3072], [768, 3072], [64, 768]],
        ),
        "fused_nn_dense": (
            48,
            [[64, 768], [768, 768], [64, 768]],
        ),
        "fused_nn_dense_1": (
            12,
            [[64, 768], [3072, 768], [64, 3072]],
        ),
        "fused_subtract_add_sqrt_divide_multiply_add": (
            25,
            [[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]],
        ),
        "fused_nn_batch_matmul": (
            24,
            [[12, 64, 64], [12, 64, 64], [12, 64, 64]],
        ),
        "fused_reshape_add_add": (
            24,
            [[64, 768], [768], [1, 64, 768], [1, 64, 768]],
        ),
        "fused_variance": (
            25,
            [[1, 64, 768], [1, 64, 1], [1, 64, 1]],
        ),
        "fused_mean": (
            25,
            [[1, 64, 768], [1, 64, 1]],
        ),
        "fused_reshape_add_reshape_transpose_reshape": (
            12,
            [[64, 768], [768], [12, 64, 64]],
        ),
        "fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": (
            12,
            [[64, 3072], [3072], [64, 3072]],
        ),
        "fused_nn_fast_softmax": (
            12,
            [[1, 12, 64, 64], [1, 12, 64, 64]],
        ),
        "fused_reshape_add_reshape_transpose_reshape_1": (
            24,
            [[64, 768], [768], [12, 64, 64]],
        ),
        "fused_reshape_divide_add": (
            12,
            [[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]],
        ),
        "fused_reshape_transpose_reshape": (
            12,
            [[12, 64, 64], [64, 768]],
        ),
        "fused_nn_dense_add_fast_tanh": (
            1,
            [[1, 768], [768, 768], [1, 768], [1, 768]],
        ),
        "fused_cast_take_add": (
            1,
            [[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]],
        ),
        "fused_take": (
            1,
            [[1, 64, 768], [1, 768]],
        ),
        "fused_reshape": (
            12,
            [[1, 12, 64, 64], [12, 64, 64]],
        ),
        "fused_reshape_1": (
            24,
            [[1, 64, 768], [64, 768]],
        ),
    }
    mod, params, _ = get_network(name="bert_base", input_shape=[1, 64])
    extracted_tasks = ms.extract_task_from_relay(mod,
                                                 target="llvm",
                                                 params=params)
    assert len(extracted_tasks) == len(expected)
    for t in extracted_tasks:
        prim_func = None
        for _, v in t.dispatched[0].functions.items():
            prim_func = v
        shape = [[int(x) for x in prim_func.buffer_map[b].shape]
                 for b in prim_func.params]
        assert t.task_name in expected
        expected_weight, expected_shape = expected[t.task_name]
        assert expected_weight == t.weight, t.task_name
        assert expected_shape == shape, t.task_name
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
    def filter_func(args) -> bool:
        from tvm.te import create_prim_func  # pylint: disable=import-outside-toplevel

        has_complex_op = False
        visited = set()

        def traverse(t):
            nonlocal has_complex_op
            assert t.handle is not None
            if t.handle.value in visited:
                return
            if isinstance(t.op, te.PlaceholderOp):
                pass
            elif isinstance(t.op, te.ComputeOp):
                has_complex_op = has_complex_op or any(
                    isinstance(e, tir.Reduce) for e in t.op.body)
                for x in t.op.input_tensors:
                    traverse(x)
            visited.add(t.handle.value)

        for t in args:
            traverse(t)
        if not has_complex_op:
            return None
        return create_prim_func(args)

    mod, params, _ = get_network(name="resnet_18",
                                 input_shape=[1, 3, 224, 224])
    extracted_tasks = ms.extract_task_from_relay(
        mod,
        target="llvm",
        params=params,
        te_filter_func=filter_func,
    )
    expected_task_names = [
        "fused_" + s for s in [
            "nn_max_pool2d",
            "nn_adaptive_avg_pool2d",
            "nn_dense_add",
            "nn_conv2d_add",
            "nn_conv2d_add_1",
            "nn_conv2d_add_2",
            "nn_conv2d_add_add_nn_relu",
            "nn_conv2d_add_add_nn_relu_1",
            "nn_conv2d_add_nn_relu",
            "nn_conv2d_add_nn_relu_1",
            "nn_conv2d_add_nn_relu_2",
            "nn_conv2d_add_nn_relu_3",
            "nn_conv2d_add_nn_relu_4",
            "nn_conv2d_add_nn_relu_5",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
        ]
    ]

    assert len(extracted_tasks) == len(expected_task_names)
    for t in extracted_tasks:
        assert t.task_name in expected_task_names, t.task_name