def test_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init") x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans") s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1") s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) def test_getbody(): body = tvm.schedule.ScanGetBody(s_scan.op) assert set(body) == set([s_scan.op, s_update.op, s_up1.op]) def test_attach_path(): s = tvm.create_schedule(s_scan.op) s[x_trans].compute_at(s[s_update], s_update.op.axis[0]) apath = tvm.schedule.CreateAttachPath(s) assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])) assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis])) def test_fix_pt(): body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
def test_lstm_cell_inline(): num_step = 128 num_input = 256 num_hidden = 1152 batch_size = 4 # Global transition matrix X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X") Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h") Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h") # h: output hidden state, c: cell state. s_state_h = tvm.placeholder((num_step, batch_size, num_hidden)) s_state_c = tvm.placeholder((num_step, batch_size, num_hidden)) s_init_c = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c") s_init_h = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h") # LSTM transition k = tvm.reduce_axis((0, num_input), name="ki2h") s_i2h = tvm.compute( (num_step, 4, batch_size, num_hidden), lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k), name="s_i2h") k = tvm.reduce_axis((0, num_hidden), name="ki2h") s_h2h = tvm.compute( (num_step, 4, batch_size, num_hidden), lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k), name="s_h2h") # Gate rules gates = tvm.compute(s_i2h.shape, lambda *i: s_i2h(*i) + s_h2h(*i), name="gates") gshape = (num_step, batch_size, num_hidden) in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]), name="in_gate") in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, 1, i, j]), name="in_transform") forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]), name="forget_gate") out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]), name="out_gate") next_c = tvm.compute(gshape, lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j], name="next_c") next_h = tvm.compute(gshape, lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h") update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c") update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h") # schedule scan_h, scan_c = tvm.scan( [s_init_h, s_init_c], [update_h, update_c], [s_state_h, s_state_c], inputs=[X], name="lstm_scan") # schedule s = tvm.create_schedule(scan_h.op) # Inline gate computations s[gates].compute_inline() s[in_gate].compute_inline() s[in_transform].compute_inline() s[forget_gate].compute_inline() s[out_gate].compute_inline() # verify we can lower correctly tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
def test_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init") x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans") s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1") s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) def test_getbody(): body = tvm.schedule.ScanGetBody(s_scan.op) assert set(body) == set([s_scan.op, s_update.op, s_up1.op]) def test_attach_path(): s = tvm.create_schedule(s_scan.op) s[x_trans].compute_at(s[s_update], s_update.op.axis[0]) apath = tvm.schedule.CreateAttachPath(s) assert (tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])) assert (tuple(apath[x_trans.op]) == tuple( [s_update.op.axis[0], s_scan.op.scan_axis])) def test_fix_pt(): body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert (fxpt[s_scan.spatial_axis_[0]].value != 0)
def test_scan_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1) s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1) res = tvm.scan(s_init, s_update3, s_state, inputs=x) s = tvm.create_schedule(res.op) assert s[s_update1].group is not None assert s[s_update2].group == s[s_update1].group # Assign within group, is valid s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1]) # create a new group, for [s_update2 and s_update1] g2 = s.create_group(outputs=s_update2, inputs=[s_state, x]) assert g2.group is not None assert g2.group == s[s_update3].group assert s[s_update2].group == g2 assert s[s_update1].group == g2 g2.compute_at(s[s_update3], s_update3.op.axis[1]) assert g2.attach_stage == s[s_update3] try: # compute outside group error. s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) assert False except tvm.TVMError: pass
def test_scan_group(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i]) s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1) s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1) res = tvm.scan(s_init, s_update3, s_state, inputs=x) s = tvm.create_schedule(res.op) assert s[s_update1].group is not None assert s[s_update2].group == s[s_update1].group # Assign within group, is valid s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1]) # create a new group, for [s_update2 and s_update1] g2 = s.create_group(outputs=s_update2, inputs=[s_state, x]) assert g2.group is not None assert g2.group == s[s_update3].group assert s[s_update2].group == g2 assert s[s_update1].group == g2 g2.compute_at(s[s_update3], s_update3.op.axis[1]) assert g2.attach_stage == s[s_update3] try: # compute outside group error. s[s_update2].compute_at(s[s_init], s_init.op.axis[0]) assert False except tvm.TVMError: pass
def test_scan1(): s_update = tvm.compute((l, m, n), lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_tensor_scan(): m = tvm.size_var("m") n = tvm.size_var("n") x = tvm.placeholder((m, n)) s = tvm.placeholder((m, n)) res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), tvm.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]), s) assert tuple(res.shape) == (m, n)
def test_tensor_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.placeholder((m, n)) s = tvm.placeholder((m, n)) res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]), tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]), s) assert tuple(res.shape) == (m, n)
def cumsum(X): """ Y[i] = sum(X[:i]) """ (m, ) = X.shape s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state") s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32")) s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1]) return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")
def test_scan4_reach_other(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan3_not_exact_reach(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan1(): s_update = tvm.compute( (l, m, n), lambda t, i, j: x[t, j, i] + s_state[t - 1, j, i], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) assert (fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan4_reach_other(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t - 1, j, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t - 1, i, j] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert (fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan5_multi_output(): m = tvm.var("m") n = tvm.var("n") x1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n)) s2 = tvm.placeholder((m, n)) s1_init = tvm.compute((1, n), lambda _, i: x1[0, i]) s2_init = tvm.compute((1, n), lambda _, i: x2[0, i]) s1_update = tvm.compute((m, n), lambda t, i: s1[t - 1, i] + x1[t, i]) s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i]) r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2]) body = tvm.schedule.ScanGetBody(r0.op) fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op) assert (fxpt[r1.op.spatial_axis_[0]].value == 1)
def test_schedule_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i]) res = tvm.scan(s_init, s_update, s_state) assert tuple(res.shape) == (m, n) s = tvm.create_schedule(res.op) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_schedule_scan(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: x[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i]) res = tvm.scan(s_init, s_update, s_state) assert tuple(res.shape) == (m, n) s = tvm.create_schedule(res.op) s = s.normalize() bounds = tvm.schedule.InferBound(s) assert (bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds)
def test_scan3_not_exact_reach(): s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t - 1, i, j], name="h1") s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t - 1, i, 10] * 2, name="h1") s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update") s_scan = tvm.scan(s_init, s_update, s_state) body = tvm.schedule.ScanGetBody(s_scan.op) fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op) assert (fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
def test_scan(): m = tvm.size_var("m") n = tvm.size_var("n") X = tvm.placeholder((m, n), name="X") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) scan = tvm.scan(s_init, s_update, s_state) # test scan + compute case res = tvm.compute((m, n), lambda i, j: scan[i, j]) # schedule s = tvm.create_schedule(res.op) num_thread = 256 block_x = tvm.thread_axis(None, "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x) s[s_init].bind(xi, thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_x) s[s_update].bind(xi, thread_x) xo, xi = s[res].split(res.op.axis[1], factor=num_thread) s[res].bind(xo, block_x) s[res].bind(xi, thread_x) # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return fscan = tvm.build(s, [X, res], device, name="myscan") # launch the kernel. n = 1024 m = 10 a_np = np.random.uniform(size=(m, n)).astype(res.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx) fscan(a, b) tvm.testing.assert_allclose(b.asnumpy(), np.cumsum(a_np, axis=0)) check_device("vulkan") check_device("cuda") check_device("metal") check_device("opencl")
def test_scan5_multi_output(): m = tvm.var("m") n = tvm.var("n") x1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n)) s2 = tvm.placeholder((m, n)) s1_init = tvm.compute((1, n), lambda _, i: x1[0, i]) s2_init = tvm.compute((1, n), lambda _, i: x2[0, i]) s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + x1[t, i]) s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i]) r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2]) body = tvm.schedule.ScanGetBody(r0.op) fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op) assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
def test_scan_inline1(): m = tvm.var("m") n = tvm.var("n") x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state1 = tvm.placeholder((m, n)) s_state2 = tvm.placeholder((m, n)) s_init1 = tvm.compute((1, n), lambda _, i: x[0, i]) s_init2 = tvm.compute((1, n), lambda _, i: x[0, i]) s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1") s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2") s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1") s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2") res1, res2 = tvm.scan([s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2]) s = tvm.create_schedule(res1.op) s[s_x1].compute_inline() stmt = tvm.lower(s, [x, res1, res2])
def test_bound_scan(): m = tvm.var("m") n = tvm.var("n") X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state) assert tuple(s_scan.shape) == (m, n) s = tvm.create_schedule(s_scan.op) XX = s.cache_read(X, "local", s_update) xo, xi = s[s_update].split(s_update.op.axis[1], factor=4) s[XX].compute_at(s[s_update], xo) s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) assert bounds[XX.op.axis[1]].extent.value == 4
def test_scan(): m = tvm.var("m") n = tvm.var("n") X = tvm.placeholder((m, n), name="X") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) res = tvm.scan(s_init, s_update, s_state) # schedule s = tvm.create_schedule(res.op) num_thread = 256 block_x = tvm.thread_axis(None, "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x) s[s_init].bind(xi, thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_x) s[s_update].bind(xi, thread_x) # one line to build the function. def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: print("skip because %s is not enabled.." % device) return fscan = tvm.build(s, [X, res], device, name="myscan") # launch the kernel. n = 1024 m = 10 a_np = np.random.uniform(size=(m, n)).astype(res.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx) fscan(a, b) np.testing.assert_allclose( b.asnumpy(), np.cumsum(a_np, axis=0)) check_device("vulkan") check_device("cuda") check_device("metal") check_device("opencl")
def test_scan_multi_out(): m = tvm.size_var("m") n = tvm.size_var("n") x1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n)) s2 = tvm.placeholder((m, n)) s1_init = tvm.compute((1, n), lambda _, i: x1[0, i]) s2_init = tvm.compute((1, n), lambda _, i: x2[0, i]) s1_update = tvm.compute( (m, n), lambda t, i: s1[t - 1, i] + s2[t - 1, i] + x1[t, i]) s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i]) r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2]) assert (r0.value_index == 0) assert (r1.value_index == 1) json_str = tvm.save_json(r0.op) zz = tvm.load_json(json_str) assert isinstance(zz, tvm.tensor.ScanOp)
def test_scan_multi_out(): m = tvm.var("m") n = tvm.var("n") x1 = tvm.placeholder((m, n)) s1 = tvm.placeholder((m, n)) x2 = tvm.placeholder((m, n)) s2 = tvm.placeholder((m, n)) s1_init = tvm.compute((1, n), lambda _, i: x1[0, i]) s2_init = tvm.compute((1, n), lambda _, i: x2[0, i]) s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + s2[t-1, i] + x1[t, i]) s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i]) r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2]) assert(r0.value_index == 0) assert(r1.value_index == 1) json_str = tvm.save_json(r0.op) zz = tvm.load_json(json_str) assert isinstance(zz, tvm.tensor.ScanOp)
def rnn_matexp(): n_num_step = 128 n_num_hidden = 1152 n_batch_size = 4 detect_global_barrier = DETECT_GLOBAL_BARRIER num_step = tvm.var("num_step") num_hidden = tvm.convert(n_num_hidden) batch_size = tvm.convert(n_batch_size) num_thread_y = 8 num_thread_x = 16 * 3 num_sm = 24 Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh") s_init = tvm.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init") s_state = tvm.placeholder((num_step, batch_size, num_hidden)) kh = tvm.reduce_axis((0, num_hidden), name="kh") s_update = tvm.compute( (num_step, batch_size, num_hidden), lambda t, i, j: tvm.sum(s_state[t-1, i, kh] * Whh[kh, j], axis=kh), name="update") s_scan = tvm.scan(s_init, s_update, s_state) # schedule s = tvm.create_schedule(s_scan.op) CL = s_update SS = s.cache_read(s_state, "shared", [CL]) SL = s.cache_read(SS, "local", [CL]) WhhL = s.cache_read(Whh, "local", [CL]) ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y) CLF = s.rfactor(CL, ko) block_x = tvm.thread_axis((0, num_sm), "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") if PERSIST_KERNEL: s[s_scan.op].env_threads([block_x, thread_y, thread_x]) bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm) tx, xi = s[s_init].split(xi, nparts=num_thread_x) s[s_init].bind(bx, block_x) s[s_init].bind(tx, thread_x) bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm) tx, xi = s[s_update].split(xi, nparts=num_thread_x) s[s_update].bind(bx, block_x) s[s_update].bind(tx, thread_x) s[CL].bind(s[CL].op.reduce_axis[0], thread_y) s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0]) # Duplicate store predicate. s[CL].set_store_predicate(thread_y.equal(0)) if PERSIST_KERNEL: s[WhhL].compute_at(s[s_scan], thread_x) s[WhhL].unroll(WhhL.op.axis[0]) else: s[WhhL].compute_at(s[CLF], CLF.op.axis[3]) kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1) ko, ki = s[CLF].split(ki, factor=4) s[SS].compute_at(s[CLF], kr) s[SL].compute_at(s[CLF], ko) xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3) ty, xi = s[SS].split(xi, nparts=num_thread_y) tx, xi = s[SS].split(xi, nparts=num_thread_x) s[SS].bind(ty, thread_y) s[SS].bind(tx, thread_x) def check_device(target): with tvm.build_config( detect_global_barrier=detect_global_barrier, auto_unroll_max_step=128, unroll_explicit=False): f = tvm.build(s, [s_scan, Whh], target) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. res_np = np.zeros( (n_num_step, n_batch_size, n_num_hidden)).astype("float32") Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32") Whh_np[:] = 2.0 / n_num_hidden Whh_np[:, n_num_hidden//2:] = 0 res_a = tvm.nd.array(res_np, ctx) Whh_a = tvm.nd.array(Whh_np, ctx) # Skip first pass as it is compilation f(res_a, Whh_a) ctx.sync() # measure time cost of second step. tstart = time.time() f(res_a, Whh_a) ctx.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # correctness if not SKIP_CHECK: res_gpu = res_a.asnumpy() res_cmp = np.ones_like(res_np).astype("float64") Whh_np = Whh_np.astype("float64") for t in range(1, n_num_step): res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np) for i in range(n_num_step): for j in range(n_num_hidden): if abs(res_cmp[i,0,j] - res_gpu[i,0,j]) > 1e-5: print("%d, %d: %g vs %g" % (i,j, res_cmp[i,0,j], res_gpu[i,0,j])) tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3) check_device("cuda")
def rnn_matexp(): n_num_step = 128 n_num_hidden = 1152 n_batch_size = 4 detect_global_barrier = DETECT_GLOBAL_BARRIER num_step = tvm.var("num_step") num_hidden = tvm.convert(n_num_hidden) batch_size = tvm.convert(n_batch_size) num_thread_y = 8 num_thread_x = 16 * 3 num_sm = 24 Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh") s_init = tvm.compute((1, batch_size, num_hidden), lambda _, i, j: 1.0, name="init") s_state = tvm.placeholder((num_step, batch_size, num_hidden)) kh = tvm.reduce_axis((0, num_hidden), name="kh") s_update = tvm.compute( (num_step, batch_size, num_hidden), lambda t, i, j: tvm.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh), name="update") s_scan = tvm.scan(s_init, s_update, s_state) # schedule s = tvm.create_schedule(s_scan.op) CL = s_update SS = s.cache_read(s_state, "shared", [CL]) SL = s.cache_read(SS, "local", [CL]) WhhL = s.cache_read(Whh, "local", [CL]) ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y) CLF = s.rfactor(CL, ko) block_x = tvm.thread_axis((0, num_sm), "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") if PERSIST_KERNEL: s[s_scan.op].env_threads([block_x, thread_y, thread_x]) bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm) tx, xi = s[s_init].split(xi, nparts=num_thread_x) s[s_init].bind(bx, block_x) s[s_init].bind(tx, thread_x) bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm) tx, xi = s[s_update].split(xi, nparts=num_thread_x) s[s_update].bind(bx, block_x) s[s_update].bind(tx, thread_x) s[CL].bind(s[CL].op.reduce_axis[0], thread_y) s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0]) # Duplicate store predicate. s[CL].set_store_predicate(thread_y.equal(0)) if PERSIST_KERNEL: s[WhhL].compute_at(s[s_scan], thread_x) s[WhhL].unroll(WhhL.op.axis[0]) else: s[WhhL].compute_at(s[CLF], CLF.op.axis[3]) kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1) ko, ki = s[CLF].split(ki, factor=4) s[SS].compute_at(s[CLF], kr) s[SL].compute_at(s[CLF], ko) xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3) ty, xi = s[SS].split(xi, nparts=num_thread_y) tx, xi = s[SS].split(xi, nparts=num_thread_x) s[SS].bind(ty, thread_y) s[SS].bind(tx, thread_x) def check_device(target): with tvm.build_config(detect_global_barrier=detect_global_barrier, auto_unroll_max_step=128, unroll_explicit=False): f = tvm.build(s, [s_scan, Whh], target) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. res_np = np.zeros( (n_num_step, n_batch_size, n_num_hidden)).astype("float32") Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32") Whh_np[:] = 2.0 / n_num_hidden Whh_np[:, n_num_hidden // 2:] = 0 res_a = tvm.nd.array(res_np, ctx) Whh_a = tvm.nd.array(Whh_np, ctx) # Skip first pass as it is compilation f(res_a, Whh_a) ctx.sync() # measure time cost of second step. tstart = time.time() f(res_a, Whh_a) ctx.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # correctness if not SKIP_CHECK: res_gpu = res_a.asnumpy() res_cmp = np.ones_like(res_np).astype("float64") Whh_np = Whh_np.astype("float64") for t in range(1, n_num_step): res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np) for i in range(n_num_step): for j in range(n_num_hidden): if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5: print("%d, %d: %g vs %g" % (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j])) tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3) check_device("cuda")
's1') nnpu.utils.MarkScope(s_update_1) k = tvm.reduce_axis((0, m), 'k1') s_update_2 = tvm.compute(h_shape, lambda t, i: tvm.sum(u_buf[i, k] * h_state[t - 1, k], axis=k), 's2') nnpu.utils.MarkScope(s_update_2) s_update_3 = tvm.compute(h_shape, lambda t, i: s_update_1[t, i] + s_update_2[t, i], 's3') nnpu.utils.MarkScope(s_update_3) s_update_4 = tvm.compute(h_shape, lambda t, i: s_update_3[t, i] + b_buf[i], 's4') nnpu.utils.MarkScope(s_update_4) s_scan = tvm.scan(h_init_buf, s_update_4, h_state, inputs=[x_buf]) nnpu.utils.MarkScope(s_scan) #res = nnpu.utils.reshape(s_scan, h_shape) #res_host, _ = nnpu.utils.CopyBufToH(res, 'sc') s = nnpu.create_schedule(s_scan.op) # tensorize s[s_update_1].tensorize(s_update_1.op.axis[1], env.intrins.get('GEMM', shape=gemm_shape, mode='inc', reduce=True)) #s[s_update_2].tensorize(s_update_2.op.axis[1], # env.intrins.get('GEMM', shape=gemm_shape, mode='w', reduce=True)) s[s_update_3].tensorize(s_update_3.op.axis[1], env.intrins.get('VAddV', mode='w')) #s[s_update_4].tensorize(s_update_4.op.axis[1], # env.intrins.get('VAddV', mode='w')) print(tvm.lower(s, [x, w, u, b, h_init, s_scan], simple_mode=True))
# :code:`s_update` describes how to update the value at timestep t. The update # value can refer back to the values of previous timestep via state placeholder. # Note that while it is invalid to refer to :code:`s_state` at current or later timestep. # # The scan takes in state placeholder, initial value and update description. # It is also recommended(although not necessary) to list the inputs to the scan cell. # The result of the scan is a tensor, giving the result of :code:`s_state` after the # update over the time domain. # m = tvm.var("m") n = tvm.var("n") X = tvm.placeholder((m, n), name="X") s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state, inputs=[X]) ###################################################################### # Schedule the Scan Cell # ---------------------- # We can schedule the body of the scan by scheduling the update and # init part seperately. Note that it is invalid to schedule the # first iteration dimension of the update part. # To split on the time iteration, user can schedule on scan_op.scan_axis instead. # s = tvm.create_schedule(s_scan.op) num_thread = 256 block_x = tvm.thread_axis("blockIdx.x") thread_x = tvm.thread_axis("threadIdx.x") xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x)
def test_lstm_cell_inline(): num_step = 128 num_input = 256 num_hidden = 1152 batch_size = 4 # Global transition matrix X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X") Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h") Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h") # h: output hidden state, c: cell state. s_state_h = tvm.placeholder((num_step, batch_size, num_hidden)) s_state_c = tvm.placeholder((num_step, batch_size, num_hidden)) s_init_c = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c") s_init_h = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h") # LSTM transition k = tvm.reduce_axis((0, num_input), name="ki2h") s_i2h = tvm.compute( (num_step, 4, batch_size, num_hidden), lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k), name="s_i2h") k = tvm.reduce_axis((0, num_hidden), name="ki2h") s_h2h = tvm.compute( (num_step, 4, batch_size, num_hidden), lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k), name="s_h2h") # Gate rules gates = tvm.compute(s_i2h.shape, lambda *i: s_i2h(*i) + s_h2h(*i), name="gates") gshape = (num_step, batch_size, num_hidden) in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]), name="in_gate") in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, 1, i, j]), name="in_transform") forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]), name="forget_gate") out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]), name="out_gate") next_c = tvm.compute( gshape, lambda t, i, j: forget_gate[t, i, j] * s_state_c[ t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j], name="next_c") next_h = tvm.compute( gshape, lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h") update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c") update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h") # schedule scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c], [s_state_h, s_state_c], inputs=[X], name="lstm_scan") # schedule s = tvm.create_schedule(scan_h.op) # Inline gate computations s[gates].compute_inline() s[in_gate].compute_inline() s[in_transform].compute_inline() s[forget_gate].compute_inline() s[out_gate].compute_inline() # verify we can lower correctly tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
def lstm(): if not PERSIST_KERNEL: raise ValueError("Non persist LSTM not yet supported") num_thread_y = 8 num_thread_x = 16 * 3 // 2 num_sm = 24 n_num_step = 128 num_step = tvm.var('num_step') num_hidden = 1152 // 2 batch_size = 1 # Global transition matrix # Input hidden channel can be pre-caculated by a gemm Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h") # Only handle hidden transition, saves space. Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h") # h: output hidden state, c: cell state. s_state_h = tvm.placeholder((num_step, batch_size, num_hidden)) s_state_c = tvm.placeholder((num_step, batch_size, num_hidden)) s_init_c = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c") s_init_h = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h") # LSTM transition k = tvm.reduce_axis((0, num_hidden), name="ki2h") s_h2h = tvm.compute( (num_step, batch_size, 4, num_hidden), lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k), name="s_h2h") # Gate rules gates = tvm.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates") gshape = (num_step, batch_size, num_hidden) in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]), name="in_gate") in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, i, 1, j]), name="in_transform") forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]), name="forget_gate") out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]), name="out_gate") next_c = tvm.compute( gshape, lambda t, i, j: forget_gate[t, i, j] * s_state_c[ t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j], name="next_c") next_h = tvm.compute( gshape, lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h") update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c") update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h") # schedule scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c], [s_state_h, s_state_c], inputs=[Xi2h], name="lstm_scan") # schedule s = tvm.create_schedule(scan_h.op) # Inline gate computations s[gates].compute_inline() s[in_gate].compute_inline() s[in_transform].compute_inline() s[forget_gate].compute_inline() s[out_gate].compute_inline() block_x = tvm.thread_axis((0, num_sm), "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h]) s_state_c_S = s.cache_read(s_state_c, "shared", [next_c]) Wh2hL = s.cache_read(Wh2h, "local", [s_h2h]) ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y) s_h2h_rf = s.rfactor(s_h2h, ko) s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y) s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0]) if PERSIST_KERNEL: s[scan_h.op].env_threads([block_x, thread_y, thread_x]) s[Wh2hL].compute_at(s[scan_h.op], thread_x) else: s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3]) if UNROLL_WLOAD: s[Wh2hL].unroll(Wh2hL.op.axis[0]) s[Wh2hL].unroll(Wh2hL.op.axis[2]) s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3]) s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis) for ss in [s_state_h_S]: xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y) ty, xi = s[ss].split(xi, nparts=num_thread_y) tx, xi = s[ss].split(xi, nparts=num_thread_x) s[ss].bind(ty, thread_y) s[ss].bind(tx, thread_x) for init in [s_init_c, s_init_h]: bx, xi = s[init].split(init.op.axis[2], nparts=num_sm) tx, xi = s[init].split(xi, nparts=num_thread_x) s[init].bind(bx, block_x) s[init].bind(tx, thread_x) s[next_c].set_store_predicate(thread_y.equal(0)) s[next_h].set_store_predicate(thread_y.equal(0)) for update in [update_c, update_h]: bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm) tx, xi = s[update].split(xi, nparts=num_thread_x) s[update].bind(bx, block_x) s[update].bind(tx, thread_x) s[update].set_store_predicate(thread_y.equal(0)) # verify we can lower correctly def check_device(target): num_step = n_num_step flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. scan_h_np = np.zeros( (num_step, batch_size, num_hidden)).astype("float32") scan_c_np = np.zeros( (num_step, batch_size, num_hidden)).astype("float32") Xi2h_np = np.random.normal(size=(num_step, batch_size, 4, num_hidden)).astype("float32") Wh2h_np = np.random.normal(size=(4, num_hidden, num_hidden)).astype("float32") scan_h_a = tvm.nd.array(scan_h_np, ctx) scan_c_a = tvm.nd.array(scan_c_np, ctx) Xi2h_a = tvm.nd.array(Xi2h_np, ctx) Wh2h_a = tvm.nd.array(Wh2h_np, ctx) flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) ctx.sync() # measure time cost of second step. evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000) eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) print("Time cost=%g" % eval_result.mean) # set unroll_explicit for more readable code. with tvm.build_config(detect_global_barrier=DETECT_GLOBAL_BARRIER, auto_unroll_max_step=128, unroll_explicit=False): check_device("cuda")
def lstm(): if not PERSIST_KERNEL: raise ValueError("Non persist LSTM not yet supported") detect_global_barrier = DETECT_GLOBAL_BARRIER num_thread_y = 8 num_thread_x = 16 * 3 / 2 num_sm = 24 n_num_step = 128 num_step = tvm.var('num_step') num_hidden = 1152 / 2 batch_size = 1 # Global transition matrix # Input hidden channel can be pre-caculated by a gemm Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h") # Only handle hidden transition, saves space. Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h") # h: output hidden state, c: cell state. s_state_h = tvm.placeholder((num_step, batch_size, num_hidden)) s_state_c = tvm.placeholder((num_step, batch_size, num_hidden)) s_init_c = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_c") s_init_h = tvm.compute((1, batch_size, num_hidden), lambda *i: 0.0, name="init_h") # LSTM transition k = tvm.reduce_axis((0, num_hidden), name="ki2h") s_h2h = tvm.compute( (num_step, batch_size, 4, num_hidden), lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k), name="s_h2h") # Gate rules gates = tvm.compute(Xi2h.shape, lambda *i: Xi2h(*i) + s_h2h(*i), name="gates") gshape = (num_step, batch_size, num_hidden) in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]), name="in_gate") in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, i, 1, j]), name="in_transform") forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]), name="forget_gate") out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]), name="out_gate") next_c = tvm.compute(gshape, lambda t, i, j: forget_gate[t, i, j] * s_state_c[t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j], name="next_c") next_h = tvm.compute(gshape, lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h") update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c") update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h") # schedule scan_h, scan_c = tvm.scan( [s_init_h, s_init_c], [update_h, update_c], [s_state_h, s_state_c], inputs=[Xi2h], name="lstm_scan") # schedule s = tvm.create_schedule(scan_h.op) # Inline gate computations s[gates].compute_inline() s[in_gate].compute_inline() s[in_transform].compute_inline() s[forget_gate].compute_inline() s[out_gate].compute_inline() block_x = tvm.thread_axis((0, num_sm), "blockIdx.x") thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x") thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y") s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h]) s_state_c_S = s.cache_read(s_state_c, "shared", [next_c]) Wh2hL = s.cache_read(Wh2h, "local", [s_h2h]) ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y) s_h2h_rf = s.rfactor(s_h2h, ko) s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y) s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0]) if PERSIST_KERNEL: s[scan_h.op].env_threads([block_x, thread_y, thread_x]) s[Wh2hL].compute_at(s[scan_h.op], thread_x) else: s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3]) if UNROLL_WLOAD: s[Wh2hL].unroll(Wh2hL.op.axis[0]) s[Wh2hL].unroll(Wh2hL.op.axis[2]) s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3]) s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis) for ss in [s_state_h_S]: xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y) ty, xi = s[ss].split(xi, nparts=num_thread_y) tx, xi = s[ss].split(xi, nparts=num_thread_x) s[ss].bind(ty, thread_y) s[ss].bind(tx, thread_x) for init in [s_init_c, s_init_h]: bx, xi = s[init].split(init.op.axis[2], nparts=num_sm) tx, xi = s[init].split(xi, nparts=num_thread_x) s[init].bind(bx, block_x) s[init].bind(tx, thread_x) s[next_c].set_store_predicate(thread_y.equal(0)) s[next_h].set_store_predicate(thread_y.equal(0)) for update in [update_c, update_h]: bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm) tx, xi = s[update].split(xi, nparts=num_thread_x) s[update].bind(bx, block_x) s[update].bind(tx, thread_x) s[update].set_store_predicate(thread_y.equal(0)) # verify we can lower correctly def check_device(target): num_step = n_num_step flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. scan_h_np = np.zeros( (num_step, batch_size, num_hidden)).astype("float32") scan_c_np = np.zeros( (num_step, batch_size, num_hidden)).astype("float32") Xi2h_np = np.random.normal( size=(num_step, batch_size, 4, num_hidden)).astype("float32") Wh2h_np = np.random.normal( size=(4, num_hidden, num_hidden)).astype("float32") scan_h_a = tvm.nd.array(scan_h_np, ctx) scan_c_a = tvm.nd.array(scan_c_np, ctx) Xi2h_a = tvm.nd.array(Xi2h_np, ctx) Wh2h_a = tvm.nd.array(Wh2h_np, ctx) flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) ctx.sync() # measure time cost of second step. tstart = time.time() flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a) ctx.sync() tgap = time.time() - tstart print("Time cost=%g" % tgap) # set unroll_explicit for more readable code. with tvm.build_config( detect_global_barrier=DETECT_GLOBAL_BARRIER, auto_unroll_max_step=128, unroll_explicit=False): check_device("cuda")
import tvm import numpy as np m = tvm.var('m') n = tvm.var('n') X = tvm.placeholder((m, n), name='X') s_state = tvm.placeholder((m, n)) s_init = tvm.compute((1, n), lambda _, i: X[0, i]) s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) s_scan = tvm.scan(s_init, s_update, s_state, inputs=[X]) # Schedule the Scan Cell s = tvm.create_schedule(s_scan.op) num_thread = 256 block_x = tvm.thread_axis('blockIdx.x') thread_x = tvm.thread_axis('threadIdx.x') xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread) s[s_init].bind(xo, block_x) s[s_init].bind(xi, thread_x) xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread) s[s_update].bind(xo, block_x) s[s_update].bind(xi, thread_x) print(tvm.lower(s, [X, s_scan], simple_mode=True)) # Build and Verify f_scan = tvm.build(s, [X, s_scan], 'cuda', name='my_scan') ctx = tvm.gpu(0) n = 1024 m = 10 a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype) a = tvm.nd.array(a_np, ctx=ctx)