예제 #1
0
def test_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init")
    x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans")
    s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1")
    s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update")
    s_scan = tvm.scan(s_init, s_update, s_state)

    def test_getbody():
        body = tvm.schedule.ScanGetBody(s_scan.op)
        assert set(body) == set([s_scan.op, s_update.op, s_up1.op])

    def test_attach_path():
        s = tvm.create_schedule(s_scan.op)
        s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
        apath = tvm.schedule.CreateAttachPath(s)
        assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
        assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis]))

    def test_fix_pt():
        body = tvm.schedule.ScanGetBody(s_scan.op)
        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
        assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
예제 #2
0
def test_lstm_cell_inline():
    num_step = 128
    num_input = 256
    num_hidden = 1152
    batch_size = 4
    # Global transition matrix
    X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X")
    Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h")
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_input), name="ki2h")
    s_i2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
        name="s_i2h")
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(s_i2h.shape, lambda *i:
                        s_i2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]), name="in_gate")
    in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, 1, i, j]), name="in_transform")
    forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]), name="forget_gate")
    out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]), name="out_gate")
    next_c = tvm.compute(gshape,
                         lambda t, i, j:
                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
    next_h = tvm.compute(gshape,
                         lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[X],
        name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()
    # verify we can lower correctly
    tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
def test_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init")
    x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans")
    s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1")
    s_update = tvm.compute((m, n),
                           lambda t, i: s_up1[t, i] + x_trans[t, i],
                           name="update")
    s_scan = tvm.scan(s_init, s_update, s_state)

    def test_getbody():
        body = tvm.schedule.ScanGetBody(s_scan.op)
        assert set(body) == set([s_scan.op, s_update.op, s_up1.op])

    def test_attach_path():
        s = tvm.create_schedule(s_scan.op)
        s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
        apath = tvm.schedule.CreateAttachPath(s)
        assert (tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
        assert (tuple(apath[x_trans.op]) == tuple(
            [s_update.op.axis[0], s_scan.op.scan_axis]))

    def test_fix_pt():
        body = tvm.schedule.ScanGetBody(s_scan.op)
        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
        assert (fxpt[s_scan.spatial_axis_[0]].value != 0)
예제 #4
0
def test_scan_group():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i])

    s_update1 = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
    s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
    s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
    res = tvm.scan(s_init, s_update3, s_state, inputs=x)

    s = tvm.create_schedule(res.op)
    assert s[s_update1].group is not None
    assert s[s_update2].group == s[s_update1].group
    # Assign within group, is valid
    s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
    # create a new group, for [s_update2 and s_update1]
    g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
    assert g2.group is not None
    assert g2.group == s[s_update3].group
    assert s[s_update2].group == g2
    assert s[s_update1].group == g2
    g2.compute_at(s[s_update3], s_update3.op.axis[1])
    assert g2.attach_stage == s[s_update3]
    try:
        # compute outside group error.
        s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
        assert False
    except tvm.TVMError:
        pass
def test_scan_group():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i])

    s_update1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i])
    s_update2 = tvm.compute((m, n), lambda t, i: s_update1[t, i] + 1)
    s_update3 = tvm.compute((m, n), lambda t, i: s_update2[t, i] + 1)
    res = tvm.scan(s_init, s_update3, s_state, inputs=x)

    s = tvm.create_schedule(res.op)
    assert s[s_update1].group is not None
    assert s[s_update2].group == s[s_update1].group
    # Assign within group, is valid
    s[s_update1].compute_at(s[s_update2], s_update2.op.axis[1])
    # create a new group, for [s_update2 and s_update1]
    g2 = s.create_group(outputs=s_update2, inputs=[s_state, x])
    assert g2.group is not None
    assert g2.group == s[s_update3].group
    assert s[s_update2].group == g2
    assert s[s_update1].group == g2
    g2.compute_at(s[s_update3], s_update3.op.axis[1])
    assert g2.attach_stage == s[s_update3]
    try:
        # compute outside group error.
        s[s_update2].compute_at(s[s_init], s_init.op.axis[0])
        assert False
    except tvm.TVMError:
        pass
예제 #6
0
 def test_scan1():
     s_update = tvm.compute((l, m, n),
                            lambda t, i, j: x[t, j, i]  + s_state[t-1, j, i], name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     body = tvm.schedule.ScanGetBody(s_scan.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
     assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
     assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
예제 #7
0
def test_tensor_scan():
    m = tvm.size_var("m")
    n = tvm.size_var("n")
    x = tvm.placeholder((m, n))
    s = tvm.placeholder((m, n))
    res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
                   tvm.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]), s)
    assert tuple(res.shape) == (m, n)
