Example #1
0
def test_lstm_cell_inline():
    num_step = 128
    num_input = 256
    num_hidden = 1152
    batch_size = 4
    # Global transition matrix
    X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X")
    Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h")
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_input), name="ki2h")
    s_i2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
        name="s_i2h")
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(s_i2h.shape, lambda *i:
                        s_i2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]), name="in_gate")
    in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, 1, i, j]), name="in_transform")
    forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]), name="forget_gate")
    out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]), name="out_gate")
    next_c = tvm.compute(gshape,
                         lambda t, i, j:
                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
    next_h = tvm.compute(gshape,
                         lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[X],
        name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()
    # verify we can lower correctly
    tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
Example #2
0
    def check_llvm_sigmoid(n):
        A = tvm.placeholder((n,), name='A')
        B = tvm.compute((n,), lambda i: tvm.sigmoid(A[i]), name='B')

        s = tvm.create_schedule(B.op)
        f = tvm.build(s, [A, B], "llvm")

        a = tvm.nd.array(np.full((n,), -1000, 'float32'))
        b = tvm.nd.empty((n,), 'float32')
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))
Example #3
0
    def check_llvm_sigmoid(n):
        A = tvm.placeholder((n,), name='A')
        B = tvm.compute((n,), lambda i: tvm.sigmoid(A[i]), name='B')

        s = tvm.create_schedule(B.op)
        f = tvm.build(s, [A, B], "llvm")

        a = tvm.nd.array(np.full((n,), -1000, 'float32'))
        b = tvm.nd.empty((n,), 'float32')
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), np.zeros((n,), 'float32'))
Example #4
0
def sigmoid(x):
    """Take sigmoid tanh of input x.

    Parameters
    ----------
    x : tvm.Tensor
        Input argument.

    Returns
    -------
    y : tvm.Tensor
        The result.
    """
    return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i)))
Example #5
0
File: math.py Project: gwli/tvm
def sigmoid(x):
    """Take sigmoid tanh of input x.

    Parameters
    ----------
    x : tvm.Tensor
        Input argument.

    Returns
    -------
    y : tvm.Tensor
        The result.
    """
    return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i)))
Example #6
0
with ScheduleProcHelper(), nnpu.Environment('./nnpu_config_fp32.yaml'):
    env = nnpu.get_env()
    nnpu.set_device(env, type=args.sim)
    dtype_n, dtype_w = env.cfg['dtype_n'], env.cfg['dtype_w']

    assert dtype_w in ['float32', 'float16'], 'when testing activation function, float dtype is needed'

    shape = (64, )
    a = tvm.placeholder(shape, dtype_w, 'a')
    a_buf = tvm.compute(shape, lambda *i: a(*i), 'a_buf')

    exp = tvm.compute(shape, lambda i: tvm.exp(a_buf[i]), 'exp')
    log = tvm.compute(shape, lambda i: tvm.log(a_buf[i]), 'exp')
    tanh = tvm.compute(shape, lambda i: tvm.tanh(a_buf[i]), 'exp')
    sigmoid = tvm.compute(shape, lambda i: tvm.sigmoid(a_buf[i]), 'exp')

    # k = tvm.reduce_axis((0, 16), 'k0')
    # sum = tvm.compute((1, ), lambda i: tvm.sum(sigmoid[k], axis=k), 'sum')
    # nnpu.utils.MarkScope(sum)

    # softmax = tvm.compute(shape, lambda i: sigmoid[i] / sum[0], 'softmax')
    # nnpu.utils.MarkScope(softmax)
    # softmax_host, _ = nnpu.utils.CopyBufToH(softmax, 'softmax')

    s = nnpu.create_schedule([exp.op, log.op, tanh.op, sigmoid.op])
    # cache write
    exp_buf = s.cache_write(exp, env.get_scope('buffer0'))
    log_buf = s.cache_write(log, env.get_scope('buffer0'))
    tanh_buf = s.cache_write(tanh, env.get_scope('buffer0'))
    sigmoid_buf = s.cache_write(sigmoid, env.get_scope('buffer0'))
