Ejemplo n.º 1
0
def test_meta_schedule_integration_extract_from_resnet():
    mod, params, _ = get_network(name="resnet_18",
                                 input_shape=[1, 3, 224, 224])
    extracted_tasks = ms.extract_task_from_relay(mod,
                                                 target="llvm",
                                                 params=params)
    expected_task_names = [
        "fused_" + s for s in [
            "nn_max_pool2d",
            "nn_adaptive_avg_pool2d",
            "nn_dense_add",
            "nn_conv2d_add",
            "nn_conv2d_add_1",
            "nn_conv2d_add_2",
            "nn_conv2d_add_add_nn_relu",
            "nn_conv2d_add_add_nn_relu_1",
            "nn_conv2d_add_nn_relu",
            "nn_conv2d_add_nn_relu_1",
            "nn_conv2d_add_nn_relu_2",
            "nn_conv2d_add_nn_relu_3",
            "nn_conv2d_add_nn_relu_4",
            "nn_conv2d_add_nn_relu_5",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
            # The two tasks below are purely spatial and are ruled out by AutoScheduler
            "layout_transform",
            "layout_transform_reshape_squeeze",
        ]
    ]

    assert len(extracted_tasks) == len(expected_task_names)
    for t in extracted_tasks:
        assert t.task_name in expected_task_names, t.task_name
def extract_task_qbert():
    mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128)
    target = "llvm -mcpu=cascadelake"
    extracted_tasks = ms.extract_task_from_relay(mod, target, params)
    tune_tasks = list(
        filter(
            lambda task: "dense" in task.task_name or "batch_matmul" in task.
            task_name,
            extracted_tasks,
        ))
    # three int8 dense, two int8 bmm, and one fp32 dense
    assert len(tune_tasks) == 6

    for task in tune_tasks:
        relay_func = list(task.mod.functions.values())[0]
        out_type = relay_func.body.checked_type

        if out_type.dtype == "float32":
            continue

        mod = ms.default_config.mod(task.dispatched[0])
        sch = tvm.tir.Schedule(mod)
        block = sch.get_block("compute")
        annotations = sch.get(block).annotations

        assert "schedule_rule" in annotations
        assert "vnni" in annotations["schedule_rule"]
Ejemplo n.º 3
0
def extract_from_relay(
    mod: IRModule,
    target: Target,
    params: Optional[Dict[str, NDArray]],
    name: str,
    input_shape: List[int],
    *,
    cache_dir: Optional[str] = None,
    opt_level: int = 3,
    pass_config: Optional[Dict[str, Any]] = None,
    disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
    """Extract the tasks from a network.

    Parameters
    ----------
    mod : IRModule
        The IRModule representing the network.
    target : Target
        The target that the network will be deployed to.
    params : Optional[Dict[str, NDArray]]
        The parameters of the networks.
    name : str
        The name of the network.
    input_shape : List[int]
        The shape of the input tensor.
    cache_dir : Optional[str]
        The directory to cache the generated network.
        If not specified, the cache will be disabled.
    opt_level : int
        The optimization level of the compiler.
    pass_config : Optional[Dict[str, Any]]
        The pass config of the compiler.
    disabled_pass : Optional[List[str]]
        The disabled pass of the compiler.

    Returns
    -------
    extracted_tasks : List[ExtractedTask]
        The extracted tasks.
    """
    filename = f'tasks-{target.kind.name}-{name}-{",".join(str(i) for i in input_shape)}.json'
    extracted_tasks = _load_cache(cache_dir, filename)
    if extracted_tasks is None:
        extracted_tasks = extract_task_from_relay(
            mod=mod,
            target=target,
            params=params,
            opt_level=opt_level,
            pass_config=pass_config,
            disabled_pass=disabled_pass,
        )
        extracted_tasks = list(extracted_tasks)
        _save_cache(cache_dir, filename, extracted_tasks)
    return extracted_tasks
Ejemplo n.º 4
0
def apply_fixed_schedules(
    relay_mod: Union[RelayFunc, IRModule],
    target: Union[str, Target],
    params: Optional[Dict[str, NDArray]],
    schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool],
    te_filter_func=None,
):
    """Apply fixed schedules (manually written, without any tunable knobs) as specified by
    schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest.

    Parameters
    ----------
    mod : Union[RelayFunc, IRModule]
        The Relay module to apply fixed schedules.
    target : Union[str, Target]
        The target used to extract tasks.
    params : Optional[Dict[str, tvm.runtime.NDArray]]
        The associated parameters of the module.
    schedule_fn : Callable[[ExtractedTask, Schedule], bool]
        A callable that is applied for each extracted task and the corresponding default schedule.
        Returns True if the given schedule should be committed to the database, False otherwise.
    te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
        The filtering function for TE computation
        If it's a string, it's the name of the filtering function. Built in functions are
          - "meta_schedule.DefaultTaskFilter"
          - "meta_schedule.DefaultTaskFilterAllowExtern"
        If it's None, it's the default filtering function
        If it's a callable, it's the filtering function

    Returns
    -------
    database : Database
        The database containing dummy tuning records for manually scheduled traces.
    """
    target = Target(target) if isinstance(target, str) else target
    extracted_tasks = ms.extract_task_from_relay(
        relay_mod,
        target,
        params,
        te_filter_func=te_filter_func,
    )
    database = ms.database.MemoryDatabase()
    for task in extracted_tasks:
        mod = ms.default_config.mod(task.dispatched[0])
        sch = Schedule(mod)

        if schedule_fn(task, sch):
            workload = database.commit_workload(mod)
            tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0],
                                                target, [])
            database.commit_tuning_record(tune_rec)

    return database
