示例#1
0
def _sigmoid_compute(input_x):
    """
    calculating sigmoid
    """
    data_input = input_x
    dtype = input_x.dtype
    exp_support = tbe_platform.cce_conf.api_check_support(
        "te.lang.cce.vexp", "float32")
    mul_support = tbe_platform.cce_conf.api_check_support(
        "te.lang.cce.vmuls", "float32")
    if dtype == "float32" and not mul_support:
        error_manager_vector.raise_err_check_params_rules(
            "DynamicGRU", 'vmuls should support float32', 'mul_support',
            str(mul_support))

    const_num_neg_one = tvm.const(-1, dtype=dtype)
    const_num_one = tvm.const(1, dtype=dtype)
    tmp_negative = tbe.vmuls(data_input, const_num_neg_one)
    if dtype == "float32" and not exp_support:
        tmp_negative = tbe.cast_to(tmp_negative, "float16")
    tmp_exp = tbe.vexp(tmp_negative)
    if dtype == "float32" and not exp_support:
        tmp_exp = tbe.cast_to(tmp_exp, "float32")
    tmp_sum = tbe.vadds(tmp_exp, const_num_one)
    if dtype == "float32":
        inp_shape = tmp_sum.shape
        tensor_one = tbe.broadcast(tvm.const(1, dtype), inp_shape)
        res = tbe.vdiv(tensor_one, tmp_sum)
    else:
        res = tbe.vrec(tmp_sum)

    return res
示例#2
0
def add_compute(input_x, input_y, output_z, kernel_name="add"):
    """
    calculating data's add, c = a + b

    Parameters
    ----------
    input_x: TVM tensor
        the placeholder of first input data
    input_y: TVM tensor
        the placeholder of second input data
    output_data: dict
        shape and dtype of output, should be broadcast shape and type as input
    kernel_name: str
        cce kernel name, default value is add

    Returns
    -------
    res : output of the data's add
    """
    shape_x = shape_util.shape_to_list(input_x.shape)
    shape_y = shape_util.shape_to_list(input_y.shape)

    shape_x, shape_y, shape_max = shape_util.broadcast_shapes(
        shape_x,
        shape_y,
        param_name_input1="input_x",
        param_name_input2="input_y")
    shape_size = reduce(lambda x, y: x * y, shape_max[:])
    if shape_size > SHAPE_SIZE_LIMIT:
        raise RuntimeError("the shape is too large to calculate")

    input_x = tbe.broadcast(input_x, shape_max)
    input_y = tbe.broadcast(input_y, shape_max)
    res = tbe.vadd(input_x, input_y)

    return res
示例#3
0
def leaky_relu_demo_compute(x, y, negative_slope=0, kernel_name="leaky_relu"):
    """
    compute for caffe_relu_layer_cce
    """
    inp_dtype = x.dtype.lower()
    shape = x.shape

    # The original relu logic remains unchanged.
    if negative_slope == 0:
        if inp_dtype in ("float32", "int32"):
            tensor_zero = tbe.broadcast(tvm.const(0, inp_dtype), shape)
            data_res = tbe.vmax(x, tensor_zero)
        else:
            data_res = tbe.vrelu(x)

        data_res = tbe.cast_to(data_res, inp_dtype)

        return data_res
    # negative_slope != 0
    if inp_dtype in ("float16", "float32"):
        slope_tmp = tvm.const(negative_slope, dtype=inp_dtype)
        tmp = tbe.vmuls(x, slope_tmp)
        if negative_slope <= 1:
            res = tbe.vmax(x, tmp)
        else:
            res = tbe.vmin(x, tmp)
    else:
        # inp_dtype in ("int32", "int8")
        slope_tmp = tvm.const(negative_slope, dtype=inp_dtype)
        tmp = tbe.vmuls(x, slope_tmp)
        tmp_oritype = tbe.cast_to(tmp, inp_dtype)
        if negative_slope <= 1:
            res = tbe.vmax(x, tmp_oritype)
        else:
            res = tbe.vmin(x, tmp_oritype)

        res = tbe.cast_to(res, inp_dtype)

    return res
