Пример #1
0
def test_error_reporting():
    try:
        te.create_prim_func(te_scan())
        assert False
    except TypeError as e:
        error_message = str(e)
        assert error_message.find("Unsupported Operation: ScanOp.") != -1
        return
    assert False
    def filter_func(args) -> bool:
        from tvm.te import create_prim_func  # pylint: disable=import-outside-toplevel

        has_complex_op = False
        visited = set()

        def traverse(t):
            nonlocal has_complex_op
            assert t.handle is not None
            if t.handle.value in visited:
                return
            if isinstance(t.op, te.PlaceholderOp):
                pass
            elif isinstance(t.op, te.ComputeOp):
                has_complex_op = has_complex_op or any(
                    isinstance(e, tir.Reduce) for e in t.op.body)
                for x in t.op.input_tensors:
                    traverse(x)
            visited.add(t.handle.value)

        for t in args:
            traverse(t)
        if not has_complex_op:
            return None
        return create_prim_func(args)
def test_cpu_matmul_relu():
    # pylint: disable=line-too-long
    expected = [
        [
            'b0 = sch.get_block(name="C", func_name="main")',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
            "l1, l2, l3 = sch.get_loops(block=b0)",
            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
            "b24, = sch.get_consumers(block=b0)",
            "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)",
        ],
        [
            'b0 = sch.get_block(name="C", func_name="main")',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
            "l1, l2, l3 = sch.get_loops(block=b0)",
            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
            "b24, = sch.get_consumers(block=b0)",
            "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)",
        ],
        [
            'b0 = sch.get_block(name="C", func_name="main")',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")',
            "l1, l2, l3 = sch.get_loops(block=b0)",
            "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)",
            "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])",
            "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)",
            "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])",
            "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
            "l22, l23 = sch.split(loop=l3, factors=[v20, v21])",
            "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)",
        ],
    ]
    # pylint: enable=line-too-long
    target = Target("llvm")
    ctx = _create_context(
        create_prim_func(
            te_workload.matmul_relu(
                n=512,
                m=512,
                k=512,
            )
        ),
        target=target,
        rule=multi_level_tiling(target=target),
    )
    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
    assert len(spaces) == 3
    check_trace(spaces, expected)
Пример #4
0
def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str,
                                 input_layout: str):
    """Schedule for input and output layout nhwc-8h2w32c2w"""
    func = te.create_prim_func([ins, outs])
    s = tir.Schedule(func)
    Sum = s.get_block("sum")
    Avg = s.get_block("avg")

    input_transform_fn = get_layout_transform_fn(input_layout)
    output_transform_fn = get_layout_transform_fn(output_layout)
    s.transform_layout(Sum, ("read", 0), input_transform_fn)
    s.transform_layout(Avg, ("write", 0), output_transform_fn)

    # Schedule 'Avg'
    n, h, w, c = s.get_loops(Avg)
    ho, hi = s.split(h, [None, 8])
    wo, wi = s.split(w, [None, 4])
    wio, wii = s.split(wi, [None, 2])
    co, ci = s.split(c, [None, 32])
    s.reorder(n, ho, wo, co, hi, wio, ci, wii)
    ci_wii = s.fuse(ci, wii)
    s.vectorize(ci_wii)

    # Schedule 'Sum'
    s.compute_at(Sum, wio)
    Sum_axis = s.get_loops(Sum)
    s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3])
    ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3])
    # s.vectorize(ci_wii) # Doesn't work
    return s
