def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block("C_rf"): vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): with T.block("C"): vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): with T.block("D"): b_2 = T.axis.S(16, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32")
def square_sum_square_root_factor_one_2_rfactor( A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16, ), "float32"]) -> None: C = T.alloc_buffer([16], dtype="float32") C_rf = T.alloc_buffer([16, 1], dtype="float32") for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C_rf"): b = T.axis.spatial(16, i0) i = T.axis.reduce(256, i1_i2_fused_inner // 256) j = T.axis.reduce(256, i1_i2_fused_inner % 256) vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer) with T.init(): C_rf[b, vi1_i2_fused_outer] = T.float32(0) C_rf[b, vi1_i2_fused_outer] = C_rf[ b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] for i0, i1_i2_fused_outer in T.grid(16, 1): with T.block("C"): b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] for i0_1 in T.serial(16): with T.block("D"): b_1 = T.axis.spatial(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1, ), "float32"]) -> None: C = T.alloc_buffer([1], dtype="float32") for i0, i1, i2 in T.grid(1, 256, 256): with T.block("C"): b, i, j = T.axis.remap("SRR", [i0, i1, i2]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0 in T.serial(1): with T.block("D"): b = T.axis.S(1, i0) D[b] = T.sqrt(C[b], dtype="float32")
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1, ), "float32"]) -> None: C = T.alloc_buffer([1], dtype="float32") for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for i1, i2 in T.grid(256, 256): with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) D[b] = T.sqrt(C[b], dtype="float32")
def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: T.bind(b, i0) T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) T.bind(j, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block([16], "D") as [b_1]: T.bind(b_1, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block( [1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ vi1_i2_fused_inner, b, i, j, ]: T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) T.bind(b, i0) T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) T.bind(j, T.floormod(i1_i2_fused_outer, 256)) with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) T.bind(b_1, i0_1) with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] for i0_2 in T.serial(0, 16): with T.block([16], "D") as [b_2]: T.bind(b_2, i0_2) D[b_2] = T.sqrt(C[b_2], dtype="float32")