def matmul(): # Algorithm k = tvm.reduce_axis((0, K), 'k') A = tvm.placeholder((M, K), name='A') B = tvm.placeholder((K, N), name='B') ##### define space begin ##### cfg = autotvm.get_config() cfg.define_split("tile_x", M, num_outputs=3) cfg.define_split("tile_y", N, num_outputs=3) cfg.define_split("tile_k", K, num_outputs=2) ##### define space end ##### # We have to re-write the algorithm slightly. bn = cfg["tile_y"].size[-1] packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB') C = tvm.compute( (M, N), lambda x, y: tvm.sum(A[x, k] * packedB[tvm.div(y, bn), k, y % bn], axis=k), name='C') s = tvm.create_schedule(C.op) x, y = s[C].op.axis k, = s[C].op.reduce_axis # schedule according to config # Allocate write cache CC = s.cache_write(C, 'global') xt, xo, xi = cfg["tile_x"].apply(s, C, x) yt, yo, yi = cfg["tile_y"].apply(s, C, y) s[C].reorder(xt, yt, xo, yo, xi, yi) xyt = s[C].fuse(xt, yt) # parallel s[C].parallel(xyt) xyo = s[C].fuse(xo, yo) s[C].unroll(xi) s[C].vectorize(yi) # Write cache is computed at xyo s[CC].compute_at(s[C], xyo) # New inner axes xc, yc = s[CC].op.axis k, = s[CC].op.reduce_axis ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, xc, ki, yc) s[CC].unroll(xc) s[CC].unroll(ki) s[CC].vectorize(yc) x, y, z = s[packedB].op.axis s[packedB].vectorize(z) s[packedB].parallel(x) return s, [A, B, C]
def check_div(start, end, divisor, dtype): T = tvm.compute((end - start, ), lambda i: tvm.div( tvm.expr.Cast(dtype, (start + i)), tvm.const(divisor, dtype))) s = tvm.create_schedule([T.op]) f = tvm.build(s, [T], "llvm") a = tvm.nd.empty((end - start, ), dtype) f(a) ref = [int(float(i) / divisor) for i in range(start, end)] tvm.testing.assert_allclose(a.asnumpy(), ref)
def test_reduce_simplify(): ck = CanonicalChecker() k = tvm.reduce_axis((0, 10), name="k") j = tvm.reduce_axis((-5, 3), name="j") A = tvm.placeholder((10, ), name='A') ck.verify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j]), tvm.sum(k + j, [k, j])) ck.verify(tvm.sum(A[3], []), A[3]) # The rule below is not typical, removed for now ck.verify(tvm.sum(tvm.div(k, 10), k), tvm.sum(tvm.const(0, "int32"), k))
def check_llvm_reciprocal(n): A = tvm.placeholder((n,), name='A') B = tvm.compute((n,), lambda i: tvm.div(1.0,(1e+37*A[i])), name='B') s = tvm.create_schedule(B.op) f = tvm.build(s, [A, B], "llvm") a = tvm.nd.array(np.full((n,), 100, 'float32')) b = tvm.nd.empty((n,), 'float32') f(a, b) tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))
def test_average_pool(): for i in range(5): N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)] (input_dtype, acc_dtype) = random_dtypes() D = tvm.placeholder((N, CI, H, W), dtype=input_dtype) KH = min(H, KH) KW = min(W, KW) kh = tvm.reduce_axis((0, KH)) kw = tvm.reduce_axis((0, KW)) OH = (H - KH) + 1 OW = (W - KW) + 1 C = tvm.compute( (N, CO, OH, OW), lambda n, co, h, w: tvm.sum(tvm.div( D[n][co][h + kh][w + kw].astype(acc_dtype), (KW * KH)), axis=[kh, kw])) s = tvm.create_schedule([C.op]) assert compute_flop(s) == 2 * N * CO * OH * OW * KH * KW