예제 #8
0
def test_tensor_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.placeholder((m, n))
    s = tvm.placeholder((m, n))
    res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
                   tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]),
                   s)
    assert tuple(res.shape) == (m, n)
 def cumsum(X):
     """
     Y[i] = sum(X[:i])
     """
     (m, ) = X.shape
     s_state = tvm.placeholder((m + 1, ), dtype="int32", name="state")
     s_init = tvm.compute((1, ), lambda _: tvm.const(0, "int32"))
     s_update = tvm.compute((m + 1, ), lambda l: s_state[l - 1] + X[l - 1])
     return tvm.scan(s_init, s_update, s_state, inputs=[X], name="cumsum")
예제 #10
0
 def test_scan4_reach_other():
     s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1")
     s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1")
     s_update = tvm.compute((l, m, n),
                            lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
     assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
     assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
예제 #11
0
 def test_scan3_not_exact_reach():
     s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1")
     s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1")
     s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     body = tvm.schedule.ScanGetBody(s_scan.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
     assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
     assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
 def test_scan1():
     s_update = tvm.compute(
         (l, m, n),
         lambda t, i, j: x[t, j, i] + s_state[t - 1, j, i],
         name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     body = tvm.schedule.ScanGetBody(s_scan.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
     assert (fxpt[s_scan.op.spatial_axis_[0]].value == 0)
     assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
 def test_scan4_reach_other():
     s_h1 = tvm.compute((l, n, m),
                        lambda t, j, i: s_state[t - 1, j, j],
                        name="h1")
     s_h2 = tvm.compute((l, m, n),
                        lambda t, i, j: s_state[t - 1, i, j] * 2,
                        name="h1")
     s_update = tvm.compute((l, m, n),
                            lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j],
                            name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
     assert (fxpt[s_scan.op.spatial_axis_[0]].value == 0)
     assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
 def test_scan5_multi_output():
     m = tvm.var("m")
     n = tvm.var("n")
     x1 = tvm.placeholder((m, n))
     s1 = tvm.placeholder((m, n))
     x2 = tvm.placeholder((m, n))
     s2 = tvm.placeholder((m, n))
     s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
     s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
     s1_update = tvm.compute((m, n), lambda t, i: s1[t - 1, i] + x1[t, i])
     s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i])
     r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2])
     body = tvm.schedule.ScanGetBody(r0.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
     assert (fxpt[r1.op.spatial_axis_[0]].value == 1)
예제 #15
0
def test_schedule_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
    res = tvm.scan(s_init, s_update, s_state)

    assert tuple(res.shape) == (m, n)
    s = tvm.create_schedule(res.op)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    assert(bounds[res.op.scan_axis].min.value == 1)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
예제 #16
0
def test_schedule_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: x[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + x[t, i])
    res = tvm.scan(s_init, s_update, s_state)

    assert tuple(res.shape) == (m, n)
    s = tvm.create_schedule(res.op)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    assert (bounds[res.op.scan_axis].min.value == 1)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
 def test_scan3_not_exact_reach():
     s_h1 = tvm.compute((l, n, m),
                        lambda t, j, i: s_state[t - 1, i, j],
                        name="h1")
     s_h2 = tvm.compute((l, m, n),
                        lambda t, i, j: s_state[t - 1, i, 10] * 2,
                        name="h1")
     s_update = tvm.compute((l, m, n),
                            lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j],
                            name="update")
     s_scan = tvm.scan(s_init, s_update, s_state)
     body = tvm.schedule.ScanGetBody(s_scan.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op)
     assert (fxpt[s_scan.op.spatial_axis_[0]].value == 1)
     assert (fxpt[s_scan.op.spatial_axis_[1]].value == 0)