Пример #5
0
def STIR_schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
    """Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""

    # NOTE: This function is a variation of the STIR_schedule_nhwc_8h2w32c2w
    # functions.  Most of that function's code comments apply to this function
    # as well, but are ommited for brevity.

    # NOTE: the "n11c-1024c" output layout is shorthand for this axis mapping:
    # [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
    func = te.create_prim_func([ins, outs])

    s = tir.Schedule(func)
    Max = s.get_block("max")

    input_transform_fn = get_layout_transform_fn(input_layout)
    output_transform_fn = get_layout_transform_fn(output_layout)
    s.transform_layout(Max, ("read", 0), input_transform_fn)
    s.transform_layout(Max, ("write", 0), output_transform_fn)

    (
        n,
        h,
        w,
        c,
        rh,
        rw,
    ) = s.get_loops(Max)
    co, ci = s.split(c, [None, 1024])
    # s.vectorize(ci)

    return s
Пример #6
0
    def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
        """Compile Relay functions to a runtime module using Tensor Expressions."""
        assert isinstance(partition, relay.Function)
        assert isinstance(partition.body, relay.Call)
        assert isinstance(partition.body.op, relay.Function)

        global_name = str(partition.attrs.global_symbol)
        comp_func = partition.body.op
        comp_name = comp_func.attrs["Composite"]
        assert comp_name in _LOWER_MAP
        assert isinstance(comp_func.body, relay.Call)

        op = comp_func.body
        inputs = []
        for i, param in enumerate(comp_func.params):
            inputs.append(
                te.placeholder(
                    param.checked_type.shape,
                    name=f"input_{i}",
                    dtype=param.checked_type.dtype,
                )
            )

        output = _LOWER_MAP[comp_name](op, inputs)
        prim_func = te.create_prim_func(inputs + [output])
        return tvm.build(prim_func, target=target, name=global_name)
Пример #7
0
def relu_stir_schedule(Input, Output, input_layout, output_layout):
    """
    Schedule assumes the layout function to be bijective
    """
    if (input_layout != output_layout) or (output_layout !=
                                           "nhwc-8h2w32c2w-2d"):
        raise RuntimeError(
            f"Unexpected input_layout, output_layout '{input_layout, output_layout}'"
        )
    relu_func = te.create_prim_func([Input, Output])
    sch = tir.Schedule(relu_func, debug_mask="all")
    block = sch.get_block("compute")
    sch.transform_layout(block, Input.name,
                         get_layout_transform_fn(input_layout))
    sch.transform_layout(block, Output.name,
                         get_layout_transform_fn(output_layout))

    n, h, w, c = sch.get_loops(block)
    h_o, h_i = sch.split(h, [None, 8])
    w_o, w_i = sch.split(w, [None, 4])
    c_o, c_i = sch.split(c, [None, 32])
    wio, wii = sch.split(w_i, [None, 2])

    sch.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii)

    fused = sch.fuse(c_i, wii)
    sch.vectorize(fused)
    return sch
Пример #8
0
def test_unique_name():
    A = te.placeholder((16, 16), name="A")
    B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main")
    C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main")
    func = te.create_prim_func([A, C])
    s = tir.Schedule(func, debug_mask="all")
    assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef)
    assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef)
Пример #9
0
def test_int64_indices():
    n = te.var("n", "int64")
    A = te.placeholder((n,), name="A")
    B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B")
    prim_func = te.create_prim_func([A, B])
    loop = prim_func.body.block.body
    assert loop.loop_var.dtype == "int64"
    assert loop.min.dtype == "int64"
    assert loop.extent.dtype == "int64"
Пример #10
0
def test_select_simplify():
    placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32")
    tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c")
    result = te.create_prim_func([placeholder, tensor])
    script_func = result.script()
    # There should be no Select
    assert script_func.find("Select") == -1
    # There should be no undefined vars
    assert script_func.find("Var") == -1
Пример #11
0
def test_unique_name_reduction_block():
    k1 = te.reduce_axis((0, 16), "k1")
    k2 = te.reduce_axis((0, 16), "k2")
    A = te.placeholder((16, 16), name="A")
    B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum")
    C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum")
    func = te.create_prim_func([A, C])
    s = tir.Schedule(func, debug_mask="all")
    assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef)
    assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
def test_get_auto_tensorize_mapping_info_conv2d():
    conv2d = create_prim_func(
        te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64, 3, 1, 1))
    check_index_map(
        conv2d,
        "conv2d_nhwc",
        WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
        lambda n, h, w, c, rh, rw, rc:
        (n * 256 + h * 16 + w, c, rh * 192 + rw * 64 + rc),
    )
Пример #13
0
def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k):
    matmul = create_prim_func(
        te_workload.batch_matmul_nkkm(b,
                                      m,
                                      n,
                                      k,
                                      in_dtype="float16",
                                      out_dtype="float32"))
    check_index_map(matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
                    lambda b, m, n, k: (b, m, n, k))
Пример #14
0
    def test_cast_fp32_fp16_slice(
        self,
        input_shape,
        dtype,
        input_layout,
        output_layout,
        transformed_input_np,
        transformed_expected_output_np,
        axis_sep,
        hexagon_session,
        working_scope,
    ):
        """
        Top level testing function for cast fp32 to fp16
        """
        if hexagon_session._launcher._serial_number != "simulator":
            pytest.skip(
                msg="Due to https://github.com/apache/tvm/issues/11957")

        target_hexagon = tvm.target.hexagon("v68")
        target = tvm.target.Target(target_hexagon, host=target_hexagon)
        cast_input = te.placeholder(input_shape, name="A", dtype=dtype)
        cast_output = sl.cast_f32_f16_compute(cast_input)
        cast_func = te.create_prim_func([cast_input, cast_output])
        tir_s = sl.cast_f32_f16_schedule(cast_func, input_layout,
                                         output_layout)
        input_data = allocate_hexagon_array(
            hexagon_session.device,
            data=transformed_input_np,
            axis_separators=axis_sep,
            mem_scope=working_scope,
        )
        output_data = allocate_hexagon_array(
            hexagon_session.device,
            tensor_shape=transformed_expected_output_np.shape,
            dtype=transformed_expected_output_np.dtype,
            axis_separators=axis_sep,
            mem_scope=working_scope,
        )
        with tvm.transform.PassContext(opt_level=3):
            tir_irm = tvm.lower(tir_s.mod, [cast_input, cast_output],
                                name="cast_f32_f16")
            runtime_module = tvm.build(tir_irm,
                                       target=target,
                                       name="cast_f32_f16")
        mod = hexagon_session.load_module(runtime_module)

        mod(input_data, output_data)
        output_np = output_data.numpy()
        tvm.testing.assert_allclose(
            output_np,
            transformed_expected_output_np,
            1e-3,
            1e-3,
        )
def test_three_stage_gemm():
    N = K = M = 4096
    i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1,
                                                        2], [128, 2, 1]

    def is_ampere_or_newer():
        arch = tvm.contrib.nvcc.get_target_compute_version()
        major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
        return major >= 8

    def index_map(i, j):
        return (
            i // 16,
            j // 16,
            *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
        )

    workload = te.create_prim_func(te_workload.matmul_fp16(N, M, K))

    sch = mma_schedule(
        workload,
        16,
        "float16",
        False,
        i_factors,
        j_factors,
        k_factors,
        index_map,
        index_map,
        index_map,
        LDMATRIX_16x16_A_DYN_INTRIN,
        LDMATRIX_16x16_B_DYN_INTRIN,
        MMA_f16f16f32_INTRIN,
        MMA_fill_16x16_f32_INTRIN,
        MMA_store_16x16_f32_global_INTRIN,
        "shared.dyn",
    )

    k0 = sch.get_loops(sch.get_block("C_o_update"))[3]

    sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
    sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])

    if is_ampere_or_newer():
        f = tvm.build(sch.mod["main"], target="cuda")

        dev = tvm.device("cuda", 0)
        a_np = np.random.uniform(size=(N, K)).astype("float16")
        b_np = np.random.uniform(size=(K, M)).astype("float16")
        c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
        a = tvm.nd.array(a_np, dev)
        b = tvm.nd.array(b_np, dev)
        c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev)
        f(a, b, c)
        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
Пример #16
0
    def _lower_relay_to_tir(
            self, relay_prim_func: relay.Function) -> tvm.tir.PrimFunc:
        """Lower a Relay primitive function to a S-TIR primitive function.

        Parameters
        ----------
        prim_func : tvm.relay.Function
            The Relay function to lower.

        Returns
        -------
        out : tvm.tir.PrimFunc
            The lowered schedulable TensorIR primitive function.

        """
        def _get_tensors(te_cached_func):
            outputs = list(te_cached_func.outputs)
            stack = []
            visited = set()
            for output_ in outputs:
                if output_ not in visited:
                    visited.add(output_)
                    stack.append(output_)

            args = []
            while len(stack) != 0:
                tensor = stack.pop()
                if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp):
                    args.append(tensor)
                elif isinstance(tensor.op, tvm.te.tensor.ComputeOp):
                    inputs = tensor.op.input_tensors
                    for input_ in inputs:
                        if input_ not in visited:
                            visited.add(input_)
                            stack.append(input_)

            return args + outputs

        lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE")
        te_cached_func = lower_to_te(relay_prim_func)
        x = _get_tensors(te_cached_func)
        tir_prim_func = te.create_prim_func(x)
        tir_prim_func = tir_prim_func.with_attr(
            "global_symbol", relay_prim_func.attrs["global_symbol"])

        compiler_attr = relay_prim_func.attrs["Compiler"]
        target = tvm.target.Target.current()
        if target.kind.name != compiler_attr:
            target = tvm.target.Target(compiler_attr)

        tir_prim_func = tir_prim_func.with_attr("target", target)
        tir_prim_func = tir_prim_func.with_attr("relay_attrs",
                                                relay_prim_func.attrs)
        return tir_prim_func
def test_cuda_matmul():
    # pylint: disable=line-too-long
    expected = [
        [
            'b0 = sch.get_block(name="C", func_name="main")',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")',
            "l1, l2, l3 = sch.get_loops(block=b0)",
            "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)",
            "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])",
            "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)",
            "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])",
            "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)",
            "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])",
            "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)",
            "l30 = sch.fuse(l9, l19)",
            'sch.bind(loop=l30, thread_axis="blockIdx.x")',
            "l31 = sch.fuse(l10, l20)",
            'sch.bind(loop=l31, thread_axis="vthread.x")',
            "l32 = sch.fuse(l11, l21)",
            'sch.bind(loop=l32, thread_axis="threadIdx.x")',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)',
            'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)',
            'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")',
            "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)",
            'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")',
            "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)",
            "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
            "l41 = sch.fuse(l39, l40)",
            "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
            'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
            "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)",
            "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
            "l50 = sch.fuse(l48, l49)",
            "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
            'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
        ]
    ]
    # pylint: enable=line-too-long
    target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm")
    ctx = _create_context(
        create_prim_func(
            te_workload.matmul(
                n=512,
                m=512,
                k=512,
            )
        ),
        target=target,
        rule=multi_level_tiling(target=target),
    )
    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
    assert len(spaces) == 1
    check_trace(spaces, expected)
    def test_te_extern_call(self, func, verify):
        ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
        prim_func = ir_mod["main"]

        input_tensors = create_input_tensors_for_primfunc(prim_func)
        output = te.extern_primfunc(input_tensors, prim_func)
        rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func))
        tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func))

        target = tvm.target.Target("llvm")
        func = tvm.build(rt_prim_func, target=target)
        verify(func)
Пример #19
0
def tir_broadcast_schedule(
    out_m,
    input_a,
    input_b,
    output_layout: str,
    input_a_layout: str,
    input_b_layout: str,
    op_name: str,
):
    """Schedule for input and output layout nhwc-8h2w32c2w-2d considering broadcast"""
    func = te.create_prim_func([input_a, input_b, out_m])

    s = tir.Schedule(func)

    block_dict = {
        "add": "T_add",
        "subtract": "T_subtract",
        "multiply": "T_multiply"
    }

    block = s.get_block(block_dict[op_name])

    if input_a_layout == "nhwc-8h2w32c2w-2d":
        input_a_transformed_layout = get_layout_transform_fn(input_a_layout)
        s.transform_layout(block,
                           buffer=("read", 0),
                           index_map=input_a_transformed_layout)

    if input_b_layout == "nhwc-8h2w32c2w-2d":
        input_b_transformed_layout = get_layout_transform_fn(input_b_layout)
        s.transform_layout(block,
                           buffer=("read", 1),
                           index_map=input_b_transformed_layout)

    output_transformed_layout = get_layout_transform_fn(output_layout)
    s.transform_layout(block,
                       buffer=("write", 0),
                       index_map=output_transformed_layout)

    n, h, w, c = s.get_loops(block)

    h_o, h_i = s.split(h, [None, 8])
    w_o, w_i = s.split(w, [None, 4])
    c_o, c_i = s.split(c, [None, 32])
    wio, wii = s.split(w_i, [None, 2])

    s.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii)

    fused = s.fuse(c_i, wii)
    s.vectorize(fused)

    return s
Пример #20
0
def test_data_dependent_access():
    A = te.placeholder((10, ), name="A")
    B = te.placeholder((10, ), name="B", dtype="int32")
    C = te.compute((10, ), lambda i: A[B[i]])

    func = te.create_prim_func([C, A, B])
    func = tvm.build(func)

    a_np = np.random.uniform(size=(10, )).astype(A.dtype)
    b_np = np.arange(10, dtype=B.dtype)
    c = tvm.nd.array(np.zeros(10, dtype=C.dtype))
    func(c, tvm.nd.array(a_np), tvm.nd.array(b_np))
    tvm.testing.assert_allclose(a_np[b_np], c.numpy())
Пример #21
0
def test_tensor_attr():
    k = te.reduce_axis((0, 128), "k")
    A = te.placeholder((128, 128), name="A")
    B = te.placeholder((128, 128), name="B")
    C = te.compute(
        (128, 128),
        lambda x, y: te.sum(A[x, k] * B[y, k], axis=k),
        name="C",
        attrs={"layout_free_placeholders": [B]},
    )
    func = te.create_prim_func([A, B, C])
    rt_func = tvm.script.from_source(func.script())
    tvm.ir.assert_structural_equal(func, rt_func)
Пример #22
0
    def test_argmax_slice(
        self,
        input_shape,
        dtype,
        input_layout,
        output_layout,
        in_axis,
        transformed_input_np,
        transformed_expected_output_np,
        in_axis_sep,
        out_axis_sep,
        hexagon_session,
        working_scope,
    ):
        """Top level testing function for argmax"""
        target_hexagon = tvm.target.hexagon("v69")
        target = tvm.target.Target(target_hexagon, host=target_hexagon)
        argmax_input = te.placeholder(input_shape, name="A", dtype=dtype)
        output = sl.argmax.argmax_compute(argmax_input, in_axis)
        argmax_func = te.create_prim_func([argmax_input, output])
        tir_s = sl.argmax_schedule(argmax_func, input_layout, output_layout)
        input_data = allocate_hexagon_array(
            hexagon_session.device,
            data=transformed_input_np,
            axis_separators=in_axis_sep,
            mem_scope=working_scope,
        )
        output_data = allocate_hexagon_array(
            hexagon_session.device,
            tensor_shape=transformed_expected_output_np.shape,
            dtype=transformed_expected_output_np.dtype,
            axis_separators=out_axis_sep,
            mem_scope=working_scope,
        )
        with tvm.transform.PassContext(opt_level=3,
                                       config={"tir.disable_assert": True}):
            tir_irm = tvm.lower(tir_s.mod, [argmax_input, output],
                                name="argmax")
            runtime_module = tvm.build(tir_irm, [argmax_input, output],
                                       target=target,
                                       name="argmax")
        mod = hexagon_session.load_module(runtime_module)

        mod(input_data, output_data)
        output_np = output_data.numpy()
        tvm.testing.assert_allclose(
            output_np,
            transformed_expected_output_np,
            1e-3,
            1e-3,
        )
Пример #23
0
def stir_schedule_nhwc_8h2w32c2w(
    out: te.Tensor,
    inp: te.Tensor,
    out_layout: str,
    in_layout: str,
) -> tir.Schedule:
    """Schedule for input and output layout nhwc-8h2w32c2w"""
    reshape_func = te.create_prim_func([inp, out])
    sch = tir.Schedule(reshape_func, debug_mask="all")
    compute = sch.get_block("T_reshape")

    sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout))
    sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout))
    return sch
Пример #24
0
def test_constant():
    M = 11
    A = te.placeholder((M,), name="A")
    B = te.compute(tuple(), lambda: 2, name="B")
    # Manually craft ProducerLoad because `B[]` is not allowed.
    C = te.compute(
        (M,), lambda x: A[x] + tvm.tir.expr.ProducerLoad(B, []), name="C", tag="broadcast"
    )

    func = te.create_prim_func([C, A])
    func = tvm.build(func)
    a_np = np.random.uniform(size=(M,)).astype(A.dtype)
    c = tvm.nd.array(np.zeros(M, dtype=C.dtype))
    x = func(c, tvm.nd.array(a_np))
    tvm.testing.assert_allclose(a_np + 2, c.numpy())
Пример #25
0
    def test_tanh(
        self,
        input_shape,
        dtype,
        input_layout,
        output_layout,
        transformed_input_np,
        transformed_expected_output_np,
        axis_sep,
        hexagon_session,
        working_scope,
    ):
        """Top Level testing function for tanh fp16 op"""

        target_hexagon = tvm.target.hexagon("v69")
        target = tvm.target.Target(target_hexagon, host=target_hexagon)
        A = te.placeholder(input_shape, name="A", dtype=dtype)
        M = sl.tanh_te_compute(A)
        tanhf16_func = te.create_prim_func([A, M])
        tir_s = sl.tanhf16_schedule(tanhf16_func, input_layout, output_layout)
        A_data = allocate_hexagon_array(
            hexagon_session.device,
            data=transformed_input_np,
            axis_separators=axis_sep,
            mem_scope=working_scope,
        )
        M_data = allocate_hexagon_array(
            hexagon_session.device,
            tensor_shape=transformed_expected_output_np.shape,
            dtype=transformed_expected_output_np.dtype,
            axis_separators=axis_sep,
            mem_scope=working_scope,
        )
        with tvm.transform.PassContext(opt_level=3):
            tir_irm = tvm.lower(tir_s.mod, [A, M], name="tanhf16")
            runtime_module = tvm.build(tir_irm, target=target, name="tanhf16")
        mod = hexagon_session.load_module(runtime_module)

        mod(A_data, M_data)
        output_np = M_data.numpy()
        tvm.testing.assert_allclose(
            output_np,
            transformed_expected_output_np,
            1e-3,
            1e-3,
        )
def test_tensorize_dpa4():
    m, n, k = 128, 128, 128

    X = te.placeholder((m, k), name="X", dtype="int8")
    W = te.placeholder((n, k), name="W", dtype="int8")
    ak = te.reduce_axis((0, k), name="k")

    matmul = te.compute(
        (m, n),
        lambda i, j: te.sum(
            X[i, ak].astype("int32") * W[j, ak].astype("int32"),
            axis=ak,
        ),
        name="compute",
    )

    func = te.create_prim_func([X, W, matmul])

    for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
        sch = tir.Schedule(func, debug_mask="all")
        block = sch.get_block("compute")
        i, j, k = sch.get_loops(block)

        by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3))
        bx, tx, xi = sch.split(j, factors=sch.sample_perfect_tile(j, n=3))
        ko, ki = sch.split(k, [None, 4])
        ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))

        sch.reorder(by, bx, ty, tx, yi, xi)

        CC = sch.cache_write(block, 0, "local")
        sch.reverse_compute_at(CC, tx)

        def fetch_to_shared(block, idx):
            block_read = sch.cache_read(block, idx, "shared")
            sch.compute_at(block_read, ko, True)
            return block_read

        fetch_to_shared(block, 0)
        fetch_to_shared(block, 1)

        sch.decompose_reduction(block, ko)
        sch.tensorize(ki, intrin)

        verify_trace_roundtrip(sch=sch, mod=func)
Пример #27
0
def test_tensor_layout_attr():
    k = te.reduce_axis((0, 128), "k")
    A = te.placeholder((128, 128), name="A")
    B = te.placeholder((128, 128), name="B")
    C = te.compute(
        (128, 128),
        lambda x, y: te.sum(A[x, k] * B[y, k], axis=k),
        name="C",
        attrs={"layout_free_placeholders": [B]},
    )
    D = te.compute(
        (128, 128),
        lambda x, y: C[x, y] + 1,
        name="D",
        attrs={"layout_free_placeholders": [C]},
    )
    func = te.create_prim_func([A, B, D])
    tvm.ir.assert_structural_equal(func, expected_layout_attr)
def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
    X = te.placeholder((m, k), name="X", dtype=lhs_type)
    packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4),
                              name="packedW",
                              dtype="int8")

    ak = te.reduce_axis((0, k), name="k")
    matmul = te.compute(
        (m, n),
        lambda i, j: te.sum(
            X[i, ak].astype("int32") * packed_W[tvm.tir.indexdiv(j, 16),
                                                tvm.tir.indexdiv(ak, 4), j %
                                                16, ak % 4].astype("int32"),
            axis=ak,
        ),
        name="compute",
    )

    return te.create_prim_func([X, packed_W, matmul])
Пример #29
0
def test_get_tensorize_loop_mapping_matmul_mma():
    @T.prim_func
    def matmul_16x16x16xf16f16f16_desc(
        A: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        B: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
        C: T.Buffer((16, 16), "float16", align=128, offset_factor=1),
    ) -> None:
        with T.block("root"):
            T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
            T.writes(C[0:16, 0:16])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("update"):
                    vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                    C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]

    matmul = create_prim_func(te_workload.matmul_relu(
        n=512,
        m=512,
        k=512,
    ))

    s = Schedule(matmul)
    block = s.get_block("C")
    i0, i1, i2 = s.get_loops(block)
    desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc)

    for do_reorder in [False, True]:
        # Mapping should be invariant to the loop permutation
        if do_reorder:
            s.reorder(i2, i0, i1)

        info = get_tensorize_loop_mapping(s, block,
                                          matmul_16x16x16xf16f16f16_desc)
        assert info is not None
        desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items())

        for i in range(3):
            assert desc_loops[i] in desc_loop_to_sref

        assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0)
        assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1)
        assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
def test_rewrite_cooperative_fetch():
    mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512))
    target = _target()
    ctx = _create_context(mod, target)

    sch = tir.Schedule(mod, debug_mask="all")
    # fmt: off
    # pylint: disable=line-too-long,invalid-name
    b0 = sch.get_block(name="C", func_name="main")
    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")
    l2, l3, l4 = sch.get_loops(block=b0)
    v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16])
    l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])
    v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2])
    l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])
    v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32])
    l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])
    sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)
    l31 = sch.fuse(l10, l20)
    sch.bind(loop=l31, thread_axis="blockIdx.x")
    l32 = sch.fuse(l11, l21)
    sch.bind(loop=l32, thread_axis="vthread.x")
    l33 = sch.fuse(l12, l22)
    sch.bind(loop=l33, thread_axis="threadIdx.x")
    b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")
    sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)
    _, _, _, _, l39, l40 = sch.get_loops(block=b34)
    l41 = sch.fuse(l39, l40)
    _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1])
    sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)
    b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")
    sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)
    _, _, _, _, l49, l50 = sch.get_loops(block=b44)
    l51 = sch.fuse(l49, l50)
    _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2])
    sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)
    sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)
    # pylint: enable=line-too-long,invalid-name
    # fmt: on
    sch.enter_postproc()
    assert ctx.postprocs[0].apply(sch)
    tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)