def transformed_simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: for tx in T.thread_binding(0, 16, thread="threadIdx.x"): with T.block(): T.reads([A[tx, 0:16]]) T.writes([C[tx, 0:16]]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads([A[tx, 0]]) T.writes([B[0, tx, 0]]) B[0, tx, 0] = A[tx, 0] * T.float32(2) with T.block(): T.reads([A[tx, 1:16], B[0:2, tx, 0]]) T.writes([B[0:2, tx, 0], C[tx, 0:15]]) for i in T.serial(0, 15): with T.block(): T.reads([A[tx, i + 1]]) T.writes([B[(i + 1) % 2, tx, 0]]) B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) with T.block(): T.reads([B[i % 2, tx, 0]]) T.writes([C[tx, i]]) C[tx, i] = B[i % 2, tx, 0] + T.float32(1) with T.block(): T.reads([B[1, tx, 0]]) T.writes([C[tx, 15]]) C[tx, 15] = B[1, tx, 0] + T.float32(1)
def conv2d_nhwc_transformed( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227, Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ax0, ax_1, ax_2 in T.grid(12544, 64, 147): with T.block("conv2d_nhwc"): bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2]) T.reads( PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3], Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1], ) T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]) with T.init(): Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0) Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = ( Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 % 21 // 3, bv2 % 3] * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1])
def square_sum_square_root_factor_one_2_rfactor( A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16, ), "float32"]) -> None: C = T.alloc_buffer([16], dtype="float32") C_rf = T.alloc_buffer([16, 1], dtype="float32") for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C_rf"): b = T.axis.spatial(16, i0) i = T.axis.reduce(256, i1_i2_fused_inner // 256) j = T.axis.reduce(256, i1_i2_fused_inner % 256) vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer) with T.init(): C_rf[b, vi1_i2_fused_outer] = T.float32(0) C_rf[b, vi1_i2_fused_outer] = C_rf[ b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] for i0, i1_i2_fused_outer in T.grid(16, 1): with T.block("C"): b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] for i0_1 in T.serial(16): with T.block("D"): b_1 = T.axis.spatial(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def element_wise_storage_align(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): vi = T.axis.S(128, i0) vj = T.axis.S(128, ax1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]}) B[vi, vj] = A[vi, vj] * T.float32(2) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1)
def inline_block_with_init( A: T.Buffer[(1, 512, 7, 7), "float32"], B: T.Buffer[(1, 512, 1, 1), "float32"], ) -> None: B_rf = T.alloc_buffer([1, 512, 1, 1, 49], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(1, 512, 1, 1, 49, 1): with T.block("tensor_rf"): vi4 = T.axis.spatial(49, i4) ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(512, i1) ax2 = T.axis.spatial(1, 0) ax3 = T.axis.spatial(1, 0) with T.init(): B_rf[ax0, ax1, ax2, ax3, vi4] = T.float32(0) B_rf[ax0, ax1, ax2, ax3, vi4] = (B_rf[ax0, ax1, ax2, ax3, vi4] + A[ax0, ax1, ax2 * 7 + vi4 // 7, ax3 * 7 + vi4 % 7, ]) for i0, i1 in T.grid(1, 512): for ax0, ax1, ax2, ax3, ax4 in T.grid(49, 1, 1, 1, 1): with T.block("tensor"): vi4, ax0_1 = T.axis.remap("RS", [ax0, ax1]) ax1_1 = T.axis.spatial(512, i1 + ax2) ax2_1, ax3_1 = T.axis.remap("SS", [ax3, ax4]) with T.init(): B[ax0_1, ax1_1, ax2_1, ax3_1] = T.float32(0) B[ax0_1, ax1_1, ax2_1, ax3_1] = (B[ax0_1, ax1_1, ax2_1, ax3_1] + B_rf[ax0_1, ax1_1, ax2_1, ax3_1, vi4])
def get_valid_counts( data: T.handle, valid_count: T.handle, out: T.handle, out_indices: T.handle, score_threshold: T.float32, id_index: T.int32, score_index: T.int32, ) -> None: data_buf = T.match_buffer(data, (1, 2500, 6), "float32") valid_count_buf = T.match_buffer(valid_count, (1, ), "int32") out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") with T.block([1], "init") as [vi]: valid_count_buf[vi] = T.int32(0) with T.block([2500], "update") as [vj]: T.reads([data_buf[vi, vj, 6]]) T.writes([ valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6] ]) if (data_buf[vi, vj, score_index] > score_threshold) and ( (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0))): for k in T.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: for k in T.serial(0, 6): out_buf[vi, vj, k] = T.float32(-1) out_indices_buf[vi, vj] = T.int32(-1)
def main(a: T.handle, b: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(a, (128, 128, 4), dtype="float32", scope="global.texture") B = T.alloc_buffer((128, 128, 4), dtype="float32", scope="global.texture") C = T.match_buffer(b, (128, 128, 4), dtype="float32", scope="global.texture") for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): with T.block("B"): vb, vt, vk = T.axis.remap( "SSS", [block_idx, thread_idx, k]) B[vb, vt, vk] = A[vb, vt, vk] + T.float32(1) for block_idx in T.thread_binding(0, 128, thread="blockIdx.x"): for thread_idx in T.thread_binding(0, 128, thread="threadIdx.x"): for k in T.serial(4): with T.block("C"): vb, vt, vk = T.axis.remap( "SSS", [block_idx, thread_idx, k]) C[vb, vt, vk] = B[vb, vt, vk] * T.float32(2)
def simple_compute_conflicting_order(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 1, 1], }, ): with T.block(): T.reads(A[tx, i]) T.writes(D[tx, i]) B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") with T.block(): T.reads(A[tx, i]) T.writes(B[tx, 0]) B[tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) C[tx, 0] = B[tx, 0] + T.float32(2) with T.block(): T.reads(C[tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[tx, 0] + T.float32(1)
def conv2d_nhwc( Input: T.Buffer[(1, 224, 224, 3), "float32"], Weight: T.Buffer[(7, 7, 3, 64), "float32"], Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 < 227)), Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1], T.float32(0), dtype="float32", ) for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3): with T.block("conv2d_nhwc"): n, h, w, co, rh, rw, rc = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): Conv2d_nhwc[n, h, w, co] = T.float32(0) Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + ( PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ( (T.floordiv(co, 64) * 3) + rc)] * Weight[rh, rw, rc, co])
def element_wise(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj] * T.float32(2)) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def tiled_conv2d_with_padding( inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"], ) -> None: PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): with T.block("PadInput"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32", ) for ( i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, ) in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): with T.block("conv2d_nhwc"): n = T.axis.spatial(1, 0) h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) T.reads( conv2d_nhwc[n, h, w, co], PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co], ) T.writes(conv2d_nhwc[n, h, w, co]) with T.init(): conv2d_nhwc[n, h, w, co] = T.float32(0) conv2d_nhwc[n, h, w, co] = ( conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co])
def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], dtype="float32") B = T.match_buffer(b, [16], dtype="float32") B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.thread_binding(0, 16, thread="blockIdx.x"): for k0o in T.thread_binding(0, 4, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for k0i0, k1 in T.grid(4, 16): with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1] for k0i1 in T.serial(0, 4): with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[ 0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, k0o, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(16, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def main(X: T.Buffer[(1, 512, 56, 56), "float32"], W: T.Buffer[(512, 512, 3, 3), "float32"], B: T.Buffer[(512, 1, 1), "float32"], bn_scale: T.Buffer[(512, 1, 1), "float32"], bn_offset: T.Buffer[(512, 1, 1), "float32"], compute: T.Buffer[(1, 512, 56, 56), "float32"]) -> None: compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(224, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding( 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding( 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid( 1, 3, 1, 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): with T.block("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) xx = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) rc = T.axis.reduce(512, i4_1 * 16 + i4_2) ry, rx = T.axis.remap("RR", [i5_0, i6_2]) with T.init(): compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[ nn, ff, yy, xx] + T.if_then_else( yy + ry >= 1 and yy + ry < 57 and xx + rx >= 1 and xx + rx < 57, X[nn, rc, yy + ry - 1, xx + rx - 1], T.float32(0), dtype="float32") * W[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): with T.block("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute[v0, v1, v2, v3] = T.max( (compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0))
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "T.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") threadIdx_y = T.env_thread("threadIdx.y") blockIdx_x = T.env_thread("blockIdx.x") blockIdx_y = T.env_thread("blockIdx.y") blockIdx_z = T.env_thread("blockIdx.z") A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32") B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32") # body T.launch_thread(blockIdx_z, 196) B_local = T.allocate([64], "float32", "local") Apad_shared = T.allocate([512000], "float32", "shared") Apad_shared_local = T.allocate([8], "float32", "local") T.launch_thread(blockIdx_y, 8) T.launch_thread(blockIdx_x, 4) T.launch_thread(threadIdx_y, 8) T.launch_thread(threadIdx_x, 8) for ff_c_init, nn_c_init in T.grid(8, 8): B_local[ff_c_init * 8 + nn_c_init] = T.float32(0) for rc_outer, ry, rx in T.grid(32, 3, 3): for ax3_inner_outer in T.serial(0, 2): Apad_shared[T.ramp( threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4)] = T.if_then_else( 1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, A[T.ramp( ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4)], T.broadcast(T.float32(0), 4), dtype="float32x4", ) # Access of the last element of Apad_shared prevents # buffer compacting from reducing the amount of shared # memory used. Apad_shared[512000 - 1] = 0.0 for rc_inner in T.serial(0, 8): for ax3 in T.serial(0, 8): Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 + threadIdx_x * 8 + ax3] for ff_c, nn_c in T.grid(8, 8): B_local[ff_c * 8 + nn_c] = B_local[ff_c * 8 + nn_c] + Apad_shared_local[nn_c] for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 + nn_inner_inner_inner] # fmt: on
def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block([]) as []: with T.block([]) as []: B[0, 0] = A[0, 0] + T.float32(1) with T.block([128, 128]) as [vi, vj]: C[vi, vj] = B[vi, vj] + T.float32(1)
def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block([]) as []: with T.block([128, 128]) as [vi, vj]: B[vi, vj] = A[vi, vj] + T.float32(1) with T.block([128, 128]) as [vi, vj]: C[vi, vj] = B[vi, vj] + T.float32(1)
def expected_after(A: T.Buffer[128, "float32"], B: T.Buffer[130, "float32"]): for i, j in T.grid(2, 65): if i * 65 + j >= 0 and i * 65 + j < 128: A[i * 65 + j] = T.float32(0) for i, j in T.grid(2, 65): B[i * 65 + j] = T.if_then_else(i * 65 + j >= 0 and i * 65 + j < 128, A[i * 65 + j], T.float32(0), dtype="float32")
def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block(): with T.block(): B[0, 0] = A[0, 0] + T.float32(1) for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1)
def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4], }, ): with T.block(): T.reads(A[tx, i, 0:16]) T.writes(C[tx, i, 0:16]) A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") for j in T.serial(0, 16): with T.block(): T.reads(A[tx, i, j]) T.writes(A_shared[tx, 0, j]) A_shared[tx, 0, j] = A[tx, i, j] for j in T.serial(0, 16): with T.block(): T.block_attr({"double_buffer_scope": 0}) T.reads(A_shared[tx, 0, j]) T.writes(A_local[0, 0, j]) A_local[0, 0, j] = A_shared[tx, i, j] for j in T.serial( 0, 16, annotations={ "software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1], }, ): with T.block(): T.reads(A_local[0, 0, j]) T.writes(C[tx, i, j]) B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") with T.block(): T.reads(A_local[tx, i, j]) T.writes(B[tx, i, 0]) B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) with T.block(): T.reads(B[tx, i, 0]) T.writes(C[tx, i, j]) C[tx, i, j] = B[tx, i, 0] + T.float32(1)
def element_wise_set_scope(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B_shared[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B_shared[vi, vj] + T.float32(1)
def func_1(A: T.Buffer[(16,), "float32"], C: T.Buffer[(1,), "float32"]): for i in T.serial( 0, 16, ): with T.block(): B = T.alloc_buffer((1,), dtype="float32") with T.block(): B[0] = A[i] * T.float32(2) with T.block(): C[0] = C[0] + A[i] + B[0] + T.float32(1) A[i] = B[0] + T.float32(1)
def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): with T.block("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[ nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bias_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_mul"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0))
def element_wise_set_axis_separator( A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: B = T.alloc_buffer([128, 128], dtype="float32", axis_separators=[1]) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1)
def transformed_strided_buffer_func(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]) -> None: # body for i0 in T.serial(4): B = T.allocate([4, 17], "float32", "global") B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) for i1, j in T.grid(4, 16): B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) for i1, j in T.grid(4, 16): C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2)
def single_reduction_loop_with_block_predicate( A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") for i0 in T.serial(256): for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32( -3.4028234663852886e38) T_softmax_maxelem_shared[i0_1] = T.max( T_softmax_maxelem_shared[i0_1], A[i0_1, k]) for ax0, ax1_0 in T.grid(1, 1): for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_expsum"): i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) T_softmax_expsum_shared[ i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32") for i1_0 in T.serial(1): for i1_1 in T.thread_binding(512, thread="threadIdx.x"): with T.block("T_softmax_norm"): i0_3 = T.axis.spatial(256, i0) i1 = T.axis.spatial(256, i1_1) T.where(i1_0 * 512 + i1_1 < 256) T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]) T.writes(T_softmax_norm[i0_3, i1]) T.block_attr({"axis": 1}) T_softmax_norm[i0_3, i1] = ( T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") / T_softmax_expsum_shared[i0_3])
def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block(): for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(128, 128): with T.block(): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1)
def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) with T.block([]) as []: with T.block([128, 128]) as [vi, vj]: T.reads(A[vi, vj]) B[vi, vj] = A[vi, vj] + T.float32(1) with T.block([128, 128]) as [vi, vj]: T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1)
def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: B = T.alloc_buffer([128, 128], dtype="float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1)
def f(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [M, N]) B = T.match_buffer(b, [M, N]) C = T.match_buffer(c, [M, N]) with T.block(): for i, j in T.grid(M, N): with T.block("s1"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.float32(1) for i, j in T.grid(M, N): with T.block("s2"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + T.float32(1)
def buffer_opaque_access(b: T.handle, c: T.handle) -> None: B = T.match_buffer(b, [16, 16], "float32") C = T.match_buffer(c, [16, 16], "float32") with T.block([]): T.reads([]) T.writes(B[0:16, 0:16]) A = T.allocate([256], "float32", "global") for i, j in T.grid(16, 16): T.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): T.evaluate(T.load("float32", A, i * 16 + j)) for j in range(0, 16): T.evaluate( T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) for i, j in T.grid(16, 16): with T.block([16, 16]) as [vi, vj]: T.bind(vi, i) T.bind(vj, j) C[vi, vj] = B[vi, vj]