Esempio n. 1
0
def _spatial_pack_data_only(wkl, sch, data):
    H, W = wkl.height, wkl.width
    CI, CO = wkl.in_filter, wkl.out_filter
    KH, KW = wkl.hkernel, wkl.wkernel
    HPAD, WPAD = wkl.hpad, wkl.wpad
    HSTR, WSTR = wkl.hstride, wkl.wstride
    HCAT, WCAT = KH - 1, KW - 1

    VH = sch.vh
    VW = sch.vw
    VC = sch.vc
    UNROLL = sch.unroll

    TH = H + 2 * HPAD
    TW = W + 2 * WPAD
    OH = (H + 2 * HPAD - KH) // HSTR + 1
    OW = (W + 2 * WPAD - KW) // WSTR + 1

    dshape = (1, CI, H, W)
    dpshape = (1, CI, TH, TW)
    dvshape = (1, TH // (VH * HSTR), TW // (VW * WSTR), CI, VH * HSTR + HCAT,
               VW * WSTR + WCAT)

    DOPAD = (HPAD != 0 and WPAD != 0)
    if DOPAD:
        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
    else:
        data_pad = data

    data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: \
        data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec')

    s = tvm.create_schedule(data_vec.op)
    # traverse(s, data_vec.op)

    # schedule for data_vec
    A0, A1 = data_pad, data_vec
    if DOPAD:
        s[A0].compute_inline()
    n, h, w, ci, vh, vw = s[A1].op.axis
    s[A1].fuse(vh, vw)
    if sch.ba == 1:
        oaxis = h
        paxis = h
    else:
        oh, ih = s[A1].split(h, sch.ba)
        oaxis = oh
        paxis = ih
    s[A1].parallel(paxis)
    s[A1].pragma(oaxis, "parallel_launch_point")
    s[A1].pragma(paxis, "parallel_stride_pattern")
    s[A1].pragma(oaxis, "parallel_barrier_when_finish")

    return data_vec, s
Esempio n. 2
0
def decl_winograd(data, U, stride, padding, out_dtype):
    """declare winograd fast convolution F(2x2, 3x3) for conv2d"""
    N, C, H, W = [util.get_const_int(x) for x in data.shape]
    _, _, C, K = [util.get_const_int(x) for x in U.shape]
    HPAD, WPAD = 1, 1
    if isinstance(stride, (tuple, list)):
        HSTR, WSTR = stride
    else:
        HSTR, WSTR = stride, stride

    assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1
    data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

    m = 2
    r = 3
    alpha = m + r - 1
    K = K
    nH, nW = (H + m - 1) // m, (W + m - 1) // m
    P = N * nH * nW

    # pack input tile
    input_tile = tvm.compute(
        (C, P, alpha, alpha),
        lambda c, b, eps, nu: tvm.select(
            b < P, data_pad[b // (nH * nW)][c][b // nW % nH * m + eps][
                b % nW * m + nu], tvm.const(0, data_pad.dtype)),
        name='d')

    V = decl_V_minimal(input_tile, alpha, C, P)

    # batch gemm
    c = tvm.reduce_axis((0, C), name='c')
    M = tvm.compute(
        (alpha, alpha, K, P),
        lambda eps, nu, k, b: tvm.sum(U[eps][nu][c][k] * V[eps][nu][c][b],
                                      axis=c),
        name='M')

    # inverse transform and unpack
    output = decl_output_minimal(M, N, K, H, W, P, m, nH, nW)

    return output