예제 #18
0
def test_scan():
    m = tvm.size_var("m")
    n = tvm.size_var("n")
    X = tvm.placeholder((m, n), name="X")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
    scan = tvm.scan(s_init, s_update, s_state)
    # test scan + compute case
    res = tvm.compute((m, n), lambda i, j: scan[i, j])

    # schedule
    s = tvm.create_schedule(res.op)
    num_thread = 256
    block_x = tvm.thread_axis(None, "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
    xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
    s[s_init].bind(xo, block_x)
    s[s_init].bind(xi, thread_x)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
    s[s_update].bind(xo, block_x)
    s[s_update].bind(xi, thread_x)
    xo, xi = s[res].split(res.op.axis[1], factor=num_thread)
    s[res].bind(xo, block_x)
    s[res].bind(xi, thread_x)

    # one line to build the function.
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("skip because %s is not enabled.." % device)
            return
        fscan = tvm.build(s, [X, res], device, name="myscan")
        # launch the kernel.
        n = 1024
        m = 10
        a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
        fscan(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np.cumsum(a_np, axis=0))

    check_device("vulkan")
    check_device("cuda")
    check_device("metal")
    check_device("opencl")
예제 #19
0
 def test_scan5_multi_output():
     m = tvm.var("m")
     n = tvm.var("n")
     x1 = tvm.placeholder((m, n))
     s1 = tvm.placeholder((m, n))
     x2 = tvm.placeholder((m, n))
     s2 = tvm.placeholder((m, n))
     s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
     s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
     s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] +  x1[t, i])
     s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])
     r0, r1 = tvm.scan([s1_init, s2_init],
                       [s1_update, s2_update],
                       [s1, s2])
     body = tvm.schedule.ScanGetBody(r0.op)
     fxpt = tvm.schedule.ScanFixPointAnalysis(r0.op)
     assert(fxpt[r1.op.spatial_axis_[0]].value == 1)
def test_scan_inline1():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state1 = tvm.placeholder((m, n))
    s_state2 = tvm.placeholder((m, n))
    s_init1 = tvm.compute((1, n), lambda _, i: x[0, i])
    s_init2 = tvm.compute((1, n), lambda _, i: x[0, i])
    s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1")
    s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2")
    s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1")
    s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2")
    res1, res2 = tvm.scan([s_init1, s_init2],
                          [s_update1, s_update2],
                          [s_state1, s_state2])
    s = tvm.create_schedule(res1.op)
    s[s_x1].compute_inline()
    stmt = tvm.lower(s, [x, res1, res2])
예제 #21
0
def test_bound_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
    s_scan = tvm.scan(s_init, s_update, s_state)

    assert tuple(s_scan.shape) == (m, n)
    s = tvm.create_schedule(s_scan.op)
    XX = s.cache_read(X, "local", s_update)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
    s[XX].compute_at(s[s_update], xo)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    assert bounds[XX.op.axis[1]].extent.value == 4
def test_bound_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
    s_scan = tvm.scan(s_init, s_update, s_state)

    assert tuple(s_scan.shape) == (m, n)
    s = tvm.create_schedule(s_scan.op)
    XX = s.cache_read(X, "local", s_update)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
    s[XX].compute_at(s[s_update], xo)
    s = s.normalize()
    bounds = tvm.schedule.InferBound(s)
    stmt = tvm.schedule.ScheduleOps(s, bounds)
    assert bounds[XX.op.axis[1]].extent.value == 4
예제 #23
0
def test_scan_inline1():
    m = tvm.var("m")
    n = tvm.var("n")
    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
    s_state1 = tvm.placeholder((m, n))
    s_state2 = tvm.placeholder((m, n))
    s_init1 = tvm.compute((1, n), lambda _, i: x[0, i])
    s_init2 = tvm.compute((1, n), lambda _, i: x[0, i])
    s_x1 = tvm.compute((m, n), lambda t, i: s_state1[t-1, i] + x[t, i], name="x1")
    s_x2 = tvm.compute((m, n), lambda t, i: s_state2[t-1, i] + 1 , name="x2")
    s_update1 = tvm.compute((m, n), lambda t, i: s_x1[t, i], "u1")
    s_update2 = tvm.compute((m, n), lambda t, i: s_x2[t, i], "u2")
    res1, res2 = tvm.scan([s_init1, s_init2],
                          [s_update1, s_update2],
                          [s_state1, s_state2])
    s = tvm.create_schedule(res1.op)
    s[s_x1].compute_inline()
    stmt = tvm.lower(s, [x, res1, res2])
예제 #24
0
파일: test_scan.py 프로젝트: gwli/tvm
def test_scan():
    m = tvm.var("m")
    n = tvm.var("n")
    X = tvm.placeholder((m, n), name="X")
    s_state = tvm.placeholder((m, n))
    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
    res = tvm.scan(s_init, s_update, s_state)

    # schedule
    s = tvm.create_schedule(res.op)
    num_thread = 256
    block_x = tvm.thread_axis(None, "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
    xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
    s[s_init].bind(xo, block_x)
    s[s_init].bind(xi, thread_x)
    xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
    s[s_update].bind(xo, block_x)
    s[s_update].bind(xi, thread_x)

    # one line to build the function.
    def check_device(device):
        ctx = tvm.context(device, 0)
        if not ctx.exist:
            print("skip because %s is not enabled.." % device)
            return
        fscan = tvm.build(s, [X, res],
                          device,
                          name="myscan")
        # launch the kernel.
        n = 1024
        m = 10
        a_np = np.random.uniform(size=(m, n)).astype(res.dtype)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros((m, n), dtype=res.dtype), ctx)
        fscan(a, b)
        np.testing.assert_allclose(
            b.asnumpy(), np.cumsum(a_np, axis=0))

    check_device("vulkan")
    check_device("cuda")
    check_device("metal")
    check_device("opencl")
