def test_extern_ops(): hcl.init() A = hcl.placeholder((10, 32), "A") def kernel(A): B = hcl.compute(A.shape, lambda *args: A[args] + 1, "B") C = hcl.compute(A.shape, lambda *args: B[args] + 1, "C") D = hcl.compute(A.shape, lambda *args: C[args] * 2, "D") return D target = hcl.Platform.aws_f1 s = hcl.create_schedule([A], kernel) s.to(kernel.B, target.xcel) s.to(kernel.C, target.host) code = str(hcl.lower(s)) print(code) assert "test(B, C)" in code
def inner_loop_tile(): hcl.init() A = hcl.placeholder((10, 32), "A") def kernel(A): C = hcl.compute(A.shape, lambda *args: A[args] * 4, "C") return C target = hcl.platform.aws_f1 s = hcl.create_schedule([A], kernel) stage = kernel.C yo, yi = s[stage].split(stage.axis[0], factor=3) xo, xi = s[stage].split(stage.axis[1], factor=3) s.to(kernel.C, target.xcel, axis=1) code = str(hcl.lower(s)) assert "test(args.outer, C, A)" in code
def move_outputs(): hcl.init() A = hcl.placeholder((10, 32), "A") def kernel(A): B = hcl.compute(A.shape, lambda i, j: A[i, j] * 2, "B") hcl.update(B, lambda i, j: B[i, j] + 1, "update1") hcl.update(B, lambda i, j: B[i, j] * 2, "update2") return B target = hcl.platform.aws_f1 s = hcl.create_schedule([A], kernel) s.to(A, target.xcel) s.to(kernel.update1.B, target.host) code = str(hcl.lower(s)) assert "test(A.channel, B.update.channel)" in code
def inter_stage_fork(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i, j] + B[i, j], "C") D = hcl.compute(C.shape, lambda i, j: C[i, j] + 1, "D") E = hcl.compute(C.shape, lambda i, j: C[i, j] * 2, "E") return D, E target = hcl.Platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to(kernel.C, [kernel.D, kernel.E]) code = str(hcl.lower(s)) assert "allocate C.pipe.1[int32 * 10 * 32]" in code assert "allocate C.pipe.2[int32 * 10 * 32]" in code
def test_inter_stage(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "C") D = hcl.compute(C.shape, lambda i, j: C[i][j], "D") return D target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to(kernel.C, s[kernel.D]) code = str(hcl.lower(s)) assert "C.pipe1.write" in code assert "C.pipe1.read" in code
def test_multiple_subgraph(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i,j] + 1, "C") D = hcl.compute(C.shape, lambda i, j: B[i,j] + 1, "D") return hcl.compute(C.shape, lambda i, j: C[i,j] + D[i,j], "E") target = hcl.Platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to([A, B], target.xcel) s.to([kernel.E], target.host) code = str(hcl.lower(s)) assert "io attr: \"B\"" in code assert "io attr: \"A\"" in code assert "io attr: \"E\"" in code
def inter_stage_join(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: 0, "C") hcl.update(C, lambda i, j: A[i,j] + 1, "s1") hcl.update(C, lambda i, j: B[i,j] * 2, "s2") return hcl.compute(C.shape, lambda *args: C[args] + 3, "ret") target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.join([kernel.s1.C, kernel.s2.C], kernel.ret.C) code = str(hcl.lower(s)) assert "C.pipe1.read" in code assert "C.pipe2.write" in code
def test_reuse_at_with_streaming(): hcl.init() A = hcl.placeholder((10, 10), name="A") def kernel(A): B = hcl.compute((10, 10), lambda y, x: A[y, x], "B") C = hcl.compute( (10, 8), lambda y, x: B[y, x] + B[y, x + 1] + B[y, x + 2], "C") return C s = hcl.create_schedule([A], kernel) target = hcl.platform.zc706 target.config(compile="vivado_hls", mode="csim") B_ = s.to(kernel.B, target.xcel) RB = s.reuse_at(B_, s[kernel.C], kernel.C.axis[1]) s.to(kernel.C, target.host) print(hcl.lower(s))
def get_relay_model(model, shape={}, frontend='keras', dtype=hcl.Float(), in_params=None): """ Parameters ---------- model : A machine learning framework model shape : dict The model's input shape frontend : str The machine learning framework the model comes from dtype : heterocl type The model's preferred data type in_params : The input parameters of the model if not included in the model """ out_var, out_type, out_env, params = relay_parser(model, shape, frontend) out_var = full_flatten(out_var) _param = gen_params(out_type, out_env) v_param = [holder for holder in _param if ("_param" in holder.name)] v_input = [holder for holder in _param if ("input" in holder.name)] v_param.sort( key=lambda x: int(''.join(filter(lambda i: i.isdigit(), x.name)))) v_input.sort( key=lambda x: int(''.join(filter(lambda i: i.isdigit(), x.name)))) _param = partial_flatten([v_input, v_param]) func = gen_func(_param, out_var, out_type, out_env) _inputs = [] if params is None: params = in_params for var in params: _inputs.append(hcl.asarray(params[var].asnumpy())) s = hcl.create_schedule(_param, func) if debug_mode: print(hcl.lower(s)) return hcl.build(s), _inputs
def test_merge_kernel_stages(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: 0, "C") hcl.update(C, lambda i, j: A[i, j] + 1, "s1") hcl.update(C, lambda i, j: B[i, j] * 2, "s2") return hcl.compute(C.shape, lambda *args: C[args] + 3, "ret") target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], kernel) A_, B_ = s.to([A, B], target.xcel) ret_ = s.to(kernel.ret, target.host) kernel = s.duplicate(inputs=[A_, B_], outputs=[ret_]) print(hcl.lower(s))
def test_compute_at_blur_x_with_data_placement(): hcl.init() A = hcl.placeholder((10, 10), name="A") def kernel(A): B = hcl.compute((10, 8), lambda y, x: A[y, x] + A[y, x+1] + A[y, x+2],name="B") C = hcl.compute((10, 8), lambda y, x: B[y, x], name="C") D = hcl.compute((10, 8), lambda y, x: C[y, x], name="D") return D s = hcl.create_schedule([A], kernel) target = hcl.Platform.xilinx_zc706 target.config(compiler="vivado_hls",mode="csim") s[kernel.B].compute_at(s[kernel.C], kernel.C.axis[1]) s.to(kernel.C, target.xcel) s.to(kernel.D, target.host) code = str(hcl.lower(s)) assert "test(D, C)" in code
def test_reuse_blur_x_with_data_placement(): hcl.init() A = hcl.placeholder((10, 10), name="A") def kernel(A): B = hcl.compute((10, 8), lambda y, x: A[y,x] + A[y,x+1] + A[y,x+2],name="B") C = hcl.compute((10, 8), lambda y, x: B[y,x], name="C") return C s = hcl.create_schedule([A], kernel) kernel_B = kernel.B target = hcl.Platform.xilinx_zc706 target.config(compiler="vivado_hls",mode="csim") # RB = s.reuse_at(A, s[kernel_B], kernel_B.axis[1]) s.to(kernel.B, target.xcel) s.to(kernel.C, target.host) print(hcl.lower(s)) f = hcl.build(s, target)
def test_compute_at_with_reuse_2D(): hcl.init() A = hcl.compute((10, 10), lambda y, x: x + y, "A") B = hcl.compute((8, 8), lambda y, x: A[y, x] + A[y+1, x+1] + A[y+2, x+2], "B") s = hcl.create_schedule([B]) s[A].compute_at(s[B], B.axis[1]) ir = hcl.lower(s) assert "allocate A[int32 * 3 * 3]" in str(ir) f = hcl.build(s) a_np = np.fromfunction(lambda i, j: i + j, A.shape, dtype="int") b_np = np.zeros(B.shape, dtype="int") c_np = np.zeros(B.shape, dtype="int") for y in range(0, 8): for x in range(0, 8): c_np[y][x] = a_np[y][x] + a_np[y+1][x+1] + a_np[y+2][x+2] b_hcl = hcl.asarray(b_np) f(b_hcl) np.testing.assert_array_equal(c_np, b_hcl.asnumpy())
def test_mixed_stream(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "C") D = hcl.compute(C.shape, lambda i, j: C[i][j], "D") return D target = hcl.Platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to([A, B], target.xcel) s.to(kernel.D, target.host) s.to(kernel.C, s[kernel.D]) code = str(hcl.lower(s)) assert "test(A, B, D)" in code assert "allocate C.pipe.1[int32 * 10 * 32]" in code
def test_stages_one_to_many(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "C") D = hcl.compute(C.shape, lambda i, j: C[i][j] + 1, "D") E = hcl.compute(C.shape, lambda i, j: C[i][j] * 2, "E") return D, E target = hcl.Platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to(kernel.C, s[kernel.D]) s.to(kernel.C, s[kernel.E]) code = str(hcl.lower(s)) print(code) assert "allocate C.pipe.1[int32 * 10 * 32]" in code assert "allocate C.pipe.2[int32 * 10 * 32]" in code
def test_tutorial(): hcl.init() A = hcl.placeholder((6, 6), "A") F = hcl.placeholder((3, 3), "F") def kernel(A, F): r = hcl.reduce_axis(0, 3) c = hcl.reduce_axis(0, 3) return hcl.compute( (4, 4), lambda y, x: hcl.sum(A[y + r, x + c] * F[r, c], axis=[r, c]), "B") # s = hcl.create_schedule([A, F], kernel) # print(hcl.lower(s)) s_x = hcl.create_schedule([A, F], kernel) WB = s_x.reuse_at(A, s_x[kernel.B], kernel.B.axis[1], "WB") print(hcl.lower(s_x))
def test_move_inputs(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") C = hcl.placeholder((10, 32), "C") D = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "D") E = hcl.compute(C.shape, lambda i, j: C[i][j] * D[i][j], "E") F = hcl.compute(C.shape, lambda i, j: E[i][j] + 1, "F") target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B, C, D, E, F]) s.to([A, B, C], target.xcel) s.to(E, target.host) code = str(hcl.lower(s)) pattern = "test({}.channel, {}.channel, {}.channel, E.channel)" combination = [ pattern.format(*_) for _ in list(permutations(["A", "B", "C"])) ] assert any([_ in code for _ in combination])
def test_compute_at_complex_num_axis(): hcl.init() A = hcl.placeholder((10, 20, 30), name="A") B = hcl.compute(A.shape, lambda i, j, m: A[i, j, m] * 2, name="B") C = hcl.compute(B.shape, lambda ii, jj, mm: B[ii, jj, mm] + 1, name="C") D = hcl.compute(C.shape, lambda iii, jjj, mmm: C[iii, jjj, mmm] % 3, name="D") s = hcl.create_schedule([A, D]) s[B].compute_at(s[C], 1) s[C].compute_at(s[D], 2) ir = hcl.lower(s) assert "allocate B[int32 * 1 * 1 * 30]" in str(ir) assert "allocate C[int32 * 1 * 1 * 1]" in str(ir) f = hcl.build(s) a_np = np.random.randint(low=0, high=100, size=A.shape) a_hcl = hcl.asarray(a_np) d_hcl = hcl.asarray(np.zeros(D.shape), dtype="int32") f(a_hcl, d_hcl) d_np = (a_np * 2 + 1) % 3 np.testing.assert_allclose(d_np, d_hcl.asnumpy())
def test_stencil_multi_stencil(): A = hcl.placeholder((10, 10), "A") def kernel(A): B = hcl.compute((10, 8), lambda y, x: A[y, x] + A[y, x + 1] + A[y, x + 2], "B") C = hcl.compute((8, 8), lambda y, x: B[y, x] + B[y + 1, x] + B[y + 2, x], "C") s = hcl.create_schedule(A, kernel) s[kernel.B].stencil(burst_width=256, unroll_factor=4) s[kernel.C].stencil(burst_width=128, unroll_factor=8) ir = str(hcl.lower(s)) assert "stencil burst_width=256 unroll_factor=4 num_iteration=1" in ir assert "inputs=[A]" in ir assert "outputs=[B]" in ir assert "stencil burst_width=128 unroll_factor=8 num_iteration=1" in ir assert "inputs=[B]" in ir assert "outputs=[C]" in ir
def test_dtype_struct(): hcl.init() A = hcl.placeholder((100, ), dtype=hcl.Int(8)) B = hcl.placeholder((100, ), dtype=hcl.Fixed(13, 11)) C = hcl.placeholder((100, ), dtype=hcl.Float()) def kernel(A, B, C): stype = hcl.Struct({ "fa": hcl.Int(8), "fb": hcl.Fixed(13, 11), "fc": hcl.Float() }) D = hcl.compute(A.shape, lambda x: (A[x], B[x], C[x]), dtype=stype) E = hcl.compute(A.shape, lambda x: D[x].fa, dtype=hcl.Int(8)) F = hcl.compute(A.shape, lambda x: D[x].fb, dtype=hcl.Fixed(13, 11)) G = hcl.compute(A.shape, lambda x: D[x].fc, dtype=hcl.Float()) # Check the data type assert D[0].fa.dtype == "int8" assert D[0].fb.dtype == "fixed13_11" assert D[0].fc.dtype == "float32" return E, F, G s = hcl.create_schedule([A, B, C], kernel) print(hcl.lower(s)) f = hcl.build(s) np_A = np.random.randint(0, 500, size=100) - 250 np_B = np.random.rand(100) - 0.5 np_C = np.random.rand(100) - 0.5 np_E = np.zeros(100) np_F = np.zeros(100) np_G = np.zeros(100) hcl_A = hcl.asarray(np_A, dtype=hcl.Int(8)) hcl_B = hcl.asarray(np_B, dtype=hcl.Fixed(13, 11)) hcl_C = hcl.asarray(np_C, dtype=hcl.Float()) hcl_E = hcl.asarray(np_E, dtype=hcl.Int(8)) hcl_F = hcl.asarray(np_F, dtype=hcl.Fixed(13, 11)) hcl_G = hcl.asarray(np_G, dtype=hcl.Float()) f(hcl_A, hcl_B, hcl_C, hcl_E, hcl_F, hcl_G) assert np.allclose(hcl_A.asnumpy(), hcl_E.asnumpy()) assert np.allclose(hcl_B.asnumpy(), hcl_F.asnumpy()) assert np.allclose(hcl_C.asnumpy(), hcl_G.asnumpy())
def systolic(m=16, k=16, n=16, dtype=hcl.Int(), target=None): hcl.init(dtype) dim_x, dim_y = 16, 16 A = hcl.placeholder((m, k), dtype=dtype, name="A") B = hcl.placeholder((k, n), dtype=dtype, name="B") def kernel(A, B): localA = hcl.compute((m, k - 1), lambda *args: 0, "localA") localB = hcl.compute((k - 1, n), lambda *args: 0, "localB") output = hcl.compute((m, n), lambda *args: 0, "output") def update(k, y, x): localA[y, x] = hcl.select(x > 0, localA[y, x - 1], A[y, k]) localB[y, x] = hcl.select(y > 0, localB[y - 1, x], B[k, x]) output[y, x] = hcl.select( k == 0, 0, output[y, x]) + localA[y, x] * localB[y, x] hcl.mutate((m, dim_y, dim_x), lambda k, y, x: update(k, y, x), name="update") return output s = hcl.create_schedule([A, B], kernel) k = kernel.update s[k].pipeline(k.axis[0]) # self loopback streaming s.to(k.localA, kernel.update) s.to(k.localB, kernel.update) # move to xcel scope if not host_only: s.to([A, B], target.xcel) s.to(k.output, target.host) print(hcl.lower(s)) f = hcl.build(s, target=target) return f
def test_one_stage_on_dev(): hcl.init() dtype = hcl.Float() M = 64 K = 64 N = 64 A = hcl.placeholder((M, K), "A", dtype=dtype) B = hcl.placeholder((K, N), "B", dtype=dtype) k = hcl.reduce_axis(0, K) def kernel(A, B): C = hcl.compute((M, N), lambda x, y: hcl.sum(A[x, k] * B[k, y], axis=k, dtype=dtype), "C", dtype=dtype) return C target = hcl.Platform.xilinx_zc706 target.config(compiler="vivado_hls", mode="csyn", project="gemm") s = hcl.create_schedule([A, B], kernel) s.to([A, B],target.xcel) s.to(kernel.C,target.host) print(hcl.lower(s))
def test_mixed_stream(): A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: A[i][j] + B[i][j], "C") D = hcl.compute(C.shape, lambda i, j: C[i][j], "D") return D target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], kernel) s.to([A, B], target.xcel) s.to(kernel.D, target.host) s.to(kernel.C, s[kernel.D]) code = str(hcl.lower(s)) pattern = "test({}.channel, {}.channel, D.channel)" combination = [pattern.format(*_) for _ in list(permutations(["A", "B"]))] cond = any([_ in code for _ in combination]) assert cond, code assert "C.pipe1.write" in code assert "C.pipe1.read" in code
def _test_sim(length): hcl.init(hcl.Int()) A = hcl.placeholder((length, ), name="A") B = hcl.placeholder((length, ), name="B") def math_func(A, B): res = hlib.ip.vadd_rtl(A, B, length) return hcl.compute(A.shape, lambda *args: res[args] * 2, "out") target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], math_func) s.to([A, B], target.xcel) s.to(math_func.out, target.host) # test ir correctness ir = str(hcl.lower(s)) pattern = "test({}.channel, {}.channel, out.channel)" combination = [ pattern.format(*_) for _ in list(permutations(["A", "B"])) ] assert any([_ in ir for _ in combination])
def _test_invalid_stream_pattern(): A = hcl.placeholder((10,), "A") def kernel(A): B = hcl.compute(A.shape, lambda i: A[i] + 1, "B") C = hcl.compute(B.shape, lambda i: hcl.select(i < 9, B[i] + B[i+1], B[i]),"C") return C target = hcl.Platform.aws_f1 s = hcl.create_schedule([A], kernel) s.to([A], target.xcel) s.to(kernel.C, target.host) s.to(kernel.B, s[kernel.C]) passed = False try: code = str(hcl.lower(s)) passed = True except: assert not passed
def test_inter_kernel_channels(): hcl.init() A = hcl.placeholder((10, 32), "A") C = hcl.placeholder((10, 32), "C") def kernel(A, C): B = hcl.compute((10, 32), lambda *args: 0, "B") @hcl.def_([(10, 32), (10, 32)]) def add(A, B): hcl.update(B, lambda *args: A[args] + 1) @hcl.def_([(10, 32), (10, 32)]) def mul(B, C): hcl.update(C, lambda *args: B[args] * 2) add(A, B) mul(B, C) s = hcl.create_schedule([A, C], kernel) s.to(kernel.mul.B, kernel.add.B, fifo_depth=10) code = str(hcl.lower(s)) print(code)
def test_stencil_stream(): hcl.init() A = hcl.placeholder((10, 10), "A") def stencil(A): B = hcl.compute((10, 8), lambda y, x: A[y, x] + A[y, x+1] + A[y, x+2], "B") C = hcl.compute((8, 8), lambda y, x: B[y, x] + B[y+1, x] + B[y+2, x], "C") return C target = hcl.Platform.aws_f1 target.config(compiler="vitis", mode="debug", backend="vhls") s = hcl.create_schedule([A], stencil) # create stencil node s[stencil.B].stencil(burst_width=256, unroll_factor=4) s[stencil.C].stencil(burst_width=128, unroll_factor=8) # compute offloading to FPGA s.to(A, target.xcel, mode=hcl.IO.DMA) s.to(stencil.C, target.host, mode=hcl.IO.Stream) code = hcl.lower(s) assert "C[0].write" in str(code)
def test_kernel(): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute((10, 32), lambda *args: 10) @hcl.def_([(10, 32), (10, 32)]) def add(A, B): hcl.update(B, lambda *args: A[args] + 1) @hcl.def_([(10, 32), (10, 32)]) def mul(B, C): hcl.update(C, lambda *args: B[args] * 2) add(A, B) mul(B, C) s = hcl.create_schedule([A, B], kernel) s.to(B, s[kernel.mul], s[kernel.add]) code = str(hcl.lower(s)) assert "c_buf_1.write" in code assert "c_buf_1.read" in code
def conv(): image = hcl.placeholder((batch_size, 1, 256, 256), "input_image") k1 = hcl.placeholder((1, 1, 3, 3), "kernel_1") k2 = hcl.placeholder((1, 1, 3, 3), "kernel_2") def kernel(input_image, kernel_1, kernel_2): # return tensor required (cannot do def_()) interm_shape = (1, 1, 254, 254) output_shape = (1, 1, 252, 252) # make compute wrapped in hcl def module1 = hcl.def_([input_image.shape, kernel_1.shape, interm_shape], name="conv1")(hlib.nn.conv2d_nchw_imp) module2 = hcl.def_([interm_shape, kernel_2.shape, output_shape], name="conv2")(hlib.nn.conv2d_nchw_imp) conv1 = hcl.compute(interm_shape, lambda *args: 0) conv2 = hcl.compute(output_shape, lambda *args: 0) module1(input_image, kernel_1, conv1) module2(conv1, kernel_2, conv2) # derivative module for normalization return hcl.compute(output_shape, lambda *args: conv2[args], name="derv") s = hcl.create_schedule([image, k1, k2], kernel) # data moved to local i0, k10, k20 = s.to([image, k1, k2], target.fpga) # s.to([i0, k10], s[kernel.conv1]) # s.to([k20], s[kernel.conv2]) s.to(kernel.derv, target.cpu) # create stream channel between modules print(type(target.fpga), hcl.lower(s)) return hcl.build(s, target)
def test_extract_subgraph(combine=False): hcl.init() A = hcl.placeholder((10, 32), "A") B = hcl.placeholder((10, 32), "B") def kernel(A, B): C = hcl.compute(A.shape, lambda i, j: 0, "C") hcl.update(C, lambda i, j: A[i, j] + 1, "s1") hcl.update(C, lambda i, j: B[i, j] * 2, "s2") return hcl.compute(C.shape, lambda *args: C[args] + 3, "ret") target = hcl.platform.aws_f1 s = hcl.create_schedule([A, B], kernel) A_, B_ = s.to([A, B], target.xcel) ret_ = s.to(kernel.ret, target.host) # combine and split if combine == True: # merge the channel stages into s[A_].compute_at(s[B_], 1) s[B_].compute_at(s[kernel.C], 1) # merge stages from top to bottom s[kernel.C].compute_at(s[kernel.s1], kernel.s1.axis[1]) s[kernel.s1].compute_at(s[kernel.s2], kernel.s2.axis[1]) s[kernel.s2].compute_at(s[kernel.ret], kernel.ret.axis[1]) ret_s = s.placement[kernel.ret.name][0] s[kernel.ret].compute_at(ret_s, ret_s.op.axis[1]) # split along the first axis ret_s.split(ret_s.op.axis[0], factor=2) nodes = s.subgraph(inputs=[A_, B_], outputs=[ret_]) code = str(hcl.lower(s))