def test_lstm_cell_inline():
    num_step = 128
    num_input = 256
    num_hidden = 1152
    batch_size = 4
    # Global transition matrix
    X = tvm.placeholder((num_step - 1, batch_size, num_input), name="X")
    Wi2h = tvm.placeholder((4, num_hidden, num_input), name="Wi2h")
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_input), name="ki2h")
    s_i2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(X[t - 1, i, k] * Wi2h[x, j, k], axis=k),
        name="s_i2h")
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, 4, batch_size, num_hidden),
        lambda t, x, i, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k],
                                   axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(s_i2h.shape,
                        lambda *i: s_i2h(*i) + s_h2h(*i),
                        name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape,
                          lambda t, i, j: tvm.sigmoid(gates[t, 0, i, j]),
                          name="in_gate")
    in_transform = tvm.compute(gshape,
                               lambda t, i, j: tvm.tanh(gates[t, 1, i, j]),
                               name="in_transform")
    forget_gate = tvm.compute(gshape,
                              lambda t, i, j: tvm.sigmoid(gates[t, 2, i, j]),
                              name="forget_gate")
    out_gate = tvm.compute(gshape,
                           lambda t, i, j: tvm.sigmoid(gates[t, 3, i, j]),
                           name="out_gate")
    next_c = tvm.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[
            t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c")
    next_h = tvm.compute(
        gshape,
        lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]),
        name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c],
                              [s_state_h, s_state_c],
                              inputs=[X],
                              name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()
    # verify we can lower correctly
    tvm.lower(s, [X, Wi2h, Wh2h, scan_h, scan_c])