예제 #25
0
def test_scan_multi_out():
    m = tvm.size_var("m")
    n = tvm.size_var("n")
    x1 = tvm.placeholder((m, n))
    s1 = tvm.placeholder((m, n))
    x2 = tvm.placeholder((m, n))
    s2 = tvm.placeholder((m, n))
    s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
    s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
    s1_update = tvm.compute(
        (m, n), lambda t, i: s1[t - 1, i] + s2[t - 1, i] + x1[t, i])
    s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i])

    r0, r1 = tvm.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2])
    assert (r0.value_index == 0)
    assert (r1.value_index == 1)
    json_str = tvm.save_json(r0.op)
    zz = tvm.load_json(json_str)
    assert isinstance(zz, tvm.tensor.ScanOp)
예제 #26
0
def test_scan_multi_out():
    m = tvm.var("m")
    n = tvm.var("n")
    x1 = tvm.placeholder((m, n))
    s1 = tvm.placeholder((m, n))
    x2 = tvm.placeholder((m, n))
    s2 = tvm.placeholder((m, n))
    s1_init = tvm.compute((1, n), lambda _, i: x1[0, i])
    s2_init = tvm.compute((1, n), lambda _, i: x2[0, i])
    s1_update = tvm.compute((m, n), lambda t, i: s1[t-1, i] + s2[t-1, i] + x1[t, i])
    s2_update = tvm.compute((m, n), lambda t, i: x2[t, i] + s2[t-1,i])

    r0, r1 = tvm.scan([s1_init, s2_init],
                      [s1_update, s2_update],
                      [s1, s2])
    assert(r0.value_index == 0)
    assert(r1.value_index == 1)
    json_str = tvm.save_json(r0.op)
    zz = tvm.load_json(json_str)
    assert isinstance(zz, tvm.tensor.ScanOp)