Ejemplo n.º 5
0
def tune_each_task(
    mod,
    target,
    config,
    runner,
    work_dir,
    params,
):
    extracted_tasks = ms.extract_task_from_relay(mod, target, params)
    database = ms.database.JSONDatabase(
        path_workload=os.path.join(work_dir, "default_database_workload.json"),
        path_tuning_record=os.path.join(work_dir, "default_database_tuning_record.json"),
    )
    for task in extracted_tasks:
        # pylint: disable=protected-access
        tune_context = ms.tune.Parse._tune_context(
            tune_context=None,
            mod=ms.tune.Parse._mod(task.dispatched[0]),
            target=target,
            config=config,
            task_name=task.task_name,
            space_generator=None,
            sch_rules=None,
            postprocs=None,
            mutator_probs=None,
            num_threads=os.cpu_count(),
        )
        task_scheduler = ms.tune.Parse._task_scheduler(
            None,
            [tune_context],
            task_weights=[1.0],
            builder=ms.tune.Parse._builder(None),
            runner=ms.tune.Parse._runner(runner),
            database=database,
            max_trials=config.max_trials_per_task,
            cost_model=ms.tune.Parse._cost_model(None),
            measure_callbacks=ms.tune.Parse._callbacks(None),
        )
        # pylint: enable=protected-access
        task_scheduler.tune()
    with target, ms.ApplyHistoryBest(database):
        with PassContext(
            opt_level=3,
            config={"relay.backend.use_meta_schedule": True},
        ):
            return relay_build(mod, target=target, params=params)
def test_extract_task_arm_conv2d_nchwc():
    data_shape = (1, 64, 128, 128)
    weight_shape = (32, 64, 1, 1)
    bias_shape = (weight_shape[0], )
    padding = (1, 1)

    data = relay.var("data", shape=data_shape, dtype="int8")
    weight = relay.var("weight", shape=weight_shape, dtype="int8")
    bias = relay.var("bias", shape=bias_shape, dtype="int32")
    conv2d = relay.nn.conv2d(
        data=data,
        weight=weight,
        kernel_size=weight_shape[2:],
        channels=weight_shape[0],
        padding=padding,
        strides=(1, 1),
        out_dtype="int32",
    )
    bias_add = relay.nn.bias_add(conv2d, bias)
    relay_mod = tvm.IRModule.from_expr(bias_add)

    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
    bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32")

    params = {"weight": weight_np, "bias": bias_np}

    target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon"
    extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params)
    tune_tasks = list(
        filter(
            lambda task: "conv2d" in task.task_name,
            extracted_tasks,
        ))

    assert len(tune_tasks) == 1

    relay_func = list(tune_tasks[0].mod.functions.values())[0]
    out_type = relay_func.body.checked_type

    # Check that the output is in NCHWc layout
    assert list(out_type.shape) == [1, 8, 130, 130, 4]
Ejemplo n.º 7
0
def extract_and_save_tasks(cache_file):
    """Extract tuning tasks and cache the nonspatial ones in the given directory.

    Parameters
    ----------
    cache_file : str
        The filename of the cached model.

    Returns
    -------
    None
    """

    mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file)
    params = load_param_dict(params_bytearray)
    try:
        extracted_tasks = ms.extract_task_from_relay(mod,
                                                     target=args.target,
                                                     params=params)
    except tvm.error.TVMError as error:
        print(str(error))
        return
    task_cache_path = os.path.join(
        args.task_cache_dir,
        cache_file.split(".")[0] + "_extracted_tasks.json")
    is_spatial = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc")
    with open(task_cache_path, "w", encoding="utf8") as file:
        for i, task in enumerate(extracted_tasks):
            subgraph = task.dispatched[0]
            prim_func = subgraph[subgraph.get_global_vars()[0]]
            if not is_spatial(prim_func):
                subgraph_str = save_json(subgraph)
                json_obj = [task.task_name, json.loads(subgraph_str)]
                json_str = json.dumps(json_obj)
                assert "\n" not in json_str, "Failed to generate single line string."
                if i == len(extracted_tasks) - 1:
                    file.write(json_str)
                else:
                    file.write(json_str + "\n")
