def test_error_reporting(): try: te.create_prim_func(te_scan()) assert False except TypeError as e: error_message = str(e) assert error_message.find("Unsupported Operation: ScanOp.") != -1 return assert False
def filter_func(args) -> bool: from tvm.te import create_prim_func # pylint: disable=import-outside-toplevel has_complex_op = False visited = set() def traverse(t): nonlocal has_complex_op assert t.handle is not None if t.handle.value in visited: return if isinstance(t.op, te.PlaceholderOp): pass elif isinstance(t.op, te.ComputeOp): has_complex_op = has_complex_op or any( isinstance(e, tir.Reduce) for e in t.op.body) for x in t.op.input_tensors: traverse(x) visited.add(t.handle.value) for t in args: traverse(t) if not has_complex_op: return None return create_prim_func(args)
def test_cpu_matmul_relu(): # pylint: disable=line-too-long expected = [ [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", "b24, = sch.get_consumers(block=b0)", "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", ], ] # pylint: enable=line-too-long target = Target("llvm") ctx = _create_context( create_prim_func( te_workload.matmul_relu( n=512, m=512, k=512, ) ), target=target, rule=multi_level_tiling(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 3 check_trace(spaces, expected)
def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str): """Schedule for input and output layout nhwc-8h2w32c2w""" func = te.create_prim_func([ins, outs]) s = tir.Schedule(func) Sum = s.get_block("sum") Avg = s.get_block("avg") input_transform_fn = get_layout_transform_fn(input_layout) output_transform_fn = get_layout_transform_fn(output_layout) s.transform_layout(Sum, ("read", 0), input_transform_fn) s.transform_layout(Avg, ("write", 0), output_transform_fn) # Schedule 'Avg' n, h, w, c = s.get_loops(Avg) ho, hi = s.split(h, [None, 8]) wo, wi = s.split(w, [None, 4]) wio, wii = s.split(wi, [None, 2]) co, ci = s.split(c, [None, 32]) s.reorder(n, ho, wo, co, hi, wio, ci, wii) ci_wii = s.fuse(ci, wii) s.vectorize(ci_wii) # Schedule 'Sum' s.compute_at(Sum, wio) Sum_axis = s.get_loops(Sum) s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3]) ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3]) # s.vectorize(ci_wii) # Doesn't work return s
def STIR_schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str): """Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w""" # NOTE: This function is a variation of the STIR_schedule_nhwc_8h2w32c2w # functions. Most of that function's code comments apply to this function # as well, but are ommited for brevity. # NOTE: the "n11c-1024c" output layout is shorthand for this axis mapping: # [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024] func = te.create_prim_func([ins, outs]) s = tir.Schedule(func) Max = s.get_block("max") input_transform_fn = get_layout_transform_fn(input_layout) output_transform_fn = get_layout_transform_fn(output_layout) s.transform_layout(Max, ("read", 0), input_transform_fn) s.transform_layout(Max, ("write", 0), output_transform_fn) ( n, h, w, c, rh, rw, ) = s.get_loops(Max) co, ci = s.split(c, [None, 1024]) # s.vectorize(ci) return s
def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module: """Compile Relay functions to a runtime module using Tensor Expressions.""" assert isinstance(partition, relay.Function) assert isinstance(partition.body, relay.Call) assert isinstance(partition.body.op, relay.Function) global_name = str(partition.attrs.global_symbol) comp_func = partition.body.op comp_name = comp_func.attrs["Composite"] assert comp_name in _LOWER_MAP assert isinstance(comp_func.body, relay.Call) op = comp_func.body inputs = [] for i, param in enumerate(comp_func.params): inputs.append( te.placeholder( param.checked_type.shape, name=f"input_{i}", dtype=param.checked_type.dtype, ) ) output = _LOWER_MAP[comp_name](op, inputs) prim_func = te.create_prim_func(inputs + [output]) return tvm.build(prim_func, target=target, name=global_name)
def relu_stir_schedule(Input, Output, input_layout, output_layout): """ Schedule assumes the layout function to be bijective """ if (input_layout != output_layout) or (output_layout != "nhwc-8h2w32c2w-2d"): raise RuntimeError( f"Unexpected input_layout, output_layout '{input_layout, output_layout}'" ) relu_func = te.create_prim_func([Input, Output]) sch = tir.Schedule(relu_func, debug_mask="all") block = sch.get_block("compute") sch.transform_layout(block, Input.name, get_layout_transform_fn(input_layout)) sch.transform_layout(block, Output.name, get_layout_transform_fn(output_layout)) n, h, w, c = sch.get_loops(block) h_o, h_i = sch.split(h, [None, 8]) w_o, w_i = sch.split(w, [None, 4]) c_o, c_i = sch.split(c, [None, 32]) wio, wii = sch.split(w_i, [None, 2]) sch.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii) fused = sch.fuse(c_i, wii) sch.vectorize(fused) return sch
def test_unique_name(): A = te.placeholder((16, 16), name="A") B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef)
def test_int64_indices(): n = te.var("n", "int64") A = te.placeholder((n,), name="A") B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") prim_func = te.create_prim_func([A, B]) loop = prim_func.body.block.body assert loop.loop_var.dtype == "int64" assert loop.min.dtype == "int64" assert loop.extent.dtype == "int64"
def test_select_simplify(): placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32") tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c") result = te.create_prim_func([placeholder, tensor]) script_func = result.script() # There should be no Select assert script_func.find("Select") == -1 # There should be no undefined vars assert script_func.find("Var") == -1
def test_unique_name_reduction_block(): k1 = te.reduce_axis((0, 16), "k1") k2 = te.reduce_axis((0, 16), "k2") A = te.placeholder((16, 16), name="A") B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum") C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum") func = te.create_prim_func([A, C]) s = tir.Schedule(func, debug_mask="all") assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef) assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
def test_get_auto_tensorize_mapping_info_conv2d(): conv2d = create_prim_func( te_workload.conv2d_nhwc_f16(4, 16, 16, 64, 64, 3, 1, 1)) check_index_map( conv2d, "conv2d_nhwc", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda n, h, w, c, rh, rw, rc: (n * 256 + h * 16 + w, c, rh * 192 + rw * 64 + rc), )
def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k): matmul = create_prim_func( te_workload.batch_matmul_nkkm(b, m, n, k, in_dtype="float16", out_dtype="float32")) check_index_map(matmul, "Z", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, lambda b, m, n, k: (b, m, n, k))
def test_cast_fp32_fp16_slice( self, input_shape, dtype, input_layout, output_layout, transformed_input_np, transformed_expected_output_np, axis_sep, hexagon_session, working_scope, ): """ Top level testing function for cast fp32 to fp16 """ if hexagon_session._launcher._serial_number != "simulator": pytest.skip( msg="Due to https://github.com/apache/tvm/issues/11957") target_hexagon = tvm.target.hexagon("v68") target = tvm.target.Target(target_hexagon, host=target_hexagon) cast_input = te.placeholder(input_shape, name="A", dtype=dtype) cast_output = sl.cast_f32_f16_compute(cast_input) cast_func = te.create_prim_func([cast_input, cast_output]) tir_s = sl.cast_f32_f16_schedule(cast_func, input_layout, output_layout) input_data = allocate_hexagon_array( hexagon_session.device, data=transformed_input_np, axis_separators=axis_sep, mem_scope=working_scope, ) output_data = allocate_hexagon_array( hexagon_session.device, tensor_shape=transformed_expected_output_np.shape, dtype=transformed_expected_output_np.dtype, axis_separators=axis_sep, mem_scope=working_scope, ) with tvm.transform.PassContext(opt_level=3): tir_irm = tvm.lower(tir_s.mod, [cast_input, cast_output], name="cast_f32_f16") runtime_module = tvm.build(tir_irm, target=target, name="cast_f32_f16") mod = hexagon_session.load_module(runtime_module) mod(input_data, output_data) output_np = output_data.numpy() tvm.testing.assert_allclose( output_np, transformed_expected_output_np, 1e-3, 1e-3, )
def test_three_stage_gemm(): N = K = M = 4096 i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] def is_ampere_or_newer(): arch = tvm.contrib.nvcc.get_target_compute_version() major, _ = tvm.contrib.nvcc.parse_compute_version(arch) return major >= 8 def index_map(i, j): return ( i // 16, j // 16, *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), ) workload = te.create_prim_func(te_workload.matmul_fp16(N, M, K)) sch = mma_schedule( workload, 16, "float16", False, i_factors, j_factors, k_factors, index_map, index_map, index_map, LDMATRIX_16x16_A_DYN_INTRIN, LDMATRIX_16x16_B_DYN_INTRIN, MMA_f16f16f32_INTRIN, MMA_fill_16x16_f32_INTRIN, MMA_store_16x16_f32_global_INTRIN, "shared.dyn", ) k0 = sch.get_loops(sch.get_block("C_o_update"))[3] sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) if is_ampere_or_newer(): f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) a_np = np.random.uniform(size=(N, K)).astype("float16") b_np = np.random.uniform(size=(K, M)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) a = tvm.nd.array(a_np, dev) b = tvm.nd.array(b_np, dev) c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) f(a, b, c) tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
def _lower_relay_to_tir( self, relay_prim_func: relay.Function) -> tvm.tir.PrimFunc: """Lower a Relay primitive function to a S-TIR primitive function. Parameters ---------- prim_func : tvm.relay.Function The Relay function to lower. Returns ------- out : tvm.tir.PrimFunc The lowered schedulable TensorIR primitive function. """ def _get_tensors(te_cached_func): outputs = list(te_cached_func.outputs) stack = [] visited = set() for output_ in outputs: if output_ not in visited: visited.add(output_) stack.append(output_) args = [] while len(stack) != 0: tensor = stack.pop() if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp): args.append(tensor) elif isinstance(tensor.op, tvm.te.tensor.ComputeOp): inputs = tensor.op.input_tensors for input_ in inputs: if input_ not in visited: visited.add(input_) stack.append(input_) return args + outputs lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE") te_cached_func = lower_to_te(relay_prim_func) x = _get_tensors(te_cached_func) tir_prim_func = te.create_prim_func(x) tir_prim_func = tir_prim_func.with_attr( "global_symbol", relay_prim_func.attrs["global_symbol"]) compiler_attr = relay_prim_func.attrs["Compiler"] target = tvm.target.Target.current() if target.kind.name != compiler_attr: target = tvm.target.Target(compiler_attr) tir_prim_func = tir_prim_func.with_attr("target", target) tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs) return tir_prim_func
def test_cuda_matmul(): # pylint: disable=line-too-long expected = [ [ 'b0 = sch.get_block(name="C", func_name="main")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', "l1, l2, l3 = sch.get_loops(block=b0)", "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", "l30 = sch.fuse(l9, l19)", 'sch.bind(loop=l30, thread_axis="blockIdx.x")', "l31 = sch.fuse(l10, l20)", 'sch.bind(loop=l31, thread_axis="vthread.x")', "l32 = sch.fuse(l11, l21)", 'sch.bind(loop=l32, thread_axis="threadIdx.x")', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", "l50 = sch.fuse(l48, l49)", "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', ] ] # pylint: enable=line-too-long target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") ctx = _create_context( create_prim_func( te_workload.matmul( n=512, m=512, k=512, ) ), target=target, rule=multi_level_tiling(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) assert len(spaces) == 1 check_trace(spaces, expected)
def test_te_extern_call(self, func, verify): ir_mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) prim_func = ir_mod["main"] input_tensors = create_input_tensors_for_primfunc(prim_func) output = te.extern_primfunc(input_tensors, prim_func) rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func)) tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func)) target = tvm.target.Target("llvm") func = tvm.build(rt_prim_func, target=target) verify(func)
def tir_broadcast_schedule( out_m, input_a, input_b, output_layout: str, input_a_layout: str, input_b_layout: str, op_name: str, ): """Schedule for input and output layout nhwc-8h2w32c2w-2d considering broadcast""" func = te.create_prim_func([input_a, input_b, out_m]) s = tir.Schedule(func) block_dict = { "add": "T_add", "subtract": "T_subtract", "multiply": "T_multiply" } block = s.get_block(block_dict[op_name]) if input_a_layout == "nhwc-8h2w32c2w-2d": input_a_transformed_layout = get_layout_transform_fn(input_a_layout) s.transform_layout(block, buffer=("read", 0), index_map=input_a_transformed_layout) if input_b_layout == "nhwc-8h2w32c2w-2d": input_b_transformed_layout = get_layout_transform_fn(input_b_layout) s.transform_layout(block, buffer=("read", 1), index_map=input_b_transformed_layout) output_transformed_layout = get_layout_transform_fn(output_layout) s.transform_layout(block, buffer=("write", 0), index_map=output_transformed_layout) n, h, w, c = s.get_loops(block) h_o, h_i = s.split(h, [None, 8]) w_o, w_i = s.split(w, [None, 4]) c_o, c_i = s.split(c, [None, 32]) wio, wii = s.split(w_i, [None, 2]) s.reorder(n, h_o, w_o, c_o, h_i, wio, c_i, wii) fused = s.fuse(c_i, wii) s.vectorize(fused) return s
def test_data_dependent_access(): A = te.placeholder((10, ), name="A") B = te.placeholder((10, ), name="B", dtype="int32") C = te.compute((10, ), lambda i: A[B[i]]) func = te.create_prim_func([C, A, B]) func = tvm.build(func) a_np = np.random.uniform(size=(10, )).astype(A.dtype) b_np = np.arange(10, dtype=B.dtype) c = tvm.nd.array(np.zeros(10, dtype=C.dtype)) func(c, tvm.nd.array(a_np), tvm.nd.array(b_np)) tvm.testing.assert_allclose(a_np[b_np], c.numpy())
def test_tensor_attr(): k = te.reduce_axis((0, 128), "k") A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") C = te.compute( (128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C", attrs={"layout_free_placeholders": [B]}, ) func = te.create_prim_func([A, B, C]) rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func)
def test_argmax_slice( self, input_shape, dtype, input_layout, output_layout, in_axis, transformed_input_np, transformed_expected_output_np, in_axis_sep, out_axis_sep, hexagon_session, working_scope, ): """Top level testing function for argmax""" target_hexagon = tvm.target.hexagon("v69") target = tvm.target.Target(target_hexagon, host=target_hexagon) argmax_input = te.placeholder(input_shape, name="A", dtype=dtype) output = sl.argmax.argmax_compute(argmax_input, in_axis) argmax_func = te.create_prim_func([argmax_input, output]) tir_s = sl.argmax_schedule(argmax_func, input_layout, output_layout) input_data = allocate_hexagon_array( hexagon_session.device, data=transformed_input_np, axis_separators=in_axis_sep, mem_scope=working_scope, ) output_data = allocate_hexagon_array( hexagon_session.device, tensor_shape=transformed_expected_output_np.shape, dtype=transformed_expected_output_np.dtype, axis_separators=out_axis_sep, mem_scope=working_scope, ) with tvm.transform.PassContext(opt_level=3, config={"tir.disable_assert": True}): tir_irm = tvm.lower(tir_s.mod, [argmax_input, output], name="argmax") runtime_module = tvm.build(tir_irm, [argmax_input, output], target=target, name="argmax") mod = hexagon_session.load_module(runtime_module) mod(input_data, output_data) output_np = output_data.numpy() tvm.testing.assert_allclose( output_np, transformed_expected_output_np, 1e-3, 1e-3, )
def stir_schedule_nhwc_8h2w32c2w( out: te.Tensor, inp: te.Tensor, out_layout: str, in_layout: str, ) -> tir.Schedule: """Schedule for input and output layout nhwc-8h2w32c2w""" reshape_func = te.create_prim_func([inp, out]) sch = tir.Schedule(reshape_func, debug_mask="all") compute = sch.get_block("T_reshape") sch.transform_layout(compute, inp.name, get_layout_transform_fn(in_layout)) sch.transform_layout(compute, out.name, get_layout_transform_fn(out_layout)) return sch
def test_constant(): M = 11 A = te.placeholder((M,), name="A") B = te.compute(tuple(), lambda: 2, name="B") # Manually craft ProducerLoad because `B[]` is not allowed. C = te.compute( (M,), lambda x: A[x] + tvm.tir.expr.ProducerLoad(B, []), name="C", tag="broadcast" ) func = te.create_prim_func([C, A]) func = tvm.build(func) a_np = np.random.uniform(size=(M,)).astype(A.dtype) c = tvm.nd.array(np.zeros(M, dtype=C.dtype)) x = func(c, tvm.nd.array(a_np)) tvm.testing.assert_allclose(a_np + 2, c.numpy())
def test_tanh( self, input_shape, dtype, input_layout, output_layout, transformed_input_np, transformed_expected_output_np, axis_sep, hexagon_session, working_scope, ): """Top Level testing function for tanh fp16 op""" target_hexagon = tvm.target.hexagon("v69") target = tvm.target.Target(target_hexagon, host=target_hexagon) A = te.placeholder(input_shape, name="A", dtype=dtype) M = sl.tanh_te_compute(A) tanhf16_func = te.create_prim_func([A, M]) tir_s = sl.tanhf16_schedule(tanhf16_func, input_layout, output_layout) A_data = allocate_hexagon_array( hexagon_session.device, data=transformed_input_np, axis_separators=axis_sep, mem_scope=working_scope, ) M_data = allocate_hexagon_array( hexagon_session.device, tensor_shape=transformed_expected_output_np.shape, dtype=transformed_expected_output_np.dtype, axis_separators=axis_sep, mem_scope=working_scope, ) with tvm.transform.PassContext(opt_level=3): tir_irm = tvm.lower(tir_s.mod, [A, M], name="tanhf16") runtime_module = tvm.build(tir_irm, target=target, name="tanhf16") mod = hexagon_session.load_module(runtime_module) mod(A_data, M_data) output_np = M_data.numpy() tvm.testing.assert_allclose( output_np, transformed_expected_output_np, 1e-3, 1e-3, )
def test_tensorize_dpa4(): m, n, k = 128, 128, 128 X = te.placeholder((m, k), name="X", dtype="int8") W = te.placeholder((n, k), name="W", dtype="int8") ak = te.reduce_axis((0, k), name="k") matmul = te.compute( (m, n), lambda i, j: te.sum( X[i, ak].astype("int32") * W[j, ak].astype("int32"), axis=ak, ), name="compute", ) func = te.create_prim_func([X, W, matmul]) for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]: sch = tir.Schedule(func, debug_mask="all") block = sch.get_block("compute") i, j, k = sch.get_loops(block) by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3)) bx, tx, xi = sch.split(j, factors=sch.sample_perfect_tile(j, n=3)) ko, ki = sch.split(k, [None, 4]) ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2)) sch.reorder(by, bx, ty, tx, yi, xi) CC = sch.cache_write(block, 0, "local") sch.reverse_compute_at(CC, tx) def fetch_to_shared(block, idx): block_read = sch.cache_read(block, idx, "shared") sch.compute_at(block_read, ko, True) return block_read fetch_to_shared(block, 0) fetch_to_shared(block, 1) sch.decompose_reduction(block, ko) sch.tensorize(ki, intrin) verify_trace_roundtrip(sch=sch, mod=func)
def test_tensor_layout_attr(): k = te.reduce_axis((0, 128), "k") A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") C = te.compute( (128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C", attrs={"layout_free_placeholders": [B]}, ) D = te.compute( (128, 128), lambda x, y: C[x, y] + 1, name="D", attrs={"layout_free_placeholders": [C]}, ) func = te.create_prim_func([A, B, D]) tvm.ir.assert_structural_equal(func, expected_layout_attr)
def get_matmul_packed(m, n, k, lhs_type, int32_lanes): X = te.placeholder((m, k), name="X", dtype=lhs_type) packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8") ak = te.reduce_axis((0, k), name="k") matmul = te.compute( (m, n), lambda i, j: te.sum( X[i, ak].astype("int32") * packed_W[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype("int32"), axis=ak, ), name="compute", ) return te.create_prim_func([X, packed_W, matmul])
def test_get_tensorize_loop_mapping_matmul_mma(): @T.prim_func def matmul_16x16x16xf16f16f16_desc( A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), ) -> None: with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) for i, j, k in T.grid(16, 16, 16): with T.block("update"): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] matmul = create_prim_func(te_workload.matmul_relu( n=512, m=512, k=512, )) s = Schedule(matmul) block = s.get_block("C") i0, i1, i2 = s.get_loops(block) desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) for do_reorder in [False, True]: # Mapping should be invariant to the loop permutation if do_reorder: s.reorder(i2, i0, i1) info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) assert info is not None desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) for i in range(3): assert desc_loops[i] in desc_loop_to_sref assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)
def test_rewrite_cooperative_fetch(): mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) target = _target() ctx = _create_context(mod, target) sch = tir.Schedule(mod, debug_mask="all") # fmt: off # pylint: disable=line-too-long,invalid-name b0 = sch.get_block(name="C", func_name="main") b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") l2, l3, l4 = sch.get_loops(block=b0) v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) l31 = sch.fuse(l10, l20) sch.bind(loop=l31, thread_axis="blockIdx.x") l32 = sch.fuse(l11, l21) sch.bind(loop=l32, thread_axis="vthread.x") l33 = sch.fuse(l12, l22) sch.bind(loop=l33, thread_axis="threadIdx.x") b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared") sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True) _, _, _, _, l39, l40 = sch.get_loops(block=b34) l41 = sch.fuse(l39, l40) _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True) _, _, _, _, l49, l50 = sch.get_loops(block=b44) l51 = sch.fuse(l49, l50) _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True) # pylint: enable=line-too-long,invalid-name # fmt: on sch.enter_postproc() assert ctx.postprocs[0].apply(sch) tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0)