def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") T_softmax_expsum = T.alloc_buffer([256], dtype="float32") for i0, i1 in T.grid(256, 256): with T.block("T_softmax_maxelem"): i0_1, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_maxelem[i0_1] = T.min_value("float32") T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) for i0, i1 in T.grid(256, 256): with T.block("T_softmax_expsum"): i0_2, k = T.axis.remap("SR", [i0, i1]) with T.init(): T_softmax_expsum[i0_2] = T.float32(0) T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") for i0_3, i1 in T.grid(256, 256): with T.block("T_softmax_norm"): i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) T_softmax_norm[i0_4, i1_1] = T.exp( A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4]
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k] ) for ax0_0 in T.serial(0, 8): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] ) T.writes([T_softmax_expsum_shared[i0_2]]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" ) for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads( [ A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3], ] ) T.writes([T_softmax_norm[i0_3, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp( A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32", ) / T_softmax_expsum_shared[i0_3] )
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def exp_exp_opaque_access_with_tvm_access_ptr_inlined( lookup_table: T.Buffer[(1024,), "int8"], x: T.Buffer[(16,), "float16"], compute: T.Buffer[(16,), "float16"], ) -> None: for i0 in T.serial(16): with T.block("compute_1"): i0_1 = T.axis.spatial(16, i0) # Do not put the opaque access to new write region when opaque access # wrapped with a tvm_access_ptr and the access mask set to "read only" T.reads(x[i0_1], lookup_table[0:1024]) T.writes(compute[i0_1]) compute[i0_1] = T.exp( T.exp(x[i0_1], dtype="float16"), lookup_table.access_ptr("r"), dtype="float16", )
def exp_exp_opaque_access_with_tvm_access_ptr( lookup_table: T.Buffer[(1024,), "int8"], x: T.Buffer[(16,), "float16"], compute: T.Buffer[(16,), "float16"], ) -> None: compute_1 = T.alloc_buffer([16], dtype="float16") for i0 in T.serial(16): with T.block("compute"): i0_1 = T.axis.spatial(16, i0) T.reads(x[i0_1]) T.writes(compute_1[i0_1]) compute_1[i0_1] = T.exp(x[i0_1], dtype="float16") for i0 in T.serial(16): with T.block("compute_1"): i0_2 = T.axis.spatial(16, i0) T.reads(compute_1[i0_2], lookup_table[0:1024]) T.writes(compute[i0_2]) compute[i0_2] = T.exp( compute_1[i0_2], lookup_table.access_ptr("r"), dtype="float16", )
def different_access_indices(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([B[vi, vj], A[vi, vj, vk]]) T.writes([ B[T.min(vj, vi):T.min(vj, vi) # type: ignore[misc] + (T.max(vj, vi) + 1 - T.min(vj, vi)), T.min(vi, vj):T.min(vi, vj) # type: ignore[misc] + (T.max(vi, vj) + 1 - T.min(vi, vj)), ] ]) with T.init(): B[vj, vi] = T.exp(B[vj, vi], dtype="float32") B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
def func_rewritten(A: T.Buffer[(8, ), "float32"]) -> None: B = T.allocate((1, ), "float32", "global") for i in range(8): B[0] = 3.14 x: T.float32 = T.exp(B[0], dtype="float32") A[i] = (x + 1.0) / (x - 1.0)
def func(A: T.Buffer[(8, ), "float32"]): for i in range(8): B = T.allocate((1, ), "float32", "global") B[0] = 3.14 x: T.float32 = T.exp(B[0], dtype="float32") A[i] = (x + 1.0) / (x - 1.0)
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: A = T.match_buffer(var_A, [256, 256], dtype="float32") T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32") T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.thread_binding(0, 256, thread="blockIdx.x"): for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_maxelem_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.min_value("float32") for ax0_0 in T.serial(0, 8): with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([A[i0_1, k], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.min_value("float32")]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ax0_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_expsum_normal_reduction_init"): T.reads([]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = T.float32(0) for ax0_0 in T.serial(0, 8): with T.block("T_softmax_expsum_normal_reduction"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads([ A[i0_3, k], T_softmax_maxelem_shared[i0_3], normal_reduce_temp1[0], ]) T.writes([normal_reduce_temp1[0]]) normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread_reduction"): T.reads([normal_reduce_temp1[0]]) T.writes([reduce_temp1[0]]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp1[0], True, reduce_temp1.data, ax0_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] for i1_0 in T.serial(0, 8): for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_0 * 32 + i1_1) T.reads([ A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5], ]) T.writes([T_softmax_norm[i0_5, i1]]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = (T.exp( A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32", ) / T_softmax_expsum_shared[i0_5])
def lowered_single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem_in_thread_init"): T.reads() T.writes(in_thread_0[0]) in_thread_0[0] = T.float32(-3.4028234663852886e38) with T.block("T_softmax_maxelem_in_thread"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_1, k], in_thread_0[0]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): T.reads(in_thread_0[0]) T.writes(cross_thread_0[0]) T.attr( T.comm_reducer(lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_0[0], True, cross_thread_0.data, ax1_1, dtype="handle", )) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum_in_thread_init"): T.reads() T.writes(in_thread_1[0]) in_thread_1[0] = T.float32(0) with T.block("T_softmax_expsum_in_thread"): i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32") with T.block("T_softmax_expsum_cross_thread"): T.reads(in_thread_1[0]) T.writes(cross_thread_1[0]) T.attr( T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), in_thread_1[0], True, cross_thread_1.data, ax1_1, dtype="handle", )) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_5 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]) T.writes(T_softmax_norm[i0_5, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_5, i1] = ( T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") / T_softmax_expsum_shared[i0_5])