예제 #27
0
파일: matexp.py 프로젝트: bddppq/tvm
def rnn_matexp():
    n_num_step = 128
    n_num_hidden = 1152
    n_batch_size = 4
    detect_global_barrier = DETECT_GLOBAL_BARRIER

    num_step = tvm.var("num_step")
    num_hidden = tvm.convert(n_num_hidden)
    batch_size = tvm.convert(n_batch_size)
    num_thread_y = 8
    num_thread_x = 16 * 3
    num_sm = 24

    Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh")
    s_init = tvm.compute((1, batch_size, num_hidden),
                         lambda _, i, j: 1.0, name="init")
    s_state = tvm.placeholder((num_step, batch_size, num_hidden))
    kh = tvm.reduce_axis((0, num_hidden), name="kh")
    s_update = tvm.compute(
        (num_step, batch_size, num_hidden),
        lambda t, i, j: tvm.sum(s_state[t-1, i, kh] * Whh[kh, j], axis=kh),
        name="update")
    s_scan = tvm.scan(s_init, s_update, s_state)
    # schedule
    s = tvm.create_schedule(s_scan.op)
    CL = s_update
    SS = s.cache_read(s_state, "shared", [CL])
    SL = s.cache_read(SS, "local", [CL])
    WhhL = s.cache_read(Whh, "local", [CL])
    ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
    CLF = s.rfactor(CL, ko)

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
    if PERSIST_KERNEL:
        s[s_scan.op].env_threads([block_x, thread_y, thread_x])

    bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
    tx, xi = s[s_init].split(xi, nparts=num_thread_x)
    s[s_init].bind(bx, block_x)
    s[s_init].bind(tx, thread_x)

    bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
    tx, xi = s[s_update].split(xi, nparts=num_thread_x)
    s[s_update].bind(bx, block_x)
    s[s_update].bind(tx, thread_x)
    s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
    s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
    # Duplicate store predicate.
    s[CL].set_store_predicate(thread_y.equal(0))

    if PERSIST_KERNEL:
        s[WhhL].compute_at(s[s_scan], thread_x)
        s[WhhL].unroll(WhhL.op.axis[0])
    else:
        s[WhhL].compute_at(s[CLF], CLF.op.axis[3])

    kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
    ko, ki = s[CLF].split(ki, factor=4)
    s[SS].compute_at(s[CLF], kr)
    s[SL].compute_at(s[CLF], ko)

    xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
    ty, xi = s[SS].split(xi, nparts=num_thread_y)
    tx, xi = s[SS].split(xi, nparts=num_thread_x)
    s[SS].bind(ty, thread_y)
    s[SS].bind(tx, thread_x)

    def check_device(target):
        with tvm.build_config(
                detect_global_barrier=detect_global_barrier,
                auto_unroll_max_step=128,
                unroll_explicit=False):
            f = tvm.build(s, [s_scan, Whh], target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        res_np = np.zeros(
            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
        Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
        Whh_np[:] = 2.0 / n_num_hidden
        Whh_np[:, n_num_hidden//2:] = 0

        res_a = tvm.nd.array(res_np, ctx)
        Whh_a = tvm.nd.array(Whh_np, ctx)
        # Skip first pass as it is compilation
        f(res_a, Whh_a)
        ctx.sync()
        # measure time cost of second step.
        tstart = time.time()
        f(res_a, Whh_a)
        ctx.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)
        # correctness
        if not SKIP_CHECK:
            res_gpu = res_a.asnumpy()
            res_cmp = np.ones_like(res_np).astype("float64")
            Whh_np = Whh_np.astype("float64")
            for t in range(1, n_num_step):
                res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
            for i  in range(n_num_step):
                for j in range(n_num_hidden):
                    if abs(res_cmp[i,0,j] - res_gpu[i,0,j]) > 1e-5:
                        print("%d, %d: %g vs %g" % (i,j, res_cmp[i,0,j], res_gpu[i,0,j]))
            tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)
    check_device("cuda")
예제 #28
0
def rnn_matexp():
    n_num_step = 128
    n_num_hidden = 1152
    n_batch_size = 4
    detect_global_barrier = DETECT_GLOBAL_BARRIER

    num_step = tvm.var("num_step")
    num_hidden = tvm.convert(n_num_hidden)
    batch_size = tvm.convert(n_batch_size)
    num_thread_y = 8
    num_thread_x = 16 * 3
    num_sm = 24

    Whh = tvm.placeholder((num_hidden, num_hidden), name="Whh")
    s_init = tvm.compute((1, batch_size, num_hidden),
                         lambda _, i, j: 1.0,
                         name="init")
    s_state = tvm.placeholder((num_step, batch_size, num_hidden))
    kh = tvm.reduce_axis((0, num_hidden), name="kh")
    s_update = tvm.compute(
        (num_step, batch_size, num_hidden),
        lambda t, i, j: tvm.sum(s_state[t - 1, i, kh] * Whh[kh, j], axis=kh),
        name="update")
    s_scan = tvm.scan(s_init, s_update, s_state)
    # schedule
    s = tvm.create_schedule(s_scan.op)
    CL = s_update
    SS = s.cache_read(s_state, "shared", [CL])
    SL = s.cache_read(SS, "local", [CL])
    WhhL = s.cache_read(Whh, "local", [CL])
    ko, ki = s[CL].split(s[CL].op.reduce_axis[0], nparts=num_thread_y)
    CLF = s.rfactor(CL, ko)

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
    if PERSIST_KERNEL:
        s[s_scan.op].env_threads([block_x, thread_y, thread_x])

    bx, xi = s[s_init].split(s_init.op.axis[2], nparts=num_sm)
    tx, xi = s[s_init].split(xi, nparts=num_thread_x)
    s[s_init].bind(bx, block_x)
    s[s_init].bind(tx, thread_x)

    bx, xi = s[s_update].split(s[CL].op.axis[2], nparts=num_sm)
    tx, xi = s[s_update].split(xi, nparts=num_thread_x)
    s[s_update].bind(bx, block_x)
    s[s_update].bind(tx, thread_x)
    s[CL].bind(s[CL].op.reduce_axis[0], thread_y)
    s[CLF].compute_at(s[CL], s[CL].op.reduce_axis[0])
    # Duplicate store predicate.
    s[CL].set_store_predicate(thread_y.equal(0))

    if PERSIST_KERNEL:
        s[WhhL].compute_at(s[s_scan], thread_x)
        s[WhhL].unroll(WhhL.op.axis[0])
    else:
        s[WhhL].compute_at(s[CLF], CLF.op.axis[3])

    kr, ki = s[CLF].split(CLF.op.reduce_axis[0], nparts=1)
    ko, ki = s[CLF].split(ki, factor=4)
    s[SS].compute_at(s[CLF], kr)
    s[SL].compute_at(s[CLF], ko)

    xo, xi = s[SS].split(SS.op.axis[2], factor=num_thread_x * num_thread_y * 3)
    ty, xi = s[SS].split(xi, nparts=num_thread_y)
    tx, xi = s[SS].split(xi, nparts=num_thread_x)
    s[SS].bind(ty, thread_y)
    s[SS].bind(tx, thread_x)

    def check_device(target):
        with tvm.build_config(detect_global_barrier=detect_global_barrier,
                              auto_unroll_max_step=128,
                              unroll_explicit=False):
            f = tvm.build(s, [s_scan, Whh], target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        res_np = np.zeros(
            (n_num_step, n_batch_size, n_num_hidden)).astype("float32")
        Whh_np = np.zeros((n_num_hidden, n_num_hidden)).astype("float32")
        Whh_np[:] = 2.0 / n_num_hidden
        Whh_np[:, n_num_hidden // 2:] = 0

        res_a = tvm.nd.array(res_np, ctx)
        Whh_a = tvm.nd.array(Whh_np, ctx)
        # Skip first pass as it is compilation
        f(res_a, Whh_a)
        ctx.sync()
        # measure time cost of second step.
        tstart = time.time()
        f(res_a, Whh_a)
        ctx.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)
        # correctness
        if not SKIP_CHECK:
            res_gpu = res_a.asnumpy()
            res_cmp = np.ones_like(res_np).astype("float64")
            Whh_np = Whh_np.astype("float64")
            for t in range(1, n_num_step):
                res_cmp[t][:] = np.dot(res_cmp[t - 1], Whh_np)
            for i in range(n_num_step):
                for j in range(n_num_hidden):
                    if abs(res_cmp[i, 0, j] - res_gpu[i, 0, j]) > 1e-5:
                        print("%d, %d: %g vs %g" %
                              (i, j, res_cmp[i, 0, j], res_gpu[i, 0, j]))
            tvm.testing.assert_allclose(res_gpu, res_cmp, rtol=1e-3)

    check_device("cuda")
예제 #29
0
                        's1')
    nnpu.utils.MarkScope(s_update_1)
    k = tvm.reduce_axis((0, m), 'k1')
    s_update_2 = tvm.compute(h_shape, 
                        lambda t, i: tvm.sum(u_buf[i, k] * h_state[t - 1, k], axis=k),
                        's2')
    nnpu.utils.MarkScope(s_update_2)
    s_update_3 = tvm.compute(h_shape,
                        lambda t, i: s_update_1[t, i] + s_update_2[t, i], 
                        's3')
    nnpu.utils.MarkScope(s_update_3)
    s_update_4 = tvm.compute(h_shape,
                        lambda t, i: s_update_3[t, i] + b_buf[i],
                        's4')
    nnpu.utils.MarkScope(s_update_4)
    s_scan = tvm.scan(h_init_buf, s_update_4, h_state, inputs=[x_buf])
    nnpu.utils.MarkScope(s_scan)

    #res = nnpu.utils.reshape(s_scan, h_shape)
    #res_host, _ = nnpu.utils.CopyBufToH(res, 'sc')
    s = nnpu.create_schedule(s_scan.op)
    # tensorize
    s[s_update_1].tensorize(s_update_1.op.axis[1], 
                            env.intrins.get('GEMM', shape=gemm_shape, mode='inc', reduce=True))
    #s[s_update_2].tensorize(s_update_2.op.axis[1],
    #                        env.intrins.get('GEMM', shape=gemm_shape, mode='w', reduce=True))
    s[s_update_3].tensorize(s_update_3.op.axis[1],
                            env.intrins.get('VAddV', mode='w'))
    #s[s_update_4].tensorize(s_update_4.op.axis[1],
    #                        env.intrins.get('VAddV', mode='w'))
    print(tvm.lower(s, [x, w, u, b, h_init, s_scan], simple_mode=True))
예제 #30
0
파일: scan.py 프로젝트: zhangquan920/tasn
# :code:`s_update` describes how to update the value at timestep t. The update
# value can refer back to the values of previous timestep via state placeholder.
# Note that while it is invalid to refer to :code:`s_state` at current or later timestep.
#
# The scan takes in state placeholder, initial value and update description.
# It is also recommended(although not necessary) to list the inputs to the scan cell.
# The result of the scan is a tensor, giving the result of :code:`s_state` after the
# update over the time domain.
#
m = tvm.var("m")
n = tvm.var("n")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
s_scan = tvm.scan(s_init, s_update, s_state, inputs=[X])

######################################################################
# Schedule the Scan Cell
# ----------------------
# We can schedule the body of the scan by scheduling the update and
# init part seperately. Note that it is invalid to schedule the
# first iteration dimension of the update part.
# To split on the time iteration, user can schedule on scan_op.scan_axis instead.
#
s = tvm.create_schedule(s_scan.op)
num_thread = 256
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis("threadIdx.x")
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
def test_lstm_cell_inline():
    num_step = 128
    num_input = 256
    num_hidden = 1152
    batch_size = 4
    # Global transition matrix
    X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X")
    Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h")
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_input), name="ki2h")
    s_i2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
        name="s_i2h")
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k],
                                   axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(s_i2h.shape,
                        lambda *i: s_i2h(*i) + s_h2h(*i),
                        name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape,
                          lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]),
                          name="in_gate")
    in_transform = tvm.compute(gshape,
                               lambda t, i, j: tvm.tanh(gates[t, 1, i, j]),
                               name="in_transform")
    forget_gate = tvm.compute(gshape,
                              lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]),
                              name="forget_gate")
    out_gate = tvm.compute(gshape,
                           lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]),
                           name="out_gate")
    next_c = tvm.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[
            t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c")
    next_h = tvm.compute(
        gshape,
        lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]),
        name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c],
                              [s_state_h, s_state_c],
                              inputs=[X],
                              name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()
    # verify we can lower correctly
    tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