Esempio n. 3
0
def _depthwise_spatial_pack(args, data, kernel, strides, padding, dilation,
                            out_dtype):
    """depthwise_conv2d_arm_cpu's inner implement"""
    is_var, u_vh, u_vw, u_vc = args
    out_dtype = out_dtype or data.dtype

    u_n, u_c, ih, iw = data.shape if is_var else get_const_tuple(data.shape)

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation

    if len(kernel.shape) == 4:
        pre_packed = False
        u_c, um, ukh, ukw = kernel.shape if is_var else get_const_tuple(
            kernel.shape)
    else:  # kernel tensor is pre packed
        pre_packed = True
        u_c, um, ukh, ukw, u_vc = kernel.shape if is_var else get_const_tuple(
            kernel.shape)
        u_c = u_c * u_vc

    dilated_kernel_h = (ukh - 1) * dilation_h + 1
    dilated_kernel_w = (ukw - 1) * dilation_w + 1

    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    hstr, wstr = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    u_oh = (ih + pad_top + pad_down - dilated_kernel_h) // hstr + 1
    u_ow = (iw + pad_left + pad_right - dilated_kernel_w) // wstr + 1
    # pack data
    hpad = pad_top + pad_down
    wpad = pad_left + pad_right
    dopad = hpad != 0 or wpad != 0
    if dopad:
        data_pad = pad(
            data,
            (0, 0, pad_top, pad_left),
            (0, 0, pad_down, pad_right),
            name="data_pad",
        )
    else:
        data_pad = data

    oh_div = u_oh // u_vh
    ow_div = u_ow // u_vw
    kvshape = (u_c // u_vc, um, ukh, ukw, u_vc)
    ovshape = (u_n, u_c * um // u_vc, oh_div, u_ow // u_vw, u_vh, u_vw, u_vc)
    oshape = (u_n, u_c * um, oh_div * u_vh, ow_div * u_vw)

    if dilation_h != 1 or dilation_w != 1:
        # undilate input data
        dvshape = (u_n, oh_div, ow_div, u_c, ukh, ukw, u_vh, u_vw)
        data_vec = tvm.compute(
            dvshape,
            lambda n, h, w, c, kh, kw, vh, vw: data_pad[n][c][
                (h * u_vh + vh) * hstr + kh * dilation_h][
                    (w * u_vw + vw) * wstr + kw * dilation_w],
            name="data_vec_undilated",
        )
    else:
        dvshape = (u_n, oh_div, ow_div, u_c, u_vh * hstr + ukh - 1,
                   u_vw * wstr + ukw - 1)
        data_vec = tvm.compute(
            dvshape,
            lambda n, h, w, c, vh, vw: data_pad[n][c][h * u_vh * hstr + vh][
                w * u_vw * wstr + vw],
            name="data_vec",
        )

    if pre_packed:
        kernel_vec = kernel
    else:
        kernel_vec = tvm.compute(
            kvshape,
            lambda co, m, kh, kw, vc: kernel[co * u_vc + vc][m][kh][kw],
            name="kernel_vec",
        )

    kh = tvm.reduce_axis((0, ukh), name="kh")
    kw = tvm.reduce_axis((0, ukw), name="kw")

    if dilation_h != 1 or dilation_w != 1:
        conv = tvm.compute(
            ovshape,
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, (co * u_vc + vc) // um, kh, kw, vh, vw].
                astype(out_dtype) * kernel_vec[co // um, co % um, kh, kw, vc
                                               ].astype(out_dtype),
                axis=[kh, kw],
            ),
            name="depthwise_conv",
        )
    else:
        conv = tvm.compute(
            ovshape,
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, (co * u_vc + vc) // um, vh * hstr + kh, vw *
                         wstr + kw].astype(out_dtype) * kernel_vec[
                             co // um, co % um, kh, kw, vc].astype(out_dtype),
                axis=[kh, kw],
            ),
            name="depthwise_conv",
        )

    output = tvm.compute(
        oshape,
        lambda n, co, h, w: conv[n][co // u_vc][h // u_vh][w // u_vw][h % u_vh]
        [w % u_vw][co % u_vc],
        name="output_unpack",
        tag="spatial_depthwise_conv_nchw_output",
    )
    return output
Esempio n. 4
0
def decl_winograd(cfg,
                  data,
                  kernel,
                  strides,
                  padding,
                  layout,
                  out_dtype,
                  VK=6,
                  VP=8,
                  packed_output=False):
    # return _baseline_winograd(cfg, data, kernel, strides, padding, layout, out_dtype)
    N, CI, IH, IW = get_const_tuple(data.shape)
    CO, _, KH, KW = get_const_tuple(kernel.shape)
    HSTR, WSTR = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)

    assert layout == 'NCHW'
    assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
    data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

    A_data = np.array(
        [[1, 1, 1, 1, 1, 32, 32, 0], [0, 1, -1, 2, -2, 16, -16, 0],
         [0, 1, 1, 4, 4, 8, 8, 0], [0, 1, -1, 8, -8, 4, -4, 0],
         [0, 1, 1, 16, 16, 2, 2, 0], [0, 1, -1, 32, -32, 1, -1, 1]],
        dtype=np.float32).T
    G_data = np.array(
        [[1, 0, 0], [-2 / 9, -2 / 9, -2 / 9], [-2 / 9, 2 / 9, -2 / 9],
         [1 / 90, 1 / 45, 2 / 45], [1 / 90, -1 / 45, 2 / 45],
         [1 / 45, 1 / 90, 1 / 180], [1 / 45, -1 / 90, 1 / 180], [0, 0, 1]],
        dtype=np.float32)
    B_data = np.array([[1, 0, -21 / 4, 0, 21 / 4, 0, -1, 0],
                       [0, 1, 1, -17 / 4, -17 / 4, 1, 1, 0],
                       [0, -1, 1, 17 / 4, -17 / 4, -1, 1, 0],
                       [0, 1 / 2, 1 / 4, -5 / 2, -5 / 4, 2, 1, 0],
                       [0, -1 / 2, 1 / 4, 5 / 2, -5 / 4, -2, 1, 0],
                       [0, 2, 4, -5 / 2, -5, 1 / 2, 1, 0],
                       [0, -2, 4, 5 / 2, -5, -1 / 2, 1, 0],
                       [0, -1, 0, 21 / 4, 0, -21 / 4, 0, 1]],
                      dtype=np.float32).T

    m = A_data.shape[1]
    r = 3
    alpha = m + r - 1

    C = CI

    H = (IH + 2 * HPAD - 3) // HSTR + 1
    W = (IW + 2 * WPAD - 3) // WSTR + 1
    nH, nW = (H + m - 1) // m, (W + m - 1) // m

    def round_up(a, b):
        return ((a + b - 1) // b) * b

    K = round_up(CO, VK)
    P = round_up(N * nH * nW, VP)

    assert K % VK == 0
    assert P % VP == 0

    G = const_matrix(G_data, 'G')
    r_kh = tvm.reduce_axis((0, KH), 'r_kh')
    r_kw = tvm.reduce_axis((0, KW), 'r_kw')
    assert K >= CO
    if K > CO:
        kernel_pad = pad(kernel, (0, 0, 0, 0), (K - CO, 0, 0, 0),
                         name="kernel_pad")
    else:
        kernel_pad = kernel
    input_tile = tvm.placeholder(shape=(P // VP, C, alpha, alpha, VP),
                                 dtype='float32',
                                 name="input_tile")
    U = tvm.placeholder(shape=(K // VK, alpha, alpha, C, VK),
                        dtype='float32',
                        name="U")

    #U = tvm.compute(
    #    (K // VK, alpha, alpha, C, VK), lambda k, eps, nu, c, kk:
    #    tvm.sum(kernel_pad[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
    #            G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

    ## pack input tile
    #input_tile = tvm.compute((P // VP, C, alpha, alpha, VP),
    #                         lambda b, c, eps, nu, bb:
    #                         data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
    #                         [(b*VP+bb) % nW * m + nu],
    #                         name='d')

    def compute_B_T_dot_X(b, c, eps, nu, bb):
        temp_expr = {}
        for j in range(alpha):
            wd0 = input_tile[b][c][0][j][bb] - input_tile[b][c][6][j][bb]
            d4_sub_d2 = input_tile[b][c][4][j][bb] - input_tile[b][c][2][j][bb]
            wd7 = input_tile[b][c][7][j][bb] - input_tile[b][c][1][j][bb]
            d3_sub_d5 = input_tile[b][c][3][j][bb] - input_tile[b][c][5][j][bb]
            wd1 = input_tile[b][c][2][j][bb] + input_tile[b][c][6][j][bb]
            wd2 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb]
            wd4 = input_tile[b][c][5][j][bb] + input_tile[b][c][1][j][bb] * 0.25
            wd5 = input_tile[b][c][6][j][bb] - input_tile[b][c][4][j][bb] * 5
            wd3 = input_tile[b][c][6][j][bb] + input_tile[b][c][2][j][bb] * 0.25
            wd6 = input_tile[b][c][1][j][bb] + input_tile[b][c][5][j][bb] * 0.25

            wd0 = wd0 + d4_sub_d2 * 5.25
            wd7 = wd7 + d3_sub_d5 * 5.25

            wd1 = wd1 - input_tile[b][c][4][j][bb] * 4.25
            wd2 = wd2 - input_tile[b][c][3][j][bb] * 4.25

            wd3 = wd3 - input_tile[b][c][4][j][bb] * 1.25
            wd5 = wd5 + input_tile[b][c][2][j][bb] * 4
            wd4 = wd4 - input_tile[b][c][3][j][bb] * 1.25
            wd6 = wd6 - input_tile[b][c][3][j][bb] * 1.25

            temp_expr[(0, j)] = wd0
            temp_expr[(1, j)] = wd1 + wd2
            temp_expr[(2, j)] = wd1 - wd2
            temp_expr[(3, j)] = wd3 + wd4 * 2
            temp_expr[(4, j)] = wd3 - wd4 * 2
            temp_expr[(5, j)] = wd5 + wd6 * 2
            temp_expr[(6, j)] = wd5 - wd6 * 2
            temp_expr[(7, j)] = wd7

        now = tvm.const(0.0, "float32")
        for ii in range(alpha):
            for jj in range(alpha):
                now = tvm.select(tvm.all(eps == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    B_T_dot_X = tvm.compute((P // VP, C, alpha, alpha, VP),
                            compute_B_T_dot_X,
                            name="B_T_dot_X")

    def compute_X_dot_B(b, eps, nu, c, bb):
        temp_expr = {}

        for i in range(alpha):
            wd0 = B_T_dot_X[b][c][i][0][bb] - B_T_dot_X[b][c][i][6][bb]
            d4_sub_d2 = B_T_dot_X[b][c][i][4][bb] - B_T_dot_X[b][c][i][2][bb]
            wd7 = B_T_dot_X[b][c][i][7][bb] - B_T_dot_X[b][c][i][1][bb]
            d3_sub_d5 = B_T_dot_X[b][c][i][3][bb] - B_T_dot_X[b][c][i][5][bb]
            wd1 = B_T_dot_X[b][c][i][2][bb] + B_T_dot_X[b][c][i][6][bb]
            wd2 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb]
            wd4 = B_T_dot_X[b][c][i][5][bb] + B_T_dot_X[b][c][i][1][bb] * 0.25
            wd5 = B_T_dot_X[b][c][i][6][bb] - B_T_dot_X[b][c][i][4][bb] * 5
            wd3 = B_T_dot_X[b][c][i][6][bb] + B_T_dot_X[b][c][i][2][bb] * 0.25
            wd6 = B_T_dot_X[b][c][i][1][bb] + B_T_dot_X[b][c][i][5][bb] * 0.25

            wd0 = wd0 + d4_sub_d2 * 5.25
            wd7 = wd7 + d3_sub_d5 * 5.25

            wd1 = wd1 - B_T_dot_X[b][c][i][4][bb] * 4.25
            wd2 = wd2 - B_T_dot_X[b][c][i][3][bb] * 4.25

            wd3 = wd3 - B_T_dot_X[b][c][i][4][bb] * 1.25
            wd5 = wd5 + B_T_dot_X[b][c][i][2][bb] * 4
            wd4 = wd4 - B_T_dot_X[b][c][i][3][bb] * 1.25
            wd6 = wd6 - B_T_dot_X[b][c][i][3][bb] * 1.25

            temp_expr[(i, 0)] = wd0
            temp_expr[(i, 1)] = wd1 + wd2
            temp_expr[(i, 2)] = wd1 - wd2
            temp_expr[(i, 3)] = wd3 + wd4 * 2
            temp_expr[(i, 4)] = wd3 - wd4 * 2
            temp_expr[(i, 5)] = wd5 + wd6 * 2
            temp_expr[(i, 6)] = wd5 - wd6 * 2
            temp_expr[(i, 7)] = wd7

        now = tvm.const(0.0, "float32")
        for ii in range(alpha):
            for jj in range(alpha):
                now = tvm.select(tvm.all(eps == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    V = tvm.compute((P // VP, alpha, alpha, C, VP), compute_X_dot_B, name="V")

    # batch gemm
    c = tvm.reduce_axis((0, C), name='c')
    M = tvm.compute((K // VK, P // VP, alpha, alpha, VK, VP),
                    lambda k, b, eps, nu, kk, bb: tvm.sum(
                        U[k][eps][nu][c][kk] * V[b][eps][nu][c][bb], axis=c),
                    name='M')

    def compute_A_T_dot_M(k, b, eps, nu, kk, bb):
        temp_expr = {}

        for j in range(alpha):
            m1_add_m2 = M[k][b][1][j][kk][bb] + M[k][b][2][j][kk][bb]
            m1_sub_m2 = M[k][b][1][j][kk][bb] - M[k][b][2][j][kk][bb]
            m3_add_m4 = M[k][b][3][j][kk][bb] + M[k][b][4][j][kk][bb]
            m3_sub_m4 = M[k][b][3][j][kk][bb] - M[k][b][4][j][kk][bb]
            m5_add_m6 = M[k][b][5][j][kk][bb] + M[k][b][6][j][kk][bb]
            m5_sub_m6 = M[k][b][5][j][kk][bb] - M[k][b][6][j][kk][bb]
            s0 = M[k][b][0][j][kk][bb] + m1_add_m2
            s5 = M[k][b][7][j][kk][bb] + m1_sub_m2
            s1 = m1_sub_m2 + m5_sub_m6 * 16
            s4 = m1_add_m2 + m3_add_m4 * 16
            s2 = m1_add_m2 + 8 * m5_add_m6
            s3 = m1_sub_m2 + 8 * m3_sub_m4
            s0 = s0 + m5_add_m6 * 32
            s5 = s5 + m3_sub_m4 * 32
            s1 = s1 + m3_sub_m4 * 2
            s4 = s4 + m5_add_m6 * 2
            s0 = s0 + m3_add_m4
            s5 = s5 + m5_sub_m6
            s2 = s2 + m3_add_m4 * 4
            s3 = s3 + m5_sub_m6 * 4
            temp_expr[(0, j)] = s0
            temp_expr[(1, j)] = s1
            temp_expr[(2, j)] = s2
            temp_expr[(3, j)] = s3
            temp_expr[(4, j)] = s4
            temp_expr[(5, j)] = s5
        now = tvm.const(0.0, "float32")
        for ii in range(m):
            for jj in range(alpha):
                now = tvm.select(tvm.all(eps == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    A_T_dot_M = tvm.compute((K // VK, P // VP, m, alpha, VK, VP),
                            compute_A_T_dot_M,
                            name="A_T_dot_M")

    def compute_X_dot_A(k, b, eps, nu, kk, bb):
        temp_expr = {}

        for i in range(m):
            m1_add_m2 = A_T_dot_M[k][b][i][1][kk][bb] + A_T_dot_M[k][b][i][2][
                kk][bb]
            m1_sub_m2 = A_T_dot_M[k][b][i][1][kk][bb] - A_T_dot_M[k][b][i][2][
                kk][bb]
            m3_add_m4 = A_T_dot_M[k][b][i][3][kk][bb] + A_T_dot_M[k][b][i][4][
                kk][bb]
            m3_sub_m4 = A_T_dot_M[k][b][i][3][kk][bb] - A_T_dot_M[k][b][i][4][
                kk][bb]
            m5_add_m6 = A_T_dot_M[k][b][i][5][kk][bb] + A_T_dot_M[k][b][i][6][
                kk][bb]
            m5_sub_m6 = A_T_dot_M[k][b][i][5][kk][bb] - A_T_dot_M[k][b][i][6][
                kk][bb]
            s0 = A_T_dot_M[k][b][i][0][kk][bb] + m1_add_m2
            s5 = A_T_dot_M[k][b][i][7][kk][bb] + m1_sub_m2
            s1 = m1_sub_m2 + m5_sub_m6 * 16
            s4 = m1_add_m2 + m3_add_m4 * 16
            s2 = m1_add_m2 + 8 * m5_add_m6
            s3 = m1_sub_m2 + 8 * m3_sub_m4
            s0 = s0 + m5_add_m6 * 32
            s5 = s5 + m3_sub_m4 * 32
            s1 = s1 + m3_sub_m4 * 2
            s4 = s4 + m5_add_m6 * 2
            s0 = s0 + m3_add_m4
            s5 = s5 + m5_sub_m6
            s2 = s2 + m3_add_m4 * 4
            s3 = s3 + m5_sub_m6 * 4
            temp_expr[(i, 0)] = s0
            temp_expr[(i, 1)] = s1
            temp_expr[(i, 2)] = s2
            temp_expr[(i, 3)] = s3
            temp_expr[(i, 4)] = s4
            temp_expr[(i, 5)] = s5
        now = tvm.const(0.0, "float32")
        for ii in range(m):
            for jj in range(m):
                now = tvm.select(tvm.all(eps == ii, nu == jj),
                                 temp_expr[(ii, jj)], now)
        return now

    Y = tvm.compute((K // VK, P // VP, m, m, VK, VP),
                    compute_X_dot_A,
                    name="Y")

    # unpack output
    def _output(n, k_, h, w):
        b_idx = n * nH * nW + (h // m) * nW + w // m
        b = b_idx // VP
        bb = b_idx % VP
        k = k_ // VK
        kk = k_ % VK
        return Y[k][b][h % m][w % m][kk][bb]

    output = tvm.compute((N, CO, H, W),
                         _output,
                         name='output',
                         tag='winograd_conv_output')

    if cfg:
        cfg.add_flop(2 * N * K * H * W * KH * KW * C)

    return Y, input_tile, U, output
Esempio n. 5
0
def _conv_spatial_pack_asm(args, data, kernel, strides, padding, dilation,
                           out_dtype):
    """_conv_spatial_pack_asm"""
    is_var, vh_, vw_, vc_ = args

    # create workload according to raw arguments
    out_dtype = out_dtype or data.dtype
    n_, ci_, ih_, iw_ = data.shape if is_var else get_const_tuple(data.shape)

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation

    if len(kernel.shape) == 4:
        pre_packed = False
        co_, _, kh_, kw_ = kernel.shape if is_var else get_const_tuple(
            kernel.shape)
    else:  # kernel tensor is pre packed
        pre_packed = True
        co_, _, kh_, kw_, vc_ = kernel.shape if is_var else get_const_tuple(
            kernel.shape)
        co_ = co_ * vc_

    dilated_kernel_h = (kh_ - 1) * dilation_h + 1
    dilated_kernel_w = (kw_ - 1) * dilation_w + 1
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w))
    hstr, wstr = strides if isinstance(strides,
                                       (tuple, list)) else (strides, strides)
    oh_ = (ih_ + pad_top + pad_bottom - dilated_kernel_h) // hstr + 1
    ow_ = (iw_ + pad_left + pad_right - dilated_kernel_w) // wstr + 1
    data_pad = pad(data, [0, 0, pad_top, pad_left],
                   [0, 0, pad_bottom, pad_right])

    oh_div = oh_ // vh_
    ow_div = ow_ // vw_
    kvshape = (co_ // vc_, ci_, kh_, kw_, vc_)
    ovshape = (n_, co_ // vc_, oh_div, ow_div, vh_, vw_, vc_)
    oshape = (n_, co_, oh_div * vh_, ow_div * vw_)

    if dilation_h != 1 or dilation_w != 1:
        # undilate input data
        dvshape = (n_, oh_ // vh_, ow_ // vw_, kh_, kw_, vh_, vw_, ci_)
        data_vec = tvm.compute(
            dvshape,
            lambda n, h, w, kh, kw, vh, vw, ci: data_pad[n][ci][
                (h * vh_ + vh) * hstr + kh * dilation_h][
                    (w * vw_ + vw) * wstr + kw * dilation_w],
            name="data_vec_undilated",
        )
    else:
        dvshape = (
            n_,
            oh_ // vh_,
            ow_ // vw_,
            (vh_ - 1) * hstr + kh_,
            (vw_ - 1) * wstr + kw_,
            ci_,
        )
        data_vec = tvm.compute(
            dvshape,
            lambda n, h, w, vh, vw, ci: data_pad[n][ci][h * vh_ * hstr + vh][
                w * vw_ * wstr + vw],
            name="data_vec",
        )

    if pre_packed:
        kernel_vec = kernel
    else:
        kernel_vec = tvm.compute(
            kvshape,
            lambda co, ci, kh, kw, vc: kernel[co * vc_ + vc][ci][kh][kw],
            name="kernel_vec",
        )

    ci = tvm.reduce_axis((0, ci_), name="ci")
    kh = tvm.reduce_axis((0, kh_), name="kh")
    kw = tvm.reduce_axis((0, kw_), name="kw")

    # asm begin----
    type_map = {
        "int8": "int32",
        "uint8": "uint32",
        "float32": "float32",
        "float16": "float16",
    }
    acum_dtype = type_map[data.dtype]
    attrs = {
        "SH": hstr,
        "SW": wstr,
        "PH": pad_top,
        "PW": pad_left,
        "DILA_H": dilation_h,
        "DILA_W": dilation_w,
        "VH": vh_,
        "VW": vw_,
        "VC": vc_,
        "ACUM_DTYPE": acum_dtype,
    }
    # asm end----

    if dilation_h != 1 or dilation_w != 1:
        conv = tvm.compute(
            ovshape,
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, kh, kw, vh, vw, ci].astype(out_dtype) *
                kernel_vec[co, ci, kh, kw, vc].astype(out_dtype),
                axis=[ci, kh, kw],
            ),
            name="conv",
            attrs=attrs,
        )
    else:
        conv = tvm.compute(
            ovshape,
            lambda n, co, h, w, vh, vw, vc: tvm.sum(
                data_vec[n, h, w, vh * hstr + kh, vw * wstr + kw, ci].astype(
                    out_dtype) * kernel_vec[co, ci, kh, kw, vc].astype(
                        out_dtype),
                axis=[ci, kh, kw],
            ),
            name="conv",
            attrs=attrs,
        )

    output = tvm.compute(
        oshape,
        lambda n, co, h, w: conv[n][co // vc_][h // vh_][w // vw_][h % vh_][
            w % vw_][co % vc_],
        name="output_unpack",
        tag="asm_conv2d_output",
    )

    return output
Esempio n. 6
0
def decl_winograd(data, U, stride, padding, out_dtype):
    """declare winograd fast convolution F(2x2, 3x3) for conv2d"""
    N, C, H, W = [util.get_const_int(x) for x in data.shape]
    _, _, C, K = [util.get_const_int(x) for x in U.shape]
    HPAD, WPAD = 1, 1
    if isinstance(stride, (tuple, list)):
        HSTR, WSTR = stride
    else:
        HSTR, WSTR = stride, stride

    assert HSTR == 1 and WSTR == 1 and HPAD == 1 and WPAD == 1
    data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

    B_data = np.array(
        [[1, 0, 0, 0], [0, 1, -1, 1], [-1, 1, 1, 0], [0, 0, 0, -1]], out_dtype)

    A_data = np.array([
        [1, 0],
        [1, 1],
        [1, -1],
        [0, -1],
    ], out_dtype)

    m = 2
    r = 3
    alpha = m + r - 1
    K = K

    nH, nW = (H + m - 1) // m, (W + m - 1) // m
    P = N * nH * nW

    # pack input tile
    input_tile = tvm.compute(
        (C, P, alpha, alpha),
        lambda c, b, eps, nu: tvm.select(
            b < P, data_pad[b // (nH * nW)][c][b // nW % nH * m + eps][
                b % nW * m + nu], tvm.const(0, data_pad.dtype)),
        name='d')

    # transform image
    B = const_array(B_data, 'B')
    r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
    r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
    V = tvm.compute((alpha, alpha, C, P),
                    lambda eps, nu, c, b: tvm.sum(input_tile[c][b][r_eps][
                        r_nu] * B[r_eps][eps] * B[r_nu][nu],
                                                  axis=[r_eps, r_nu]),
                    name='V')

    # batch gemm
    c = tvm.reduce_axis((0, C), name='c')
    M = tvm.compute(
        (alpha, alpha, K, P),
        lambda eps, nu, k, b: tvm.sum(U[eps][nu][c][k] * V[eps][nu][c][b],
                                      axis=c),
        name='M')

    # inverse transform and unpack
    A = const_array(A_data, 'A')
    r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
    r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
    output = tvm.compute(
        (N, K, H, W),
        lambda n, k, h, w: tvm.sum(M[r_eps][r_nu][k][n * nH * nW + (
            h // m) * nW + w // m] * A[r_eps][h % m] * A[r_nu][w % m],
                                   axis=[r_eps, r_nu]),
        name='output')

    return output
Esempio n. 7
0
def compute_conv2d_gemm_without_weight_transform(cfg,
                                                 data, B_interleaved_t, strides, padding, dilation,
                                                 out_dtype, kernel_size, output_channels):
    """Compute conv2d by transforming the input,
    executing GEMM and transforming the output back"""
    batches, IH, IW, IC = get_const_tuple(data.shape)

    KH, KW = kernel_size
    OC = output_channels

    K_AREA = KH * KW

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation

    dilated_kernel_h = (KH - 1) * dilation_h + 1
    dilated_kernel_w = (KW - 1) * dilation_w + 1

    pad_top, pad_left, pad_down, pad_right = \
        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)

    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
    if pad_top or pad_left:
        data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
                          name="data_pad")
    else:
        data_pad = data

    # --- Im2col
    M = OH * OW
    K = IC * K_AREA
    N = OC

    A_shape = (batches, M, K)
    if K_AREA == 1:
        A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y],
                       name='data_flatten')
    else:
        A = te.compute(A_shape, lambda n, x, y:
                       data_pad[n,
                                HSTR * (x // OW) + dilation_h * (y // IC) // KW,
                                WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC],
                       name='data_im2col')
    N_transformed = B_interleaved_t.shape[0]

    # --- Pad if necessary
    idxm = tvm.tir.indexmod

    pad_m = 0
    pad_k = 0

    if M % 4 != 0:
        pad_m = 4 - (M % 4)

    if K % 16 != 0:
        pad_k = 16 - (K % 16)

    M_padded = M + pad_m
    K_padded = K + pad_k

    pad_before = (0, 0, 0)
    pad_after = (0, pad_m, pad_k)

    if pad_m != 0 or pad_k != 0:
        A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")

    # --- GEMM: A*B'
    k = te.reduce_axis((0, K_padded), "k")

    A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16),
                               lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
                               name='A_interleaved')

    C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4),
                               lambda b, x, y, w, z:
                               te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)*
                                      B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype),
                                      axis=k),
                               name='C_interleaved')

    # --- Unpack C
    C = te.compute((batches, M, N),
                   lambda b, x, y:
                   C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
                   name="C", tag='injective')

    # --- Produce the conv output
    out_shape = (batches, OH, OW, OC)
    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
                     name='conv2d_gemm_output')

    return out
def _spatial_conv_all(wkl, sch, data, kernel, out_dtype):
    H, W = wkl.height, wkl.width
    CI, CO = wkl.in_filter, wkl.out_filter
    KH, KW = wkl.hkernel, wkl.wkernel
    HPAD, WPAD = wkl.hpad, wkl.wpad
    HSTR, WSTR = wkl.hstride, wkl.wstride
    HCAT, WCAT = KH - 1, KW - 1

    VH = sch.vh
    VW = sch.vw
    VC = sch.vc
    UNROLL = sch.unroll

    TH = H + 2 * HPAD
    TW = W + 2 * WPAD
    OH = (H + 2 * HPAD - KH) // HSTR + 1
    OW = (W + 2 * WPAD - KW) // WSTR + 1

    dshape = (1, CI, H, W)
    dpshape = (1, CI, TH, TW)
    dvshape = (1, TH // (VH * HSTR), TW // (VW * WSTR), CI, VH * HSTR + HCAT, VW * WSTR + WCAT)

    DOPAD = (HPAD != 0 and WPAD != 0)
    if DOPAD:
        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
    else:
        data_pad = data

    data_vec = tvm.compute(dvshape, lambda n, h, w, ci, vh, vw: \
        data_pad[n][ci][h * VH * HSTR + vh][w * VW * WSTR + vw], name='data_vec')

    kshape = (CO, CI, KH, KW)
    kvshape = (CO // VC, CI, KH, KW, VC)

    kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \
        kernel[co * VC + vc][ci][dh][dw], name='kernel_vec')

    ci = tvm.reduce_axis((0, CI), name='ci')
    dh = tvm.reduce_axis((0, KH), name='dh')
    dw = tvm.reduce_axis((0, KW), name='dw')

    ovshape = (1, CO // VC, OH // VH, OW // VW, VH, VW, VC)
    oshape = (1, CO, OH, OW)

    conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
        tvm.sum(data_vec[n, h, w, ci, vh * HSTR + dh, vw * WSTR + dw].astype(out_dtype) *
                kernel_vec[co, ci, dh, dw, vc].astype(out_dtype),
                axis=[ci, dh, dw]), name='conv')
    output = tvm.compute(oshape, lambda n, co, h, w:
    conv[n][co // VC][h // VH][w // VW][h % VH][w % VW][co % VC],
                         name='output_unpack', tag='spatial_conv_output')

    s = tvm.create_schedule(conv.op)
    traverse(s, conv.op)

    # schedule for data_vec
    A0, A1 = data_pad, data_vec
    if DOPAD:
        s[A0].compute_inline()
    _, h, _, _, _, _ = s[A1].op.axis
    if sch.ba == 1:
        oaxis = h
        paxis = h
    else:
        oh, ih = s[A1].split(h, sch.ba)
        oaxis = oh
        paxis = ih
    s[A1].parallel(paxis)
    s[A1].pragma(oaxis, "parallel_launch_point")
    s[A1].pragma(paxis, "parallel_stride_pattern")
    s[A1].pragma(oaxis, "parallel_barrier_when_finish")

    # schedule for kernel_vec
    B, B0 = kernel, kernel_vec
    co, _, _, _, _ = s[B0].op.axis
    if sch.bc == 1:
        oaxis = co
        paxis = co
    else:
        oco, ico = s[B0].split(co, sch.bc)
        oaxis = oco
        paxis = ico
    s[B0].parallel(paxis)
    s[B0].pragma(oaxis, "parallel_launch_point")
    s[B0].pragma(paxis, "parallel_stride_pattern")
    s[B0].pragma(oaxis, "parallel_barrier_when_finish")

    # schedule for conv & unpack
    C0, C = conv, output

    s = tvm.create_schedule(C.op)
    traverse(s, C.op)

    CC = s.cache_write(C0, "global")
    _, co, oh, ow, vh, vw, vc = s[C0].op.axis
    if UNROLL:
        s[C0].unroll(vw)
    s[C0].vectorize(vc)

    s[CC].compute_at(s[C0], ow)
    _, co, oh, ow, vh, vw, vc = s[CC].op.axis
    ci, dh, dw = s[CC].op.reduce_axis
    s[CC].reorder(ci, dh, vh, dw, vw, vc)

    if UNROLL:
        s[CC].unroll(vw)
    s[CC].vectorize(vc)

    n, co, h, w = s[C].op.axis
    co, vc = s[C].split(co, VC)
    oh, ow, vh, vw = s[C].tile(h, w, VH, VW)
    s[C].reorder(n, co, oh, ow, vh, vw, vc)
    # if C != C1:
    #     s[C1].compute_inline()
    s[C0].compute_at(s[C], ow)

    if sch.bc == 1:
        oaxis = co
        paxis = co
    else:
        oco, ico = s[C].split(co, sch.bc)
        oaxis = oco
        paxis = ico

    s[C].parallel(paxis)
    s[C].pragma(oaxis, "parallel_launch_point")
    s[C].pragma(paxis, "parallel_stride_pattern")
    s[C].pragma(oaxis, "parallel_barrier_when_finish")

    return C, s
Esempio n. 9
0
def _im2col_pack(wkl, sch, data, kernel, stride, padding, out_dtype):
    """ Compute convolution with im2col pack layout. """
    assert data.shape[
        0].value == 1, "im2col pack convolution only support batch size=1"

    N = 1
    H, W = wkl.height, wkl.width
    CI = wkl.in_filter
    CO = wkl.out_filter
    KH, KW = wkl.hkernel, wkl.wkernel
    HPAD, WPAD = wkl.hpad, wkl.hpad
    HSTR, WSTR = wkl.hstride, wkl.wstride

    OH = (H + 2 * HPAD - KH) // HSTR + 1
    OW = (W + 2 * WPAD - KW) // WSTR + 1

    P = sch.vp
    Q = sch.vq
    UNROLL = sch.unroll

    dshape = (N, CI, H, W)
    dpshape = (N, CI, H + 2 * HPAD, W + 2 * WPAD)
    dcshape = (N, OH, OW, CI, KH, KW)
    dvshape = (N, OH * OW // P, CI, KH, KW, P)

    kshape = (CO, CI, KH, KW)
    kvshape = (CO // Q, CI, KH, KW, Q)

    ovshape = (N, CO // Q, OH * OW // P, P, Q)
    oshape = (N, CO, OH, OW)

    ############### declaration

    DO_PAD = (wkl.hpad != 0 and wkl.wpad != 0)
    if DO_PAD:
        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
    else:
        data_pad = data

    data_col = tvm.compute(dcshape, lambda n, oh, ow, ci, hk, wk: \
        data_pad[n][ci][oh*HSTR+hk][ow*WSTR+wk], name='data_col')

    data_vec = tvm.compute(dvshape, lambda n, im, ci, hk, wk, vim: \
        data_col[n][(im*P+vim)//OW][(im*P+vim)%OW][ci][hk][wk], name='data_vec')


    kernel_vec = tvm.compute(kvshape, lambda co, ci, dh, dw, vc: \
        kernel[co*Q+vc][ci][dh][dw], name='kernel_vec')

    ci = tvm.reduce_axis((0, CI), name='ci')
    hk = tvm.reduce_axis((0, KH), name='hk')
    wk = tvm.reduce_axis((0, KW), name='wk')

    conv = tvm.compute(ovshape, lambda n, co, im, vim, vco: \
        tvm.sum(data_vec[n][im][ci][hk][wk][vim].astype(out_dtype) *
                kernel_vec[co][ci][hk][wk][vco].astype(out_dtype),
                axis=[ci, hk, wk]), name='conv')

    output = tvm.compute(oshape, lambda n, co, h, w: \
                         conv[n][co//Q][(h*OW+w)//P][(h*OW+w)%P][co%Q],
                         name='output_vec', tag='im2col_conv_output')

    return output