Example #8
0
def lstm():
    if not PERSIST_KERNEL:
        raise ValueError("Non persist LSTM not yet supported")
    num_thread_y = 8
    num_thread_x = 16 * 3 // 2
    num_sm = 24
    n_num_step = 128
    num_step = tvm.var('num_step')
    num_hidden = 1152 // 2
    batch_size = 1
    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0,
                           name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k],
                                   axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(Xi2h.shape,
                        lambda *i: Xi2h(*i) + s_h2h(*i),
                        name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape,
                          lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]),
                          name="in_gate")
    in_transform = tvm.compute(gshape,
                               lambda t, i, j: tvm.tanh(gates[t, i, 1, j]),
                               name="in_transform")
    forget_gate = tvm.compute(gshape,
                              lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]),
                              name="forget_gate")
    out_gate = tvm.compute(gshape,
                           lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]),
                           name="out_gate")
    next_c = tvm.compute(
        gshape,
        lambda t, i, j: forget_gate[t, i, j] * s_state_c[
            t - 1, i, j] + in_gate[t, i, j] * in_transform[t, i, j],
        name="next_c")
    next_h = tvm.compute(
        gshape,
        lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]),
        name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan([s_init_h, s_init_c], [update_h, update_c],
                              [s_state_h, s_state_c],
                              inputs=[Xi2h],
                              name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")

    s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h])
    s_state_c_S = s.cache_read(s_state_c, "shared", [next_c])
    Wh2hL = s.cache_read(Wh2h, "local", [s_h2h])

    ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y)
    s_h2h_rf = s.rfactor(s_h2h, ko)
    s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y)
    s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0])

    if PERSIST_KERNEL:
        s[scan_h.op].env_threads([block_x, thread_y, thread_x])
        s[Wh2hL].compute_at(s[scan_h.op], thread_x)
    else:
        s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3])

    if UNROLL_WLOAD:
        s[Wh2hL].unroll(Wh2hL.op.axis[0])
        s[Wh2hL].unroll(Wh2hL.op.axis[2])

    s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3])
    s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis)

    for ss in [s_state_h_S]:
        xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y)
        ty, xi = s[ss].split(xi, nparts=num_thread_y)
        tx, xi = s[ss].split(xi, nparts=num_thread_x)
        s[ss].bind(ty, thread_y)
        s[ss].bind(tx, thread_x)

    for init in [s_init_c, s_init_h]:
        bx, xi = s[init].split(init.op.axis[2], nparts=num_sm)
        tx, xi = s[init].split(xi, nparts=num_thread_x)
        s[init].bind(bx, block_x)
        s[init].bind(tx, thread_x)

    s[next_c].set_store_predicate(thread_y.equal(0))
    s[next_h].set_store_predicate(thread_y.equal(0))

    for update in [update_c, update_h]:
        bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm)
        tx, xi = s[update].split(xi, nparts=num_thread_x)
        s[update].bind(bx, block_x)
        s[update].bind(tx, thread_x)
        s[update].set_store_predicate(thread_y.equal(0))

    # verify we can lower correctly
    def check_device(target):
        num_step = n_num_step
        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c], target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        scan_h_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        scan_c_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        Xi2h_np = np.random.normal(size=(num_step, batch_size, 4,
                                         num_hidden)).astype("float32")
        Wh2h_np = np.random.normal(size=(4, num_hidden,
                                         num_hidden)).astype("float32")
        scan_h_a = tvm.nd.array(scan_h_np, ctx)
        scan_c_a = tvm.nd.array(scan_c_np, ctx)
        Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
        Wh2h_a = tvm.nd.array(Wh2h_np, ctx)
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        # measure time cost of second step.
        evaluator = flstm.time_evaluator(flstm.entry_name, ctx, 1, repeat=1000)
        eval_result = evaluator(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        print("Time cost=%g" % eval_result.mean)

    # set unroll_explicit for more readable code.
    with tvm.build_config(detect_global_barrier=DETECT_GLOBAL_BARRIER,
                          auto_unroll_max_step=128,
                          unroll_explicit=False):
        check_device("cuda")
Example #9
0
def single_lstm():
    num_gate = 4
    hidden_size = tvm.var('hidden_size')
    batch_size = tvm.var('batch_size')
    input_size = tvm.var('input_size')

    # A single LSTM block operations without unrolling
    # '*' linear transformation
    # '(*)' elementwise multiplication
    # F_t = sigmoid( W_f * x_t + R_f * h_t-1 + b_f )
    # I_t = sigmoid( W_i * x_t + R_i * h_t-1 + b_i )
    # O_t = sigmoid( W_o * x_t + R_o * h_t-1 + b_o )
    # C'_t = tanh( W_c * x_t + R_c * h_t-1 + b_c )
    # C_t = F_t (*) C_t-1 + I_t (*) C'_t
    # h_t = O_t (*) tanh( C_t )

    # Global transition matrix

    # input X[0..t-1]
    X = tvm.placeholder((batch_size, input_size), name="X")
    Prev_h = tvm.placeholder((batch_size, hidden_size), name="Prev_h")
    Prev_c = tvm.placeholder((batch_size, hidden_size), name="Prev_c")

    # Parameters
    # Weight matrices [W_i, W_f, W_o, W_c]: 4 * hidden_size * input_size
    # Bias: 4 * hidden_size
    Wi2h = tvm.placeholder((num_gate, hidden_size, input_size), name="Wi2h")
    Bi2h = tvm.placeholder((num_gate, hidden_size), name="Bi2h")

    # Weight matrices [R_i, R_f, R_o, R_c]: 4 * hidden_size * hidden_size
    # Only handle hidden transition, saves space.
    Wh2h = tvm.placeholder((num_gate, hidden_size, hidden_size), name="Wh2h")
    Bh2h = tvm.placeholder((num_gate, hidden_size), name="Bh2h")

    # LSTM transition
    # [W_i, W_f, W_o, W_c] * X_t: 4 * num_hidden
    l = tvm.reduce_axis((0, input_size), name="li2h")
    i2h = tvm.compute((batch_size, num_gate, hidden_size),
                      lambda i, x, j: tvm.sum(X[i, l] * Wi2h[x, j, l], axis=l),
                      name="i2h")

    # [R_i, R_f, R_o, R_c] * h_t-1: 4 * hidden_size
    # R: hidden_size * hidden_size, h: hidden_size * 1
    k = tvm.reduce_axis((0, hidden_size), name="ki2h")
    h2h = tvm.compute(
        (batch_size, num_gate, hidden_size),
        lambda i, x, j: tvm.sum(Prev_h[i, k] * Wh2h[x, j, k], axis=k),
        name="h2h")

    gates = tvm.compute(
        (batch_size, num_gate, hidden_size),
        lambda i, j, k: i2h[i, j, k] + h2h[i, j, k] + Bi2h[j, k] + Bh2h[j, k],
        name="gates")
    gshape = (batch_size, hidden_size)
    in_gate = tvm.compute(gshape,
                          lambda i, j: tvm.sigmoid(gates[i, 0, j]),
                          name="in_gate")
    forget_gate = tvm.compute(gshape,
                              lambda i, j: tvm.sigmoid(gates[i, 1, j]),
                              name="forget_gate")
    out_gate = tvm.compute(gshape,
                           lambda i, j: tvm.sigmoid(gates[i, 2, j]),
                           name="out_gate")
    in_transform = tvm.compute(gshape,
                               lambda i, j: tvm.tanh(gates[i, 3, j]),
                               name="in_transform")

    # C_t = F_t o C_t-1 + I_t o C'_t
    state_c = tvm.compute((batch_size, hidden_size),
                          lambda i, j: forget_gate[i, j] * Prev_c[i, j] +
                          in_gate[i, j] * in_transform[i, j],
                          name="state_c")
    # h_t = O_t o tanh( C_t )
    # state_h = tvm.compute((batch_size, hidden_size),
    #    lambda i, j: out_gate[i, j] * tvm.tanh(state_c[i, j]), name="state_h")
    out_c, out_h = tvm.compute(
        (batch_size, hidden_size),
        lambda i, j: (state_c[i, j], out_gate[i, j] * tvm.tanh(state_c[i, j])),
        name="outputs_c_h")
    # schedule
    s = tvm.create_schedule(out_h.op)
    print(
        tvm.lower(s, [X, Prev_h, Prev_c, Wi2h, Bi2h, Wh2h, Bh2h, out_c, out_h],
                  simple_mode=True))
    lstm = tvm.build(s,
                     [X, Prev_h, Prev_c, Wi2h, Bi2h, Wh2h, Bh2h, out_c, out_h],
                     name="single_lstm")
    print(lstm)

    lstm.save("remy_single_lstm.o")
    print(lstm.imported_modules)
    cc.create_shared("remy_single_lstm.so", ["remy_single_lstm.o"])
Example #10
0
File: lstm.py Project: gwli/tvm
def lstm():
    if not PERSIST_KERNEL:
        raise ValueError("Non persist LSTM not yet supported")
    detect_global_barrier = DETECT_GLOBAL_BARRIER
    num_thread_y = 8
    num_thread_x = 16 * 3 / 2
    num_sm = 24
    n_num_step = 128
    num_step = tvm.var('num_step')
    num_hidden = 1152 / 2
    batch_size = 1
    # Global transition matrix
    # Input hidden channel can be pre-caculated by a gemm
    Xi2h = tvm.placeholder((num_step, batch_size, 4, num_hidden), name="Xi2h")
    # Only handle hidden transition, saves space.
    Wh2h = tvm.placeholder((4, num_hidden, num_hidden), name="Wh2h")
    # h: output hidden state, c: cell state.
    s_state_h = tvm.placeholder((num_step, batch_size, num_hidden))
    s_state_c = tvm.placeholder((num_step, batch_size, num_hidden))
    s_init_c = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_c")
    s_init_h = tvm.compute((1, batch_size, num_hidden),
                           lambda *i: 0.0, name="init_h")
    # LSTM transition
    k = tvm.reduce_axis((0, num_hidden), name="ki2h")
    s_h2h = tvm.compute(
        (num_step, batch_size, 4, num_hidden),
        lambda t, i, x, j: tvm.sum(s_state_h[t - 1, i, k] * Wh2h[x, j, k], axis=k),
        name="s_h2h")
    # Gate rules
    gates = tvm.compute(Xi2h.shape, lambda *i:
                        Xi2h(*i) + s_h2h(*i), name="gates")
    gshape = (num_step, batch_size, num_hidden)
    in_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 0, j]), name="in_gate")
    in_transform = tvm.compute(gshape, lambda t, i, j: tvm.tanh(gates[t, i, 1, j]), name="in_transform")
    forget_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 2, j]), name="forget_gate")
    out_gate = tvm.compute(gshape, lambda t, i, j: tvm.sigmoid(gates[t, i, 3, j]), name="out_gate")
    next_c = tvm.compute(gshape,
                         lambda t, i, j:
                         forget_gate[t, i, j] * s_state_c[t - 1, i, j] +
                         in_gate[t, i, j] * in_transform[t, i, j], name="next_c")
    next_h = tvm.compute(gshape,
                         lambda t, i, j: out_gate[t, i, j] * tvm.tanh(next_c[t, i, j]), name="next_h")
    update_c = tvm.compute(gshape, lambda *i: next_c(*i), name="update_c")
    update_h = tvm.compute(gshape, lambda *i: next_h(*i), name="update_h")
    # schedule
    scan_h, scan_c = tvm.scan(
        [s_init_h, s_init_c],
        [update_h, update_c],
        [s_state_h, s_state_c],
        inputs=[Xi2h],
        name="lstm_scan")
    # schedule
    s = tvm.create_schedule(scan_h.op)
    # Inline gate computations
    s[gates].compute_inline()
    s[in_gate].compute_inline()
    s[in_transform].compute_inline()
    s[forget_gate].compute_inline()
    s[out_gate].compute_inline()

    block_x = tvm.thread_axis((0, num_sm), "blockIdx.x")
    thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
    thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")

    s_state_h_S = s.cache_read(s_state_h, "shared", [s_h2h])
    s_state_c_S = s.cache_read(s_state_c, "shared", [next_c])
    Wh2hL = s.cache_read(Wh2h, "local", [s_h2h])

    ko, ki = s[s_h2h].split(s[s_h2h].op.reduce_axis[0], nparts=num_thread_y)
    s_h2h_rf = s.rfactor(s_h2h, ko)
    s[s_h2h].bind(s[s_h2h].op.reduce_axis[0], thread_y)
    s[s_h2h_rf].compute_at(s[s_h2h], s[s_h2h].op.reduce_axis[0])

    if PERSIST_KERNEL:
        s[scan_h.op].env_threads([block_x, thread_y, thread_x])
        s[Wh2hL].compute_at(s[scan_h.op], thread_x)
    else:
        s[Wh2hL].compute_at(s[s_h2h], s[s_h2h].op.axis[3])

    if UNROLL_WLOAD:
        s[Wh2hL].unroll(Wh2hL.op.axis[0])
        s[Wh2hL].unroll(Wh2hL.op.axis[2])

    s[s_state_h_S].compute_at(s[s_h2h_rf], s[s_h2h_rf].op.axis[3])
    s[s_state_c_S].compute_at(s[scan_h.op], s[scan_h].op.scan_axis)

    for ss in [s_state_h_S]:
        xo, xi = s[ss].split(ss.op.axis[2], factor=num_thread_x * num_thread_y)
        ty, xi = s[ss].split(xi, nparts=num_thread_y)
        tx, xi = s[ss].split(xi, nparts=num_thread_x)
        s[ss].bind(ty, thread_y)
        s[ss].bind(tx, thread_x)

    for init in [s_init_c, s_init_h]:
        bx, xi = s[init].split(init.op.axis[2], nparts=num_sm)
        tx, xi = s[init].split(xi, nparts=num_thread_x)
        s[init].bind(bx, block_x)
        s[init].bind(tx, thread_x)

    s[next_c].set_store_predicate(thread_y.equal(0))
    s[next_h].set_store_predicate(thread_y.equal(0))

    for update in [update_c, update_h]:
        bx, xi = s[update].split(s[update].op.axis[2], nparts=num_sm)
        tx, xi = s[update].split(xi, nparts=num_thread_x)
        s[update].bind(bx, block_x)
        s[update].bind(tx, thread_x)
        s[update].set_store_predicate(thread_y.equal(0))

    # verify we can lower correctly
    def check_device(target):
        num_step = n_num_step
        flstm = tvm.build(s, [Xi2h, Wh2h, scan_h, scan_c],
                          target)
        ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0)
        # launch the kernel.
        scan_h_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        scan_c_np = np.zeros(
            (num_step, batch_size, num_hidden)).astype("float32")
        Xi2h_np = np.random.normal(
            size=(num_step, batch_size, 4, num_hidden)).astype("float32")
        Wh2h_np = np.random.normal(
            size=(4, num_hidden, num_hidden)).astype("float32")
        scan_h_a = tvm.nd.array(scan_h_np, ctx)
        scan_c_a = tvm.nd.array(scan_c_np, ctx)
        Xi2h_a = tvm.nd.array(Xi2h_np, ctx)
        Wh2h_a = tvm.nd.array(Wh2h_np, ctx)
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        # measure time cost of second step.
        tstart = time.time()
        flstm(Xi2h_a, Wh2h_a, scan_h_a, scan_c_a)
        ctx.sync()
        tgap = time.time() - tstart
        print("Time cost=%g" % tgap)

    # set unroll_explicit for more readable code.
    with tvm.build_config(
            detect_global_barrier=DETECT_GLOBAL_BARRIER,
            auto_unroll_max_step=128,
            unroll_explicit=False):
        check_device("cuda")