def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128], dtype="float32") B = T.match_buffer(b, [], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vk = T.axis.reduce(128, k) T.reads([A[vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) B[()] = reduce_temp0[0]
def main(A: T.handle, tensor: T.handle) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) # buffer definition tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "") for ax1_outer in T.serial(0, 2): T.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "") T.attr(tensor_2, "rolling_buffer_scope", True) for ax1 in T.serial(0, 6): for ax2 in T.serial(0, 12): for ax3 in T.serial(0, 16): tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.int8(0) for dh in T.serial(0, 3): for dw in T.serial(0, 3): tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3]) for ax1_inner in T.serial(0, 4): for ax2_inner in T.serial(0, 8): for ax3_inner in T.serial(0, 16): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0) for dh_1 in T.serial(0, 3): for dw_1 in T.serial(0, 5): tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner])
def ptx_global_to_shared_dyn_copy_fp16x8( A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"], C: T.Buffer[(32, 128), "float16"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") T.reads(A[0:32, 0:128], B[0:32, 0:128]) T.writes(C[0:32, 0:128]) T.attr("default", "async_scope", 1) for i in T.serial(16): for j in T.vectorized(8): A_shared[tx, i * 8 + j] = A[tx, i * 8 + j] B_shared[tx, i * 8 + j] = B[tx, i * 8 + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
def lowered_reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B_cross_thread_reduction"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, k, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): with T.block("B_cross_thread_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([A[vi, vk]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce(T.uint32(1), A[vi, vk], True, reduce_temp0.data, ko, ki, dtype="handle")) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_run_model", "runner_function": True }) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) sid_9 = T.allocate([301056], "int8", "global") sid_8 = T.allocate([802816], "int8", "global") T.evaluate( T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) T.evaluate( T.call_extern( "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) T.evaluate( T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32"))
def undefined_buffer(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.attr(A, "realize_scope", "") T.realize(C[0:16, 0:16], "") # error for i in T.serial(16): for j in T.serial(0, 16): A[i, j] = 0.0
def range_missing_args(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.attr(A, "realize_scope", "") T.realize(A[0:16, 0:16], "") for i in T.serial(16): # error for j in T.serial(0, 16): A[i, j] = 0.0
def unsupported_function_call(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.attr(A, "realize_scope", "") T.realize(A[0:16, 0:16], "") for i in T.const_range(16): # error for j in T.serial(0, 16): A[i, j] = 0.0
def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[ 0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, k0o, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(16, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def main(placeholder1: T.Buffer[(100, ), "int8"], placeholder2: T.Buffer[(100, ), "int8"]) -> None: T.attr("i0", "pragma_layout", "NHCWB16") for i0 in T.serial(0, 1): for i1 in T.serial(0, 1): for i2 in T.serial(0, 1): for i3 in T.serial(0, 6): for i4 in T.serial(0, 16): placeholder1[((i3 * 16) + i4)] = placeholder2[((T.floormod( (i3 + 4), 6) * 16) + i4)]
def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "tvmgen_default_run_model", "runner_function": True }) # body T.attr("default", "device_id", 0) T.attr("default", "device_type", 1) sid_2 = T.allocate([720000], "int8", "global") sid_6 = T.allocate([5760000], "int8", "global") sid_7 = T.allocate([720000], "int8", "global") sid_8 = T.allocate([720000], "int8", "global") T.evaluate( T.call_extern( "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2.data, dtype="int32")) T.evaluate( T.call_extern( "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2.data, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8.data, dtype="int32")) T.evaluate( T.call_extern( "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8.data, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7.data, dtype="int32")) T.evaluate( T.call_extern( "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7.data, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6.data, dtype="int32")) T.evaluate( T.call_extern( "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2.data, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6.data, output, dtype="int32"))
def main(placeholder1: T.Buffer[(100, ), "int8"], placeholder2: T.Buffer[(100, ), "int8"]) -> None: T.attr("i0", "pragma_layout", "NHWC") for i0 in T.serial(0, 1): for i1 in T.serial(0, 5): for i2 in T.serial(0, 6): for i3 in T.serial(0, 4): placeholder1[( ((i1 * 24) + (i2 * 4)) + i3)] = placeholder2[(((((T.floordiv( (i1 - 1), 2) * 48) + (T.floormod( (i1 + 1), 2) * 24)) + (i2 * 4)) + i3) + 96)]
def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None: # function attr dict T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) v1a = T.var("int32") v1c = T.var("int32") v2a = T.var("int32") v2c = T.var("int32") v3a = T.var("int32") v3c = T.var("int32") v4a = T.var("int32") v4c = T.var("int32") buffer1 = T.buffer_decl([8192], "int8") buffer10 = T.buffer_decl([2048], "int8") # body p4 = T.allocate([160], "uint8", "global") p7 = T.allocate([144], "uint8", "global") p10 = T.allocate([144], "uint8", "global") p11 = T.allocate([144], "uint8", "global") with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201): T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle")) with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205): T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle")) with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209): T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle")) with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213): T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle")) with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303): T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0[0], ki, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def ptx_global_to_shared_copy_fp32x1( A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128), "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) T.attr("default", "async_scope", 1) for i in T.serial(128): A_shared[tx, i] = A[tx, i] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i]
def main(A_param: T.handle, C_param: T.handle): A = T.match_buffer(A_param, (400,), "float32", strides=[1]) C = T.match_buffer(C_param, (4,), "float32", strides=[1]) T.func_attr({"from_legacy_te_schedule": True}) threadIdx_x = T.env_thread("threadIdx.x") T.launch_thread(threadIdx_x, 1) for i in T.serial(0, 100): B = T.allocate([4], "float32", scope="shared", strides=[1]) with T.attr(B.data, "double_buffer_scope", 1): for j in T.serial(0, 4): B[j] = A[4 * i + j] for j in T.serial(0, 4): C[j] = B[j] + 1.0
def lowered_single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) with T.block("T_softmax_maxelem_in_thread"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_1, k], in_thread_0[0]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_0[0], True, cross_thread_0.data, ax1_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) with T.block("T_softmax_expsum_in_thread"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_1[0], True, cross_thread_1.data, ax1_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5])
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([ A[i0_3, k], T_softmax_maxelem_shared[i0_3], normal_reduce_temp1[0], ]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads([ A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5], ]) T.writes([T_softmax_norm[i0_5, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = (T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32", ) / T_softmax_expsum_shared[i0_5])