예제 #32
0
def lstm():
    if not PERSIST_KERNEL:
        raise ValueError("Non persist LSTM not yet supported")
    num_thread_y = 8
    num_thread_x = 16 * 3 // 2
    num_sm = 24
    n_num_step = 128
    num_step = tvm.var('num_step')
    num_hidden = 1152 // 2
    batch_size = 1
    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k],
                                   axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(Xi2h.shape,
                        lambda *i: Xi2h(*i) + s_h2h(*i),
                        name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape,
                          lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]),
                          name="in_gate")
    in_transform = tvm.compute(gshape,
                               lambda t, i, j: tvm.tanh(gates[t, i, 1, j]),
                               name="in_transform")
    forget_gate = tvm.compute(gshape,
                              lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]),
                              name="forget_gate")
    out_gate = tvm.compute(gshape,
                           lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]),
                           name="out_gate")
    next_c = tvm.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[
            t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c")
    next_h = tvm.compute(
        gshape,
        lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]),
        name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c],
                              [s_state_h, s_state_c],
                              inputs=[Xi2h],
                              name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")

    s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h])
    s_state_c_S = s.cache_read(s_state_c, "shared", [next_c])
    Wh2hL = s.cache_read(Wh2h, "local", [s_h2h])

    ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y)
    s_h2h_rf = s.rfactor(s_h2h, ko)
    s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y)
    s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0])

    if PERSIST_KERNEL:
        s[scan_h.op].env_threads([block_x, thread_y, thread_x])
        s[Wh2hL].compute_at(s[scan_h.op], thread_x)
    else:
        s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3])

    if UNROLL_WLOAD:
        s[Wh2hL].unroll(Wh2hL.op.axis[0])
        s[Wh2hL].unroll(Wh2hL.op.axis[2])

    s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3])
    s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis)

    for ss in [s_state_h_S]:
        xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y)
        ty, xi = s[ss].split(xi, nparts=num_thread_y)
        tx, xi = s[ss].split(xi, nparts=num_thread_x)
        s[ss].bind(ty, thread_y)
        s[ss].bind(tx, thread_x)

    for init in [s_init_c, s_init_h]:
        bx, xi = s[init].split(init.op.axis[2], nparts=num_sm)
        tx, xi = s[init].split(xi, nparts=num_thread_x)
        s[init].bind(bx, block_x)
        s[init].bind(tx, thread_x)

    s[next_c].set_store_predicate(thread_y.equal(0))
    s[next_h].set_store_predicate(thread_y.equal(0))

    for update in [update_c, update_h]:
        bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm)
        tx, xi = s[update].split(xi, nparts=num_thread_x)
        s[update].bind(bx, block_x)
        s[update].bind(tx, thread_x)
        s[update].set_store_predicate(thread_y.equal(0))

    # verify we can lower correctly
    def check_device(target):
        num_step = n_num_step
        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        scan_h_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        scan_c_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        Xi2h_np = np.random.normal(size=(num_step, batch_size, 4,
                                         num_hidden)).astype("float32")
        Wh2h_np = np.random.normal(size=(4, num_hidden,
                                         num_hidden)).astype("float32")
        scan_h_a = tvm.nd.array(scan_h_np, ctx)
        scan_c_a = tvm.nd.array(scan_c_np, ctx)
        Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
        Wh2h_a = tvm.nd.array(Wh2h_np, ctx)
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        # measure time cost of second step.
        evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000)
        eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        print("Time cost=%g" % eval_result.mean)

    # set unroll_explicit for more readable code.
    with tvm.build_config(detect_global_barrier=DETECT_GLOBAL_BARRIER,
                          auto_unroll_max_step=128,
                          unroll_explicit=False):
        check_device("cuda")