Ejemplo n.º 8
0
def manual_tir_common(do_tune=False):
    M, N, K = 1024, 1024, 1024  # pylint: disable=invalid-name
    data_shape = (M, K)
    weight_shape = (N, K)

    data_dtype = "uint8"
    data = relay.var("data", shape=data_shape, dtype=data_dtype)
    weight = relay.var("weight", shape=weight_shape, dtype="int8")
    bias = relay.var("bias", shape=(weight_shape[0], ), dtype="int32")

    # dense is tuned by the TIR schedule above, bmm is scheduled by TE (topi/x86/batch_matmul.py)
    dense = relay.nn.dense(data, weight, out_dtype="int32")
    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
    out = relay.nn.batch_matmul(
        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
        out_dtype="int32",
    )

    relay_mod = tvm.IRModule.from_expr(out)

    target = "llvm -mcpu=cascadelake -num-cores 4"
    dev = tvm.device(target, 0)

    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
    bias_np = np.random.uniform(1, 10,
                                size=(weight_shape[0], )).astype("int32")

    ref = (relay.create_executor(
        "vm", mod=relay_mod, device=dev,
        target=target).evaluate()(*[data, weight_np, bias_np]).numpy())

    params = {"weight": weight_np, "bias": bias_np}

    if do_tune:
        extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params)
        # Filter out tasks that we don't intend to schedule / tune with TIR.
        tune_tasks = list(
            filter(
                lambda task: "dense" in task.task_name,
                extracted_tasks,
            ))
        config = ms.TuneConfig(
            strategy="replay_trace",
            num_trials_per_iter=64,
            max_trials_per_task=20000,
            max_trials_global=20000,
        )

        with tempfile.TemporaryDirectory() as work_dir:
            # postprocs=lambda: [] is important to prevent default post processors from
            # tampering with the manual schedule.
            database = ms.tune_extracted_tasks(
                tune_tasks,
                config,
                work_dir=work_dir,
                postprocs=lambda: [],
            )
    else:

        def schedule_fn(task, sch):
            if "dense" not in task.task_name:
                return False

            block = sch.get_block("compute")

            # Looks up schedule_rule annotation.
            # See the comment in test_tune_relay_manual_tir_vnni().
            schedule_rule = sch.get(block).annotations["schedule_rule"]

            assert "dense_vnni" in schedule_rule

            schedule_dense(block, M, False, sch)

            return True

        database = apply_fixed_schedules(relay_mod, target, params,
                                         schedule_fn)

    with ms.ApplyHistoryBest(database):
        with tvm.transform.PassContext(
                opt_level=3,
                config={"relay.backend.use_meta_schedule": True},
        ):
            # pylint: disable=W0105
            """
            The log should say
            Warning: Cannot find workload: tvmgen_default_fused_expand_dims
            Warning: Cannot find workload: tvmgen_default_fused_cast
            Warning: Cannot find workload: tvmgen_default_fused_cast_1
            Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul

            This means batch matmul and others are scheduled by TE, and dense (the one not warned)
            is found in the meta schedule tuning database during ApplyHistoryBest
            """
            # pylint: enable=W0105
            lib = relay.build(relay_mod, target=target, params=params)

    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

    runtime.set_input("data", data)
    runtime.run()

    out = runtime.get_output(0).numpy()

    np.testing.assert_equal(out, ref)
