def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16]) C = tir.alloc_buffer([16, 16]) D = tir.alloc_buffer([16, 16]) E = tir.alloc_buffer([16, 16]) F = tir.match_buffer(f, [16, 16]) C_rf = tir.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: tir.bind(vk1o, k1o) tir.bind(ci, i) tir.bind(cj, j1) tir.bind(vk1i, k1i) with tir.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in tir.serial(0, 16): for j1_1 in tir.serial(0, 16): for k1o_1 in tir.serial(0, 4): with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: tir.bind(vk1o_1, k1o_1) tir.bind(ci_1, i_1) tir.bind(cj_1, j1_1) with tir.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i_1) tir.bind(dj, j1_1) tir.bind(dk, (k2o * 4) + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i_1) tir.bind(ej, j2) tir.bind(ek, (k3o * 4) + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i_1) tir.bind(fj, j2) tir.bind(fk, (k4o * 4) + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
def square_sum(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) C = tir.match_buffer(c, [16]) with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: with tir.init(): C[b] = 0.0 C[b] = C[b] + A[b, i, j] * A[b, i, j]
def main(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [64, 64, 64]) B = tir.match_buffer(b, [64]) with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: if (j == 0) and (k == 32): B[i] = tir.float32(0) B[i] += A[i, j, k]
def main(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [64, 64, 64]) B = tir.match_buffer(b, [64]) with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: BB = tir.match_buffer(B[i], ()) AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) if (j == 0) and (k == 32): BB[()] = tir.float32(0) BB[()] += AA[j, k]
def transformed_func() -> None: A = tir.alloc_buffer([128, 128]) with tir.block([128, 128], "") as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: B = tir.alloc_buffer([128, 128]) if k == 0: for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): with tir.block([], ""): tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) C = tir.alloc_buffer([128, 128]) for kk in tir.serial(0, 4): B[((i * 4) + ii), ((j * 4) + jj)] = (B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]) for kk in tir.serial(0, 4): with tir.block([], ""): tir.reads([ B[((i * 4) + ii), ((j * 4) + jj)], C[((i * 4) + ii), ((k * 4) + kk)], ]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) D = tir.alloc_buffer([128, 128]) B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (D[((j * 4) + jj), ( (k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)])
def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: with tir.init(): B[vi] = 0.0 B[vi] = B[vi] - A[vi, vk]
def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128]) B = tir.match_buffer(b, []) with tir.block([tir.reduce_axis(0, 128)], "B") as [k]: with tir.init(): B[()] = 0.0 B[()] = B[()] + A[k]
def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, (16, 16, 16)) C = tir.alloc_buffer((16, 16)) D = tir.alloc_buffer((16, 16)) E = tir.alloc_buffer((16, 16)) F = tir.match_buffer(f, (16, 16)) for i in tir.serial(0, 16): for j1 in tir.serial(0, 16): for k1o, k1i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: tir.bind(ci, i) tir.bind(cj, j1) tir.bind(ck, k1o * 4 + k1i) with tir.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i) tir.bind(dj, j1) tir.bind(dk, k2o * 4 + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i) tir.bind(ej, j2) tir.bind(ek, k3o * 4 + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i) tir.bind(fj, j2) tir.bind(fk, k4o * 4 + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: with tir.init(): B[vi, vk] = 0.0 B[vi, vk] = B[vi, vk] + A[vi, vk]
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: with tir.init(): C[vi, vj] = tir.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: with tir.init(): C[i, j] = 0.0 C[i, j] += A[i, k] * B[j, k]
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) for i, j in tir.grid(128, 128): with tir.block([128, 128], "init") as [vi, vj]: C[vi, vj] = tir.float32(0) for k in range(0, 128): with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def tir_conv2d(a: ty.handle, w: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 14, 14]) W = tir.match_buffer(w, [16, 3, 3, 32]) B = tir.match_buffer(b, [16, 32, 14, 14]) Apad = tir.alloc_buffer([16, 16, 16, 16]) with tir.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: Apad[nn, cc, yy, xx] = tir.if_then_else( yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, A[nn, cc, yy - 1, xx - 1], 0.0, dtype="float32", ) with tir.block( [16, 32, 14, 14, tir.reduce_axis(0, 16), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "B" ) as [nn, ff, yy, xx, rc, ry, rx]: with tir.init(): B[nn, ff, yy, xx] = 0.0 B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff]
def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: with tir.init(): C[vi, vj] = 0.0 C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj]
def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: m = tir.var("int32") A = tir.match_buffer(a, [m, 128]) B = tir.match_buffer(b, [m, 128]) C = tir.match_buffer(c, [m, m]) with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: with tir.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for io, ii_ko_fused, ki in tir.grid(32, 128, 4): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def batch_matmul( # pylint: disable=no-self-argument a: ty.handle, b: ty.handle, c: ty.handle) -> None: tir.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True}) A = tir.match_buffer(a, [16, 128, 128]) B = tir.match_buffer(b, [16, 128, 128]) C = tir.match_buffer(c, [16, 128, 128]) with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: with tir.init(): C[vn, vi, vj] = 0.0 C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
def matmul( # pylint: disable=no-self-argument a: ty.handle, b: ty.handle, c: ty.handle) -> None: tir.func_attr({"global_symbol": "matmul", "tir.noalias": True}) A = tir.match_buffer(a, (1024, 1024), "float32") B = tir.match_buffer(b, (1024, 1024), "float32") C = tir.match_buffer(c, (1024, 1024), "float32") with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: with tir.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for i, k in tir.grid(128, 16): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, i) tir.bind(vk, tir.floordiv(k * k, 2)) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable A = tir.match_buffer(a, [2048, 2048], "float32") B = tir.match_buffer(b, [2048, 2048], "float32") C = tir.match_buffer(c, [2048, 2048], "float32") A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") with tir.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with tir.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in tir.thread_binding(0, 32, thread="blockIdx.y"): for bx in tir.thread_binding(0, 32, thread="blockIdx.x"): for vy in tir.thread_binding(0, 2, thread="vthread.y"): for vx in tir.thread_binding(0, 2, thread="vthread.x"): for ty in tir.thread_binding(0, 8, thread="threadIdx.y"): for tx in tir.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in tir.serial(0, 256): for k_1 in tir.unroll(0, 8): for _, i, j in tir.grid(1, 4, 4): with tir.block([ 2048, 2048, tir.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: tir.bind( vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) tir.bind(vk, k_0 * 8 + k_1) with tir.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in tir.grid(4, 4): with tir.block([2048, 2048], "C_local") as [vi, vj]: tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for i in tir.serial(0, 128): for k in tir.parallel(0, 128): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, i) tir.bind(vk, k) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) C_rf = tir.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid( 128, 128, 4, 8, 4): with tir.block( [4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf") as [ vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer ]: tir.bind(vi2_inner_inner, i2_inner_inner) tir.bind(vi, i0) tir.bind(vj, i1) tir.bind(vi2_outer, i2_outer) tir.bind(vi2_inner_outer, i2_inner_outer) with tir.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]) for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [ vi2_inner_inner_1, vi_1, vj_1, ]: tir.bind(vi2_inner_inner_1, i2_inner_inner_1) tir.bind(vi_1, i0_1) tir.bind(vj_1, i1_1) with tir.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) C = tir.match_buffer(c, [16]) C_rf = tir.alloc_buffer([16, 256]) for i0, i1, i2 in tir.grid(16, 256, 256): with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: tir.bind(vi2, i2) tir.bind(b, i0) tir.bind(i, i1) with tir.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in tir.grid(16, 256): with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: tir.bind(vi2_1, i2_1) tir.bind(b_1, i0_1) with tir.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16], "float32") B = tir.match_buffer(b, [16], "float32") B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") for j in tir.thread_binding(0, 16, thread="blockIdx.x"): for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"): for i_i, k in tir.grid(4, 16): with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: tir.bind(vi, i_o * 4 + i_i) tir.bind(vj, j) tir.bind(vk, k) with tir.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in tir.serial(0, 4): with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: tir.bind(vi, j) tir.bind(vk, i_o * 4 + k) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [256, 256]) B = tir.match_buffer(b, [256, 256]) D = tir.match_buffer(d, [256, 256]) C = tir.alloc_buffer([256, 256]) with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: with tir.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] with tir.block([256, 256], "D") as [vi, vj]: D[vi, vj] = C[vi, vj]
def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) D = tir.match_buffer(d, [16]) C = tir.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): with tir.block( [16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: tir.bind(b, i0) tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) tir.reads([C[b], A[b, i, j]]) tir.writes([C[b]]) with tir.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in tir.serial(0, 16): with tir.block([16], "D") as [b_1]: tir.bind(b_1, i0_1) tir.reads([C[b_1]]) tir.writes([D[b_1]]) D[b_1] = tir.sqrt(C[b_1], dtype="float32")
def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128]) B = tir.match_buffer(b, []) B_rf = tir.alloc_buffer([128]) with tir.block([128], "B_rf") as [vi0]: with tir.init(): B_rf[vi0] = 0.0 B_rf[vi0] = B_rf[vi0] + A[vi0] with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]: with tir.init(): B[()] = 0.0 B[()] = B[()] + B_rf[vi0_1]
def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) D = tir.match_buffer(d, [16]) C = tir.alloc_buffer([16]) C_rf = tir.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): with tir.block( [1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [ vi1_i2_fused_inner, b, i, j, ]: tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) tir.bind(b, i0) tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) with tir.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) tir.bind(b_1, i0_1) with tir.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in tir.serial(0, 16): with tir.block([16], "D") as [b_2]: tir.bind(b_2, i0_2) D[b_2] = tir.sqrt(C[b_2], dtype="float32")
def matmul_loop_multiple_children(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) D = tir.match_buffer(d, [128, 128]) for k, i, j in tir.grid(128, 128, 128): with tir.block([tir.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: tir.bind(ck, k) tir.bind(ci, i) tir.bind(cj, j) with tir.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] with tir.block([tir.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: tir.bind(dk, k) tir.bind(di, i) tir.bind(dj, j) with tir.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") B = tir.match_buffer(b, (128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.alloc_buffer((128, 128), "float32") with tir.block([128, 128]) as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: with tir.init(): for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]