예제 #33
0
파일: lstm.py 프로젝트: gwli/tvm
def lstm():
    if not PERSIST_KERNEL:
        raise ValueError("Non persist LSTM not yet supported")
    detect_global_barrier = DETECT_GLOBAL_BARRIER
    num_thread_y = 8
    num_thread_x = 16 * 3 / 2
    num_sm = 24
    n_num_step = 128
    num_step = tvm.var('num_step')
    num_hidden = 1152 / 2
    batch_size = 1
    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(Xi2h.shape, lambda *i:
                        Xi2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]), name="in_gate")
    in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, i, 1, j]), name="in_transform")
    forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]), name="forget_gate")
    out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]), name="out_gate")
    next_c = tvm.compute(gshape,
                         lambda t, i, j:
                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
    next_h = tvm.compute(gshape,
                         lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[Xi2h],
        name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")

    s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h])
    s_state_c_S = s.cache_read(s_state_c, "shared", [next_c])
    Wh2hL = s.cache_read(Wh2h, "local", [s_h2h])

    ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y)
    s_h2h_rf = s.rfactor(s_h2h, ko)
    s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y)
    s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0])

    if PERSIST_KERNEL:
        s[scan_h.op].env_threads([block_x, thread_y, thread_x])
        s[Wh2hL].compute_at(s[scan_h.op], thread_x)
    else:
        s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3])

    if UNROLL_WLOAD:
        s[Wh2hL].unroll(Wh2hL.op.axis[0])
        s[Wh2hL].unroll(Wh2hL.op.axis[2])

    s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3])
    s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis)

    for ss in [s_state_h_S]:
        xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y)
        ty, xi = s[ss].split(xi, nparts=num_thread_y)
        tx, xi = s[ss].split(xi, nparts=num_thread_x)
        s[ss].bind(ty, thread_y)
        s[ss].bind(tx, thread_x)

    for init in [s_init_c, s_init_h]:
        bx, xi = s[init].split(init.op.axis[2], nparts=num_sm)
        tx, xi = s[init].split(xi, nparts=num_thread_x)
        s[init].bind(bx, block_x)
        s[init].bind(tx, thread_x)

    s[next_c].set_store_predicate(thread_y.equal(0))
    s[next_h].set_store_predicate(thread_y.equal(0))

    for update in [update_c, update_h]:
        bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm)
        tx, xi = s[update].split(xi, nparts=num_thread_x)
        s[update].bind(bx, block_x)
        s[update].bind(tx, thread_x)
        s[update].set_store_predicate(thread_y.equal(0))

    # verify we can lower correctly
    def check_device(target):
        num_step = n_num_step
        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
                          target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        scan_h_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        scan_c_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        Xi2h_np = np.random.normal(
            size=(num_step, batch_size, 4, num_hidden)).astype("float32")
        Wh2h_np = np.random.normal(
            size=(4, num_hidden, num_hidden)).astype("float32")
        scan_h_a = tvm.nd.array(scan_h_np, ctx)
        scan_c_a = tvm.nd.array(scan_c_np, ctx)
        Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
        Wh2h_a = tvm.nd.array(Wh2h_np, ctx)
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        # measure time cost of second step.
        tstart = time.time()
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)

    # set unroll_explicit for more readable code.
    with tvm.build_config(
            detect_global_barrier=DETECT_GLOBAL_BARRIER,
            auto_unroll_max_step=128,
            unroll_explicit=False):
        check_device("cuda")
import tvm
import numpy as np

m = tvm.var('m')
n = tvm.var('n')
X = tvm.placeholder((m, n), name='X')
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i])
s_scan = tvm.scan(s_init, s_update, s_state, inputs=[X])

# Schedule the Scan Cell
s = tvm.create_schedule(s_scan.op)
num_thread = 256
block_x = tvm.thread_axis('blockIdx.x')
thread_x = tvm.thread_axis('threadIdx.x')
xo, xi = s[s_init].split(s_init.op.axis[1], factor=num_thread)
s[s_init].bind(xo, block_x)
s[s_init].bind(xi, thread_x)
xo, xi = s[s_update].split(s_update.op.axis[1], factor=num_thread)
s[s_update].bind(xo, block_x)
s[s_update].bind(xi, thread_x)
print(tvm.lower(s, [X, s_scan], simple_mode=True))

# Build and Verify
f_scan = tvm.build(s, [X, s_scan], 'cuda', name='my_scan')
ctx = tvm.gpu(0)
n = 1024
m = 10
a_np = np.random.uniform(size=(m, n)).astype(s_scan.dtype)
a = tvm.nd.array(a_np, ctx=ctx)