示例#4
0
def _dynamic_gru_inner(input_list, custom_list):
    input_x = input_list[0]
    weight1 = input_list[1]
    weight2 = input_list[2]
    bias1 = input_list[3]
    bias2 = input_list[4]
    s_init_h_gm = input_list[5]
    s_state_h_gm_last = input_list[6]

    is_gate_output = custom_list[0]
    is_first_round = custom_list[1]
    is_global_init = custom_list[2]

    input_dtype = 'float16'
    bias_dtype = bias1.dtype
    fp16_input_output = bias_dtype == 'float16'

    shape_x_input = input_x.shape
    shape_w1_input = weight1.shape
    w1_size = 2
    w2_size = 1
    t_size = shape_x_input[0].value
    m_size = shape_x_input[2].value
    k_size = shape_w1_input[1].value
    hidden_size = shape_w1_input[3].value
    in_x = k_size - hidden_size

    shape_b_1 = (1, k_size, w1_size, hidden_size, 16, 16)
    shape_b_2 = (1, k_size, w2_size, hidden_size, 16, 16)
    shape_c_1 = (1, w1_size, hidden_size, m_size, 16, 16)
    shape_c_2 = (1, w2_size, hidden_size, m_size, 16, 16)
    shape_bias_1 = (1, w1_size, hidden_size, 1, 1, 16)
    shape_bias_2 = (1, hidden_size, 1, 1, 16)
    shape_i = (1, hidden_size, m_size, 16, 16)
    shape_i_t = (t_size, hidden_size, m_size, 16, 16)
    k0_size = 16

    if is_first_round and not is_global_init:
        s_state_h = tvm.compute(
            shape_i,
            lambda *indices: tvm.const(0.0, dtype='float32'),
            name='s_state_h')
        s_state_h_fp16 = tvm.compute(
            shape_i,
            lambda *indices: s_state_h(*indices).astype('float16'),
            name="s_state_h_fp16")
    else:
        last_h = s_init_h_gm if is_first_round else s_state_h_gm_last
        if fp16_input_output:
            s_state_h_fp16 = tvm.compute(shape_i,
                                         lambda *indices: last_h(*indices),
                                         name='s_state_h_fp16')
            s_state_h = tvm.compute(
                shape_i,
                lambda *indices: s_state_h_fp16(*indices).astype('float32'),
                name="s_state_h")
        else:
            s_state_h = tvm.compute(shape_i,
                                    lambda *indices: last_h(*indices),
                                    name='s_state_h')
            s_state_h_fp16 = tvm.compute(
                shape_i,
                lambda *indices: s_state_h(*indices).astype('float16'),
                name="s_state_h_fp16")

    # compute
    # input and s_state_h need first to ub and cast to float16
    shape_a_z_bigz = (1, m_size, k_size, 16, 16)

    # input and s_start_h is Nz, need trans to zZ
    # so change axis 1 and 2
    a_l1_1 = tvm.compute(
        shape_a_z_bigz,
        lambda *indice: tvm.select(
            indice[2] < in_x, input_x[indice[0], indice[2], indice[1], indice[
                3], indice[4]], s_state_h_fp16[0, indice[2] - in_x, indice[1],
                                               indice[3], indice[4]]),
        name="a_l1_1",
        tag="concat")
    b_l1_1 = tvm.compute(shape_b_1,
                         lambda *indices: weight1(*indices),
                         name='b_l1_1')
    a_l0a_1 = tvm.compute(shape_a_z_bigz,
                          lambda *indices: a_l1_1(*indices),
                          name="a_l0a_1")
    b_l0b_1 = tvm.compute(shape_b_1,
                          lambda *indices: b_l1_1(*indices),
                          name="b_l0b_1")
    k1_1 = tvm.reduce_axis((0, k_size), name='k1_1')
    k0_1 = tvm.reduce_axis((0, k0_size), name='k0_1')
    c_l0c_1 = tvm.compute(shape_c_1,
                          lambda t, nb_0, nb_1, mb, mp, np:
                          tvm.sum((a_l0a_1[t, mb, k1_1, mp, k0_1] * \
                                   b_l0b_1[t, k1_1, nb_0, nb_1, np, k0_1]) \
                                  .astype('float32'),
                                  axis=[k1_1, k0_1]),
                          name='c_l0c_1')
    c_ub_1 = tvm.compute(shape_c_1,
                         lambda *indices: c_l0c_1(*indices),
                         name="c_ub_1")
    bias_ub_1 = tvm.compute(shape_bias_1,
                            lambda *indices: bias1(*indices),
                            name='bias_ub_1')
    bias_ub_1_fp32 = bias_ub_1
    if fp16_input_output:
        bias_ub_1_fp32 = tvm.compute(
            shape_bias_1,
            lambda *indices: bias_ub_1(*indices).astype('float32'),
            name="bias_ub_1_fp32")
    bias_bc_ub_1 = tbe.broadcast(bias_ub_1_fp32, shape_c_1)
    c_ub_bias_1 = tbe.vadd(c_ub_1, bias_bc_ub_1)

    # split matmul res
    r_t_index = 0
    i_t_index = 1
    r_t = tvm.compute(
        shape_i,
        lambda t, i, j, k, l: c_ub_bias_1(t, r_t_index, i, j, k, l),
        name="r_t")
    i_t = tvm.compute(
        shape_i,
        lambda t, i, j, k, l: c_ub_bias_1(t, i_t_index, i, j, k, l),
        name="i_t")
    r_t_sigmoid = _sigmoid_compute(r_t)
    i_t_sigmoid = _sigmoid_compute(i_t)
    r_t_mid = r_t_sigmoid
    i_t_mid = i_t_sigmoid
    if is_gate_output:
        if fp16_input_output:
            r_t_sigmoid_fp16 = tvm.compute(
                shape_i,
                lambda *indices: r_t_sigmoid(*indices).astype('float16'),
                name="r_t_sigmoid_fp16")
            i_t_sigmoid_fp16 = tvm.compute(
                shape_i,
                lambda *indices: i_t_sigmoid(*indices).astype('float16'),
                name="i_t_sigmoid_fp16")

            r_t_gm = tvm.compute(shape_i,
                                 lambda *indices: r_t_sigmoid_fp16(*indices),
                                 name="r_t_gm")
            i_t_gm = tvm.compute(shape_i,
                                 lambda *indices: i_t_sigmoid_fp16(*indices),
                                 name="i_t_gm")

            r_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: r_t_gm(*indices),
                                      name="r_t_gm_back")
            i_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: i_t_gm(*indices),
                                      name="i_t_gm_back")

            r_t_gm_back_fp32 = tvm.compute(
                shape_i,
                lambda *indices: r_t_gm_back(*indices).astype('float32'),
                name="r_t_gm_back_fp32")
            i_t_gm_back_fp32 = tvm.compute(
                shape_i,
                lambda *indices: i_t_gm_back(*indices).astype('float32'),
                name="i_t_gm_back_fp32")

            r_t_mid = r_t_gm_back_fp32
            i_t_mid = i_t_gm_back_fp32
        else:
            r_t_gm = tvm.compute(shape_i,
                                 lambda *indices: r_t_sigmoid(*indices),
                                 name="r_t_gm")
            i_t_gm = tvm.compute(shape_i,
                                 lambda *indices: i_t_sigmoid(*indices),
                                 name="i_t_gm")

            r_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: r_t_gm(*indices),
                                      name="r_t_gm_back")
            i_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: i_t_gm(*indices),
                                      name="i_t_gm_back")

            r_t_mid = r_t_gm_back
            i_t_mid = i_t_gm_back
    r_t_h = tbe.vmul(r_t_mid, s_state_h)
    r_t_h_fp16 = \
        tvm.compute(shape_i,
                    lambda *indices: r_t_h(*indices).astype(input_dtype),
                    name="r_t_h_fp16")

    # second matmul
    a_l1_2 = tvm.compute(
        shape_a_z_bigz,
        lambda *indice: tvm.select(
            indice[2] < in_x, input_x[indice[0], indice[2], indice[1], indice[
                3], indice[4]], r_t_h_fp16[0, indice[2] - in_x, indice[1],
                                           indice[3], indice[4]]),
        name="a_l1_2",
        tag="concat")

    b_l1_2 = tvm.compute(shape_b_2,
                         lambda *indices: weight2(*indices),
                         name='b_l1_2')
    a_l0a_2 = tvm.compute(shape_a_z_bigz,
                          lambda *indices: a_l1_2(*indices),
                          name="a_l0a_2")
    b_l0b_2 = tvm.compute(shape_b_2,
                          lambda *indices: b_l1_2(*indices),
                          name="b_l0b_2")
    k1_2 = tvm.reduce_axis((0, k_size), name='k1_2')
    k0_2 = tvm.reduce_axis((0, k0_size), name='k0_2')
    c_l0c_2 = tvm.compute(shape_c_2,
                          lambda t, nb_0, nb_1, mb, mp, np:
                          tvm.sum((a_l0a_2[t, mb, k1_2, mp, k0_2] * \
                                   b_l0b_2[t, k1_2, nb_0, nb_1, np, k0_2]) \
                                  .astype('float32'),
                                  axis=[k1_2, k0_2]),
                          name='c_l0c_2')
    c_ub_2 = tvm.compute(shape_i,
                         lambda t, h, m, i, j: c_l0c_2(t, 0, h, m, i, j),
                         name="c_ub_2")
    bias_ub_2 = tvm.compute(shape_bias_2,
                            lambda t, h, m, i, j: bias2(t, h, m, i, j),
                            name='bias_ub_2')
    bias_ub_2_fp32 = bias_ub_2
    if fp16_input_output:
        bias_ub_2_fp32 = tvm.compute(
            shape_bias_2,
            lambda *indices: bias_ub_2(*indices).astype('float32'),
            name="bias_ub_2_fp32")
    bias_bc_ub_2 = tbe.broadcast(bias_ub_2_fp32, shape_i)
    c_ub_bias_2 = tbe.vadd(c_ub_2, bias_bc_ub_2)

    h_t_tanh = _tanh_compute(c_ub_bias_2)
    h_t_tanh_mid = h_t_tanh
    if is_gate_output:
        if fp16_input_output:
            h_t_tanh_fp16 = tvm.compute(
                shape_i,
                lambda *indices: h_t_tanh(*indices).astype('float16'),
                name="h_t_tanh_fp16")
            n_t_gm = tvm.compute(shape_i,
                                 lambda *indices: h_t_tanh_fp16(*indices),
                                 name="n_t_gm")
            n_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: n_t_gm(*indices),
                                      name="n_t_gm_back")
            n_t_gm_back_fp32 = tvm.compute(
                shape_i,
                lambda *indices: n_t_gm_back(*indices).astype('float32'),
                name="n_t_gm_back_fp32")
            h_t_tanh_mid = n_t_gm_back_fp32
        else:
            n_t_gm = tvm.compute(shape_i,
                                 lambda *indices: h_t_tanh(*indices),
                                 name="n_t_gm")
            n_t_gm_back = tvm.compute(shape_i,
                                      lambda *indices: n_t_gm(*indices),
                                      name="n_t_gm_back")
            h_t_tanh_mid = n_t_gm_back

    c_t_tmp1 = tbe.vsub(s_state_h, h_t_tanh_mid)
    c_t_tmp2 = tbe.vmul(c_t_tmp1, i_t_mid)
    update_h = tbe.vadd(c_t_tmp2, h_t_tanh_mid)
    update_h_ub = update_h
    if fp16_input_output:
        update_h_fp16 = tvm.compute(
            shape_i_t,
            lambda *indices: update_h(*indices).astype('float16'),
            name="update_h_fp16")
        update_h_ub = update_h_fp16
    update_y_gm = tvm.compute(shape_i_t,
                              lambda t, i, j, k, l: update_h_ub(0, i, j, k, l),
                              name="update_y_gm")
    update_y_gm_back = tvm.compute(
        shape_i_t,
        lambda t, i, j, k, l: update_y_gm(0, i, j, k, l),
        name="update_y_gm_back")
    update_h_gm = tvm.compute(
        shape_i_t,
        lambda t, i, j, k, l: update_y_gm_back(0, i, j, k, l),
        name="update_h_gm")
    # end compute

    # schedule
    s = tvm.schedule.create_schedule([update_h_gm.op])

    def gen_reversed_subgraph_list(out_tensor, tensor_list):
        """
        traverse tensors by Depth-First-Search
        """
        if out_tensor is None:
            return
        stack = [out_tensor]
        visited_list = []
        while stack:
            cur_tensor = stack.pop()
            visited_list.append(cur_tensor)
            for in_tensor in cur_tensor.op.input_tensors:
                if in_tensor not in visited_list:
                    stack.append(in_tensor)
                    if "elewise" in in_tensor.op.tag or \
                            "broadcast" == in_tensor.op.tag:
                        if in_tensor not in tensor_list:
                            tensor_list.append(in_tensor)

    elewise_tensors_r_t_h_fp16 = []
    gen_reversed_subgraph_list(r_t_h_fp16, elewise_tensors_r_t_h_fp16)

    elewise_tensors = []
    tmp_tensors = []
    gen_reversed_subgraph_list(update_h_gm, tmp_tensors)
    for i in tmp_tensors:
        if i not in elewise_tensors_r_t_h_fp16:
            elewise_tensors.append(i)

    # set scope
    s[s_state_h].set_scope(tbe_platform.scope_ubuf)
    s[s_state_h_fp16].set_scope(tbe_platform.scope_ubuf)
    s[a_l1_1].set_scope(tbe_platform.scope_cbuf)
    s[b_l1_1].set_scope(tbe_platform.scope_cbuf)
    s[a_l0a_1].set_scope(tbe_platform.scope_ca)
    s[b_l0b_1].set_scope(tbe_platform.scope_cb)
    s[c_l0c_1].set_scope(tbe_platform.scope_cc)
    s[c_ub_1].set_scope(tbe_platform.scope_ubuf)
    s[bias_ub_1].set_scope(tbe_platform.scope_ubuf)
    s[bias_bc_ub_1].set_scope(tbe_platform.scope_ubuf)
    s[r_t_h_fp16].set_scope(tbe_platform.scope_ubuf)
    s[a_l1_2].set_scope(tbe_platform.scope_cbuf)
    s[b_l1_2].set_scope(tbe_platform.scope_cbuf)
    s[a_l0a_2].set_scope(tbe_platform.scope_ca)
    s[b_l0b_2].set_scope(tbe_platform.scope_cb)
    s[c_l0c_2].set_scope(tbe_platform.scope_cc)
    s[c_ub_2].set_scope(tbe_platform.scope_ubuf)
    s[bias_ub_2].set_scope(tbe_platform.scope_ubuf)
    s[bias_bc_ub_2].set_scope(tbe_platform.scope_ubuf)
    s[update_y_gm_back].set_scope(tbe_platform.scope_ubuf)
    if is_gate_output:
        s[r_t_gm_back].set_scope(tbe_platform.scope_ubuf)
        s[i_t_gm_back].set_scope(tbe_platform.scope_ubuf)
        s[n_t_gm_back].set_scope(tbe_platform.scope_ubuf)
        if fp16_input_output:
            s[r_t_sigmoid_fp16].set_scope(tbe_platform.scope_ubuf)
            s[i_t_sigmoid_fp16].set_scope(tbe_platform.scope_ubuf)
            s[h_t_tanh_fp16].set_scope(tbe_platform.scope_ubuf)
            s[r_t_gm_back_fp32].set_scope(tbe_platform.scope_ubuf)
            s[i_t_gm_back_fp32].set_scope(tbe_platform.scope_ubuf)
            s[n_t_gm_back_fp32].set_scope(tbe_platform.scope_ubuf)
    if fp16_input_output:
        s[bias_ub_1_fp32].set_scope(tbe_platform.scope_ubuf)
        s[bias_ub_2_fp32].set_scope(tbe_platform.scope_ubuf)
        s[update_h_fp16].set_scope(tbe_platform.scope_ubuf)

    # compute inline
    compute_inline_tensors = [i_t, r_t]
    for tensor in compute_inline_tensors:
        s[tensor].compute_inline()

    # matmul tiling
    factor_l1_m, factor_l1_n, factor_l1_k, factor_l0_m, factor_l0_n, factor_l0_k = \
        _get_tiling(m_size, k_size, hidden_size)

    l1_n_outer_1, l1_n_inner_1 = s[c_l0c_1].split(c_l0c_1.op.axis[2],
                                                  factor=factor_l1_n)
    l1_m_outer_1, l1_m_inner_1 = s[c_l0c_1].split(c_l0c_1.op.axis[3],
                                                  factor=factor_l1_m)
    l1_k_outer_1, l1_k_inner_1 = s[c_l0c_1].split(c_l0c_1.op.reduce_axis[0],
                                                  factor=factor_l1_k)
    l0_n_outer_1, l0_n_inner_1 = s[c_l0c_1].split(l1_n_inner_1,
                                                  factor=factor_l0_n)
    l0_m_outer_1, l0_m_inner_1 = s[c_l0c_1].split(l1_m_inner_1,
                                                  factor=factor_l0_m)
    l0_k_outer_1, l0_k_inner_1 = s[c_l0c_1].split(l1_k_inner_1,
                                                  factor=factor_l0_k)
    s[c_l0c_1].reorder(c_l0c_1.op.axis[0], l1_n_outer_1, l1_k_outer_1,
                       c_l0c_1.op.axis[1], l1_m_outer_1, l0_n_outer_1,
                       l0_m_outer_1, l0_k_outer_1, l0_n_inner_1, l0_m_inner_1,
                       c_l0c_1.op.axis[4], c_l0c_1.op.axis[5], l0_k_inner_1,
                       c_l0c_1.op.reduce_axis[1])
    s[a_l1_1].double_buffer()
    s[b_l1_1].double_buffer()
    s[a_l0a_1].double_buffer()
    s[b_l0b_1].double_buffer()
    s[c_l0c_1].double_buffer()
    s[c_ub_1].double_buffer()
    s[a_l1_1].compute_at(s[c_l0c_1], l1_k_outer_1)
    s[b_l1_1].compute_at(s[c_l0c_1], c_l0c_1.op.axis[1])
    s[a_l0a_1].compute_at(s[c_l0c_1], l1_k_outer_1)
    s[b_l0b_1].compute_at(s[c_l0c_1], l0_k_outer_1)

    c_ub_bias_1_outer, c_ub_bias_1_inner = s[c_ub_bias_1].split(
        c_ub_bias_1.op.axis[2], factor=factor_l1_n)
    s[c_ub_bias_1].reorder(c_ub_bias_1.op.axis[0], c_ub_bias_1_outer,
                           c_ub_bias_1.op.axis[1], c_ub_bias_1_inner,
                           c_ub_bias_1.op.axis[3], c_ub_bias_1.op.axis[4],
                           c_ub_bias_1.op.axis[5])
    s[c_l0c_1].compute_at(s[c_ub_bias_1], c_ub_bias_1_outer)
    s[c_ub_1].compute_at(s[c_ub_bias_1], c_ub_bias_1_outer)
    s[bias_ub_1].compute_at(s[c_ub_bias_1], c_ub_bias_1_outer)
    s[bias_bc_ub_1].compute_at(s[c_ub_bias_1], c_ub_bias_1_outer)
    if fp16_input_output:
        s[bias_ub_1_fp32].compute_at(s[c_ub_bias_1], c_ub_bias_1_outer)
    s[c_ub_bias_1].emit_insn(c_ub_bias_1.op.axis[1], 'vector_add')

    r_t_h_fp16_outer, r_t_h_fp16_inner = s[r_t_h_fp16].split(
        r_t_h_fp16.op.axis[1], factor=factor_l1_n)
    for tensor in elewise_tensors_r_t_h_fp16:
        s[tensor].set_scope(tbe_platform.scope_ubuf)
        if tensor == c_ub_bias_1:
            continue
        s[tensor].compute_at(s[r_t_h_fp16], r_t_h_fp16_outer)
        insn = _get_emit_insn_map(tensor)
        s[tensor].emit_insn(tensor.op.axis[0], insn)
    if is_gate_output:
        s[r_t_gm].compute_at(s[r_t_h_fp16], r_t_h_fp16_outer)
        s[r_t_gm_back].compute_at(s[r_t_h_fp16], r_t_h_fp16_outer)
        if fp16_input_output:
            s[r_t_sigmoid_fp16].compute_at(s[r_t_h_fp16], r_t_h_fp16_outer)
            s[r_t_gm_back_fp32].compute_at(s[r_t_h_fp16], r_t_h_fp16_outer)
    s[r_t_h_fp16].emit_insn(r_t_h_fp16_inner, 'vector_conv')

    l1_n_outer_2, l1_n_inner_2 = s[c_l0c_2].split(c_l0c_2.op.axis[2],
                                                  factor=factor_l1_n)
    l1_m_outer_2, l1_m_inner_2 = s[c_l0c_2].split(c_l0c_2.op.axis[3],
                                                  factor=factor_l1_m)
    l1_k_outer_2, l1_k_inner_2 = s[c_l0c_2].split(c_l0c_2.op.reduce_axis[0],
                                                  factor=factor_l1_k)
    l0_n_outer_2, l0_n_inner_2 = s[c_l0c_2].split(l1_n_inner_2,
                                                  factor=factor_l0_n)
    l0_m_outer_2, l0_m_inner_2 = s[c_l0c_2].split(l1_m_inner_2,
                                                  factor=factor_l0_m)
    l0_k_outer_2, l0_k_inner_2 = s[c_l0c_2].split(l1_k_inner_2,
                                                  factor=factor_l0_k)
    s[c_l0c_2].reorder(c_l0c_2.op.axis[0], l1_n_outer_2, l1_k_outer_2,
                       c_l0c_2.op.axis[1], l1_m_outer_2, l0_n_outer_2,
                       l0_m_outer_2, l0_k_outer_2, l0_n_inner_2, l0_m_inner_2,
                       c_l0c_2.op.axis[4], c_l0c_2.op.axis[5], l0_k_inner_2,
                       c_l0c_2.op.reduce_axis[1])
    s[a_l1_2].double_buffer()
    s[b_l1_2].double_buffer()
    s[a_l0a_2].double_buffer()
    s[b_l0b_2].double_buffer()
    s[c_l0c_2].double_buffer()
    s[c_ub_2].double_buffer()
    s[a_l1_2].compute_at(s[c_l0c_2], l1_k_outer_2)
    s[b_l1_2].compute_at(s[c_l0c_2], c_l0c_2.op.axis[1])
    s[a_l0a_2].compute_at(s[c_l0c_2], l1_k_outer_2)
    s[b_l0b_2].compute_at(s[c_l0c_2], l0_k_outer_2)

    update_h_gm_outer, update_h_gm_inner = s[update_h_gm].split(
        update_h_gm.op.axis[1], factor=factor_l1_n)
    s[c_l0c_2].compute_at(s[update_h_gm], update_h_gm_outer)
    s[c_ub_2].compute_at(s[update_h_gm], update_h_gm_outer)
    s[bias_ub_2].compute_at(s[update_h_gm], update_h_gm_outer)
    s[bias_bc_ub_2].compute_at(s[update_h_gm], update_h_gm_outer)
    s[c_ub_bias_2].compute_at(s[update_h_gm], update_h_gm_outer)
    s[update_y_gm].compute_at(s[update_h_gm], update_h_gm_outer)
    s[update_y_gm_back].compute_at(s[update_h_gm], update_h_gm_outer)
    if fp16_input_output:
        s[bias_ub_2_fp32].compute_at(s[update_h_gm], update_h_gm_outer)
        s[update_h_fp16].compute_at(s[update_h_gm], update_h_gm_outer)
    if is_gate_output:
        s[i_t_gm].compute_at(s[update_h_gm], update_h_gm_outer)
        s[i_t_gm_back].compute_at(s[update_h_gm], update_h_gm_outer)
        s[n_t_gm].compute_at(s[update_h_gm], update_h_gm_outer)
        s[n_t_gm_back].compute_at(s[update_h_gm], update_h_gm_outer)
        if fp16_input_output:
            s[i_t_sigmoid_fp16].compute_at(s[update_h_gm], update_h_gm_outer)
            s[i_t_gm_back_fp32].compute_at(s[update_h_gm], update_h_gm_outer)
            s[h_t_tanh_fp16].compute_at(s[update_h_gm], update_h_gm_outer)
            s[n_t_gm_back_fp32].compute_at(s[update_h_gm], update_h_gm_outer)

    for tensor in elewise_tensors:
        s[tensor].set_scope(tbe_platform.scope_ubuf)
        s[tensor].compute_at(s[update_h_gm], update_h_gm_outer)
        insn = _get_emit_insn_map(tensor)
        s[tensor].emit_insn(tensor.op.axis[0], insn)

    # emit insn
    if is_first_round and not is_global_init:
        s[s_state_h].emit_insn(s_state_h.op.axis[0], 'broadcast')
        s[s_state_h_fp16].emit_insn(s_state_h_fp16.op.axis[0], 'vector_conv')
    else:
        if fp16_input_output:
            s[s_state_h_fp16].emit_insn(s_state_h_fp16.op.axis[0], 'dma_copy')
            s[s_state_h].emit_insn(s_state_h.op.axis[0], 'vector_conv')
        else:
            s[s_state_h].emit_insn(s_state_h.op.axis[0], 'dma_copy')
            s[s_state_h_fp16].emit_insn(s_state_h_fp16.op.axis[0],
                                        'vector_conv')

    s[a_l1_1].emit_insn(a_l1_1.op.axis[0], 'dma_copy')
    s[b_l1_1].emit_insn(b_l1_1.op.axis[0], 'dma_copy')
    s[a_l0a_1].emit_insn(a_l0a_1.op.axis[0], 'dma_copy')
    s[b_l0b_1].emit_insn(b_l0b_1.op.axis[0], 'dma_copy')
    mad_dict = {"mad_pattern": 0, "k_outer": [l1_k_outer_1, l0_k_outer_1]}
    s[c_l0c_1].emit_insn(l0_n_inner_1, 'mad', mad_dict)
    s[c_ub_1].emit_insn(c_ub_1.op.axis[0], 'dma_copy')
    s[bias_ub_1].emit_insn(bias_ub_1.op.axis[0], 'dma_copy')
    if fp16_input_output:
        s[bias_ub_1_fp32].emit_insn(bias_ub_1_fp32.op.axis[0], 'vector_conv')
        s[bias_ub_2_fp32].emit_insn(bias_ub_2_fp32.op.axis[0], 'vector_conv')
        s[update_h_fp16].emit_insn(update_h_fp16.op.axis[0], 'vector_conv')
    s[bias_bc_ub_1].emit_insn(bias_bc_ub_1.op.axis[0], 'unified_broadcast')
    s[a_l1_2].emit_insn(a_l1_2.op.axis[0], 'dma_copy')
    s[b_l1_2].emit_insn(b_l1_2.op.axis[0], 'dma_copy')
    s[a_l0a_2].emit_insn(a_l0a_2.op.axis[0], 'dma_copy')
    s[b_l0b_2].emit_insn(b_l0b_2.op.axis[0], 'dma_copy')
    mad_dict = {"mad_pattern": 0, "k_outer": [l1_k_outer_2, l0_k_outer_2]}
    s[c_l0c_2].emit_insn(l0_n_inner_2, 'mad', mad_dict)
    s[c_ub_2].emit_insn(c_ub_2.op.axis[0], 'dma_copy')
    s[bias_ub_2].emit_insn(bias_ub_2.op.axis[0], 'dma_copy')
    s[bias_bc_ub_2].emit_insn(bias_bc_ub_2.op.axis[0], 'unified_broadcast')
    s[update_y_gm].emit_insn(update_y_gm.op.axis[0], 'dma_copy')
    s[update_y_gm_back].emit_insn(update_y_gm_back.op.axis[0], 'phony_insn')
    s[update_y_gm_back].reused_by(update_h_ub)
    if is_gate_output:
        s[r_t_gm].emit_insn(r_t_gm.op.axis[0], 'dma_copy')
        s[i_t_gm].emit_insn(i_t_gm.op.axis[0], 'dma_copy')
        s[n_t_gm].emit_insn(n_t_gm.op.axis[0], 'dma_copy')
        s[r_t_gm_back].emit_insn(r_t_gm_back.op.axis[0], 'phony_insn')
        s[i_t_gm_back].emit_insn(i_t_gm_back.op.axis[0], 'phony_insn')
        s[n_t_gm_back].emit_insn(n_t_gm_back.op.axis[0], 'phony_insn')
        if fp16_input_output:
            s[r_t_sigmoid_fp16].emit_insn(r_t_sigmoid_fp16.op.axis[0],
                                          'vector_conv')
            s[i_t_sigmoid_fp16].emit_insn(i_t_sigmoid_fp16.op.axis[0],
                                          'vector_conv')
            s[h_t_tanh_fp16].emit_insn(h_t_tanh_fp16.op.axis[0], 'vector_conv')
            s[r_t_gm_back_fp32].emit_insn(r_t_gm_back_fp32.op.axis[0],
                                          'phony_insn')
            s[i_t_gm_back_fp32].emit_insn(i_t_gm_back_fp32.op.axis[0],
                                          'phony_insn')
            s[n_t_gm_back_fp32].emit_insn(n_t_gm_back_fp32.op.axis[0],
                                          'phony_insn')
            s[r_t_gm_back_fp32].reused_by(r_t_sigmoid)
            s[i_t_gm_back_fp32].reused_by(i_t_sigmoid)
            s[n_t_gm_back_fp32].reused_by(h_t_tanh)
            s[r_t_gm_back].reused_by(r_t_sigmoid_fp16)
            s[i_t_gm_back].reused_by(i_t_sigmoid_fp16)
            s[n_t_gm_back].reused_by(h_t_tanh_fp16)
        else:
            s[r_t_gm_back].reused_by(r_t_sigmoid)
            s[i_t_gm_back].reused_by(i_t_sigmoid)
            s[n_t_gm_back].reused_by(h_t_tanh)
    s[update_h_gm].emit_insn(update_h_gm_inner, 'dma_copy')

    output_list = [update_y_gm, update_h_gm]
    if is_gate_output:
        output_list.append(r_t_gm)
        output_list.append(i_t_gm)
        output_list.append(n_t_gm)
    return output_list, s