Ejemplo n.º 9
0
def test_meta_schedule_integration_extract_from_bert_base():
    expected = {
        "fused_nn_dense_2": (
            12,
            [[64, 3072], [768, 3072], [64, 768]],
        ),
        "fused_nn_dense": (
            48,
            [[64, 768], [768, 768], [64, 768]],
        ),
        "fused_nn_dense_1": (
            12,
            [[64, 768], [3072, 768], [64, 3072]],
        ),
        "fused_subtract_add_sqrt_divide_multiply_add": (
            25,
            [[1, 64, 768], [1, 64, 1], [1, 64, 1], [768], [768], [1, 64, 768]],
        ),
        "fused_nn_batch_matmul": (
            24,
            [[12, 64, 64], [12, 64, 64], [12, 64, 64]],
        ),
        "fused_reshape_add_add": (
            24,
            [[64, 768], [768], [1, 64, 768], [1, 64, 768]],
        ),
        "fused_variance": (
            25,
            [[1, 64, 768], [1, 64, 1], [1, 64, 1]],
        ),
        "fused_mean": (
            25,
            [[1, 64, 768], [1, 64, 1]],
        ),
        "fused_reshape_add_reshape_transpose_reshape": (
            12,
            [[64, 768], [768], [12, 64, 64]],
        ),
        "fused_reshape_add_multiply_fast_erf_multiply_add_multiply_reshape": (
            12,
            [[64, 3072], [3072], [64, 3072]],
        ),
        "fused_nn_fast_softmax": (
            12,
            [[1, 12, 64, 64], [1, 12, 64, 64]],
        ),
        "fused_reshape_add_reshape_transpose_reshape_1": (
            24,
            [[64, 768], [768], [12, 64, 64]],
        ),
        "fused_reshape_divide_add": (
            12,
            [[12, 64, 64], [1, 1, 1, 64], [1, 12, 64, 64]],
        ),
        "fused_reshape_transpose_reshape": (
            12,
            [[12, 64, 64], [64, 768]],
        ),
        "fused_nn_dense_add_fast_tanh": (
            1,
            [[1, 768], [768, 768], [1, 768], [1, 768]],
        ),
        "fused_cast_take_add": (
            1,
            [[1, 64], [30522, 768], [1, 64, 768], [1, 64, 768]],
        ),
        "fused_take": (
            1,
            [[1, 64, 768], [1, 768]],
        ),
        "fused_reshape": (
            12,
            [[1, 12, 64, 64], [12, 64, 64]],
        ),
        "fused_reshape_1": (
            24,
            [[1, 64, 768], [64, 768]],
        ),
    }
    mod, params, _ = get_network(name="bert_base", input_shape=[1, 64])
    extracted_tasks = ms.extract_task_from_relay(mod,
                                                 target="llvm",
                                                 params=params)
    assert len(extracted_tasks) == len(expected)
    for t in extracted_tasks:
        prim_func = None
        for _, v in t.dispatched[0].functions.items():
            prim_func = v
        shape = [[int(x) for x in prim_func.buffer_map[b].shape]
                 for b in prim_func.params]
        assert t.task_name in expected
        expected_weight, expected_shape = expected[t.task_name]
        assert expected_weight == t.weight, t.task_name
        assert expected_shape == shape, t.task_name
def test_meta_schedule_dynamic_loop_extent():
    a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
    b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
    mod = IRModule({"main": relay.Function([a], b)})
    extracted_tasks = ms.extract_task_from_relay(mod, target="llvm", params={})
    assert not extracted_tasks
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
    def filter_func(args) -> bool:
        from tvm.te import create_prim_func  # pylint: disable=import-outside-toplevel

        has_complex_op = False
        visited = set()

        def traverse(t):
            nonlocal has_complex_op
            assert t.handle is not None
            if t.handle.value in visited:
                return
            if isinstance(t.op, te.PlaceholderOp):
                pass
            elif isinstance(t.op, te.ComputeOp):
                has_complex_op = has_complex_op or any(
                    isinstance(e, tir.Reduce) for e in t.op.body)
                for x in t.op.input_tensors:
                    traverse(x)
            visited.add(t.handle.value)

        for t in args:
            traverse(t)
        if not has_complex_op:
            return None
        return create_prim_func(args)

    mod, params, _ = get_network(name="resnet_18",
                                 input_shape=[1, 3, 224, 224])
    extracted_tasks = ms.extract_task_from_relay(
        mod,
        target="llvm",
        params=params,
        te_filter_func=filter_func,
    )
    expected_task_names = [
        "fused_" + s for s in [
            "nn_max_pool2d",
            "nn_adaptive_avg_pool2d",
            "nn_dense_add",
            "nn_conv2d_add",
            "nn_conv2d_add_1",
            "nn_conv2d_add_2",
            "nn_conv2d_add_add_nn_relu",
            "nn_conv2d_add_add_nn_relu_1",
            "nn_conv2d_add_nn_relu",
            "nn_conv2d_add_nn_relu_1",
            "nn_conv2d_add_nn_relu_2",
            "nn_conv2d_add_nn_relu_3",
            "nn_conv2d_add_nn_relu_4",
            "nn_conv2d_add_nn_relu_5",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu",
            "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1",
        ]
    ]

    assert len(extracted_tasks) == len(expected_task_names)
    for t in extracted_tasks:
        assert t.task_name in expected_task_names, t.task_name