示例#1
0
def _default_dense_pack_config(cfg, M, N, K):
    vec_width = get_fp32_len()

    tilex_ii = 1
    for bn in range(vec_width*2, 0, -1):
        if N % bn == 0:
            tilex_ii = bn
            break
    NN = N // tilex_ii
    tilex_oi = 1
    while NN // tilex_oi > 4:
        if (NN // tilex_oi) % 2 == 1:
            break
        tilex_oi *= 2

    tiley_ii = 8
    while M % tiley_ii != 0:
        tiley_ii //= 2
    MM = M // tiley_ii
    tiley_oi = 1
    while MM // tiley_oi > 4:
        if (MM // tiley_oi) % 2 == 1:
            break
        tiley_oi *= 2

    cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii])
    cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii])
    cfg["tile_k"] = SplitEntity([K, 1])
示例#2
0
def _fallback_schedule_int8(cfg, wkl):
    HPAD, WPAD = wkl.hpad, wkl.wpad
    HSTR, WSTR = wkl.hstride, wkl.wstride
    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1

    oc_bn = 16
    assert wkl.out_filter % oc_bn == 0

    ic_bn = 1
    for bn in range(oc_bn, 0, -4):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break
    assert wkl.in_filter % 4 == 0

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#3
0
def _default_dense_pack_config(cfg, M, N, K):
    # Generate default schedule for dynamic shape.
    if isinstance(M, tvm.expr.Var):
        M = 16
    if isinstance(N, tvm.expr.Var):
        N = 16
    if isinstance(K, tvm.expr.Var):
        K = 16

    vec_width = get_fp32_len()
    tilex_ii = 1
    for bn in range(vec_width * 2, 0, -1):
        if N % bn == 0:
            tilex_ii = bn
            break
    NN = N // tilex_ii
    tilex_oi = 1
    while NN // tilex_oi > 4:
        if (NN // tilex_oi) % 2 == 1:
            break
        tilex_oi *= 2

    tiley_ii = 8
    while M % tiley_ii != 0:
        tiley_ii //= 2
    MM = M // tiley_ii
    tiley_oi = 1
    while MM // tiley_oi > 4:
        if (MM // tiley_oi) % 2 == 1:
            break
        tiley_oi *= 2

    cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii])
    cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii])
    cfg["tile_k"] = SplitEntity([K, 1])
示例#4
0
def _fallback_schedule_int8(cfg, wkl):
    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1

    oc_bn = 16
    assert wkl.out_filter % oc_bn == 0

    ic_bn = 1
    for bn in range(oc_bn, 0, -4):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break
    assert wkl.in_filter % 4 == 0

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#5
0
def _fallback_schedule(cfg, wkl):
    simd_width = get_fp32_len()
    HPAD, WPAD = wkl.hpad, wkl.wpad
    HSTR, WSTR = wkl.hstride, wkl.wstride
    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    for ow_factor in range(out_width, 0, -1):
        if out_width % ow_factor == 0:
            for oh_factor in range(out_height, 0, -1):
                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
                    cfg["tile_ic"] = SplitEntity(
                        [wkl.in_filter // ic_bn, ic_bn])
                    cfg["tile_oc"] = SplitEntity(
                        [wkl.out_filter // oc_bn, oc_bn])
                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
                    cfg["tile_ow"] = SplitEntity(
                        [out_width // ow_factor, ow_factor])
                    return
    raise ValueError(
        "cannot decide default schedule for workload: {}".format(wkl))
示例#6
0
def _fallback_schedule(cfg, wkl):
    simd_width = get_fp32_len()
    DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad
    DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride
    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(7, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break
    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#7
0
def _fallback_schedule(cfg, wkl):
    simd_width = get_simd_32bit_lanes()
    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1

    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#8
0
def _default_dense_nopack_config(cfg, M, N, K):
    vec_width = get_fp32_len()
    tilek_bn = 1
    for bn in range(vec_width*2, 0, -1):
        if K % bn == 0:
            tilek_bn = bn
            break
    cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn])
    cfg["tile_x"] = SplitEntity([N, 1])
    cfg["tile_y"] = SplitEntity([1, M])
示例#9
0
def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
    """Fallback schedule for 1x1 conv2d int8 on cpu.
    Normally the inner most pattern takes two int8/uint8 tensors
    data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
    produces a dot product int32/uint32 output[int32_lanes].

    Parameters
    ----------
    int32_lanes : int
        How many numbers of int32/uint32 will be produced using intrinsic.
        This is related to output channel.
    num_int8_elements : int
        How many numbers of input int32/uint32 will be multiplied and reduced.
        This is related to input channel.
    """
    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    out_height = (wkl.height + pt + pb - wkl.kernel_h) // HSTR + 1
    out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1

    assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
        wkl.out_filter,
        int32_lanes,
    )
    assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % (
        wkl.in_filter,
        num_int8_elements,
    )

    oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
    ic_bn = 1
    for bn in range(oc_bn, 0, -4):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    for ow_factor in range(out_width, 0, -1):
        if out_width % ow_factor == 0:
            for oh_factor in range(out_height, 0, -1):
                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
                    cfg["tile_ic"] = SplitEntity(
                        [wkl.in_filter // ic_bn, ic_bn])
                    cfg["tile_oc"] = SplitEntity(
                        [wkl.out_filter // oc_bn, oc_bn])
                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
                    cfg["tile_ow"] = SplitEntity(
                        [out_width // ow_factor, ow_factor])
                    return
    raise ValueError(
        "cannot decide default schedule for workload: {}".format(wkl))
示例#10
0
def schedule_dense_small_batch(cfg, s, C):
    """Schedule float32/64 dense with small batch size"""
    A, _ = C.op.input_tensors
    _, in_dim = get_const_tuple(A.shape)
    cfg.define_split('tile_k', in_dim, num_outputs=2)
    if cfg.is_fallback:
        cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])

    _, kf = cfg['tile_k'].apply(s, C, C.op.reduce_axis[0])
    CF = s.rfactor(C, kf)

    if C.op in s.outputs:
        Out = C
    else:
        Out = s.outputs[0].output(0)
        s[C].compute_at(s[Out], s[Out].op.axis[1])
    s[Out].bind(s[Out].op.axis[0], tvm.thread_axis("blockIdx.y"))
    s[Out].bind(s[Out].op.axis[1], tvm.thread_axis("blockIdx.x"))

    tx = s[C].op.reduce_axis[0]
    thread_x = tvm.thread_axis("threadIdx.x")
    s[C].bind(tx, thread_x)
    s[CF].compute_at(s[C], tx)
    s[C].set_store_predicate(thread_x.var.equal(0))
    s[Out].set_store_predicate(thread_x.var.equal(0))
示例#11
0
    def _callback(op):
        if op.tag == "sparse_dense_bsrmm":
            y_bsrmm = op.input_tensors[0]
            assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
            out = s.outputs[0].output(0)
            (_, c) = s[y_bsrmm].op.reduce_axis

            (m_o, n_o) = s[out].op.axis
            s[out].bind(m_o, te.thread_axis("blockIdx.x"))
            s[out].bind(n_o, te.thread_axis("blockIdx.y"))
            s[y_bsrmm].compute_at(s[out], n_o)

            thread_x = te.thread_axis("threadIdx.x")

            cfg.define_split("tile_c", c, num_outputs=2)
            if cfg.is_fallback:
                cfg["tile_c"] = SplitEntity([-1, 8])
            _, ci = cfg['tile_c'].apply(s, y_bsrmm, c)

            y_bsrmm_factored = s.rfactor(y_bsrmm, ci)
            tx = s[y_bsrmm].op.reduce_axis[0]
            s[y_bsrmm].bind(tx, thread_x)
            s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
            s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
            s[out].set_store_predicate(thread_x.var.equal(0))
示例#12
0
def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes,
                                      num_int8_elements):
    """Fallback schedule for conv2d int8 on cpu.
    Normally the inner most pattern takes two int8/uint8 tensors
    data[num_int8_elements] and kernel[int32_lanes, num_int8_elements],
    produces a dot product int32/uint32 output[int32_lanes].

    Parameters
    ----------
    int32_lanes : int
        How many numbers of int32/uint32 will be produced using intrinsic.
        This is related to output channel.
    num_int8_elements : int
        How many numbers of input int32/uint32 will be multiplied and reduced.
        This is related to input channel.
    """
    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

    assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
        wkl.out_filter,
        int32_lanes,
    )
    assert wkl.in_filter % num_int8_elements == 0, "wkl.in_filter=%d, num_int8_elements=%d" % (
        wkl.in_filter,
        num_int8_elements,
    )

    oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
    ic_bn = 1
    for bn in range(oc_bn, 0, -4):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#13
0
def _default_dense_nopack_config(cfg, M, N, K):
    # Generate default schedule for dynamic shape.
    if isinstance(M, tvm.expr.Var):
        M = 16
    if isinstance(N, tvm.expr.Var):
        N = 16
    if isinstance(K, tvm.expr.Var):
        K = 16

    vec_width = get_fp32_len()
    tilek_bn = 1
    for bn in range(vec_width * 2, 0, -1):
        if K % bn == 0:
            tilek_bn = bn
            break
    cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn])
    cfg["tile_x"] = SplitEntity([N, 1])
    cfg["tile_y"] = SplitEntity([1, M])
示例#14
0
def _get_default_config(cfg,
                        data,
                        kernel,
                        strides,
                        padding,
                        out_dtype,
                        is_depthwise=False):
    if is_depthwise:
        raise RuntimeError("Depthwise not supported for intel graphics.")

    batch_size, in_channel, height, width = get_const_tuple(data.shape)
    out_channel, _, hkernel, _ = get_const_tuple(kernel.shape)
    HSTR, _ = strides

    ic_bn = 1
    oc_bn, oc_bn_upper = 16, 16
    for i in range(oc_bn_upper, 0, -1):
        if out_channel % i == 0:
            oc_bn = i
            break

    if HSTR == 2:
        if out_channel + hkernel == 515:
            block_oh = 4
            block_ow = 4
        else:
            block_oh = 4
            block_ow = 5
    elif hkernel == 3:
        if out_channel == 512:
            block_oh = 2
            block_ow = 7
        else:
            block_oh = 2
            block_ow = 14
    else:
        block_oh = 1
        block_ow = 16
    cfg["tile_ic"] = SplitEntity([in_channel // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([out_channel // oc_bn, oc_bn])
    cfg["block_oh"] = OtherOptionEntity(block_oh)
    cfg["block_ow"] = OtherOptionEntity(block_ow)
示例#15
0
def get_default_conv2d_config(cfg, fc, y, x):
    """Defines conv2d default parameters for split axis for Adreno conv2d and depthwise conv2d"""
    # look for vthread params:
    vy = 1
    for n in range(5, 0, -1):
        if y % n == 0:
            vy = n
            break

    vx = 1
    for n in range(5, 0, -1):
        if x % n == 0 and vy * n < 9:
            vx = n
            break

    y = y // vy
    x = x // vx

    tfc = 1
    for n in range(64, 0, -1):
        if fc % n == 0:
            tfc = n
            break
    ty = 1
    for n in range(16, 0, -1):
        if y % n == 0 and tfc * n <= 512:
            ty = n
            break
    tx = 1
    for n in range(16, 0, -1):
        if x % n == 0 and tfc * ty * n <= 512:
            tx = n
            break

    fc = fc // tfc
    y = y // ty
    x = x // tx

    cfg["tile_fc"] = SplitEntity([fc, 1, tfc])
    cfg["tile_y"] = SplitEntity([y, vy, ty])
    cfg["tile_x"] = SplitEntity([x, vx, tx])
示例#16
0
def _fallback_schedule(cfg, wkl):
    """
    Get default schedule for the workload
    Parameters
    ----------
    cfg : tvm.autotvm.task.space.FallbackConfigEntity
        Fallback config to be updated
    wkl : topi.nn.depthwise_conv2d.Workload
        Convolution workload
    """
    simd_width = get_simd_32bit_lanes()

    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1

    out_width = (wkl.width - dilated_kernel_w + pl + pr) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#17
0
 def _fallback_schedule(N, F, Y, X):
     # pylint: disable=unused-argument
     # split N (batch dimension)
     if N > 1:
         cfg["tile_n"] = SplitEntity([-1, 1, 1, 4])
     else:
         cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
     # split F (output channel dimension)
     if F > 1:
         cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
     # split Y (height dimension)
     y_split_factor = 1
     for candidate in range(5, 17):
         if Y % candidate == 0:
             y_split_factor = candidate
             break
     cfg["tile_y"] = SplitEntity([-1, 1, 1, y_split_factor])
     # split X (width dimension)
     x_split_factor = 1
     for candidate in range(5, 17):
         if X % candidate == 0:
             x_split_factor = candidate
             break
     cfg["tile_x"] = SplitEntity([-1, x_split_factor, 1, 1])
     # split RC (input channel dimension, which is a reduction axis)
     cfg["tile_rc"] = SplitEntity([-1, 1, 16])
     # other configurations
     cfg["fuse_yx"] = OtherOptionEntity(False)
     cfg["unroll_explicit"] = OtherOptionEntity(True)
     cfg["auto_unroll_max_step"] = OtherOptionEntity(1500)
示例#18
0
def _fallback_schedule(cfg, wkl):
    """
    Get default schedule for the workload
    Parameters
    ----------
    cfg : tvm.autotvm.task.space.FallbackConfigEntity
        Fallback config to be updated
    wkl : topi.nn.depthwise_conv2d.Workload
        Convolution workload
    """
    simd_width = get_fp32_len()

    HPAD, WPAD = wkl.hpad, wkl.wpad
    HSTR, WSTR = wkl.hstride, wkl.wstride
    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#19
0
def _fallback_schedule(cfg, wkl):
    simd_width = 4  # assume ARM SIMD Width is 4
    pad_left, pad_right = wkl.padl, wkl.padr
    stride_w = wkl.stride_w
    out_width = (wkl.width + pad_left + pad_right -
                 wkl.kernel_w) // stride_w + 1
    groups = wkl.groups
    kernels_per_group = wkl.out_filter // groups
    kernel_depth = wkl.in_filter // groups

    oc_bn = 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if kernels_per_group % bn == 0:
            oc_bn = bn
            break
    if oc_bn > kernels_per_group:
        oc_bn = kernels_per_group

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if kernel_depth % bn == 0:
            ic_bn = bn
            break
    if ic_bn > kernel_depth:
        ic_bn = kernel_depth

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
    cfg["unroll_kw"] = OtherOptionEntity(False)
示例#20
0
def _fallback_schedule(cfg, wkl):
    simd_width = get_simd_32bit_lanes()
    pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
    HSTR, WSTR = wkl.stride_h, wkl.stride_w
    dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1
    dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1

    out_height = (wkl.height + pt + pb - dilated_kernel_h) // HSTR + 1
    out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if wkl.out_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if wkl.in_filter % bn == 0:
            ic_bn = bn
            break

    for ow_factor in range(out_width, 0, -1):
        if out_width % ow_factor == 0:
            for oh_factor in range(out_height, 0, -1):
                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
                    cfg["tile_ic"] = SplitEntity(
                        [wkl.in_filter // ic_bn, ic_bn])
                    cfg["tile_oc"] = SplitEntity(
                        [wkl.out_filter // oc_bn, oc_bn])
                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
                    cfg["tile_ow"] = SplitEntity(
                        [out_width // ow_factor, ow_factor])
                    return
    raise ValueError(
        "cannot decide default schedule for workload: {}".format(wkl))
示例#21
0
def _schedule_dense_small_batch(cfg, s, C):
    A, weights = C.op.input_tensors
    if len(weights.op.input_tensors) == 1 and weights.op.input_tensors[0] == A:
        s[weights].compute_inline()

    _, in_dim_weights = get_const_tuple(weights.shape)
    _, in_dim_A = get_const_tuple(A.shape)

    if isinstance(in_dim_A, int):
        in_dim = in_dim_A
    elif isinstance(in_dim_weights, int):
        in_dim = in_dim_weights
    else:
        in_dim = None

    if in_dim is not None:
        cfg.define_split("tile_k", in_dim, num_outputs=2)
        if cfg.is_fallback:
            cfg["tile_k"] = SplitEntity([-1, 64] if in_dim > 64 else [1, 64])
        _, kf = cfg["tile_k"].apply(s, C, C.op.reduce_axis[0])
    else:
        tile_k = 64
        _, kf = s[C].split(C.op.reduce_axis[0], tile_k)

    CF = s.rfactor(C, kf)

    if C.op in s.outputs:
        Out = C
    else:
        Out = s.outputs[0].output(0)
        s[C].compute_at(s[Out], s[Out].op.axis[1])
    s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
    s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))

    tx = s[C].op.reduce_axis[0]
    thread_x = te.thread_axis("threadIdx.x")
    s[C].bind(tx, thread_x)
    s[CF].compute_at(s[C], tx)
    s[C].set_store_predicate(thread_x.var.equal(0))
    s[Out].set_store_predicate(thread_x.var.equal(0))
示例#22
0
def schedule_depthwise_conv2d_nhwc(cfg, outs):
    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])
    out = outs[0]

    ##### space definition begin #####
    n, h, w, c = s[out].op.axis
    # Split the number of input/output channels
    cfg.define_split("tile_c", c, num_outputs=2)
    # Split the height of the convolution
    _, hi = cfg.define_split("tile_h", h, num_outputs=2)
    # Split the width of the convolution
    _, wi = cfg.define_split("tile_w", w, num_outputs=2)
    # Additional out (e.g., requantization, bias addition, etc..)
    # 0: locate the output on the second last axis of the main compuation
    # 1: locate the output closest to the main computation
    cfg.define_knob("locate_output", [0, 1])
    # Determine if we should unroll the computation of the inner tile
    cfg.define_knob("unroll_tile", [True, False])

    # fallback support
    if cfg.is_fallback:
        cfg["tile_c"] = SplitEntity([-1, 8])
        cfg["tile_h"] = SplitEntity([-1, 2])
        cfg["tile_w"] = SplitEntity([-1, 2])
        cfg["locate_output"] = OtherOptionEntity(1)
        cfg["unroll_tile"] = OtherOptionEntity(True)
    ##### space definition end #####

    def schedule_conv(conv):
        conv_data = conv.op.input_tensors[0]
        kernel_data = conv.op.input_tensors[1]
        in_type = conv_data.dtype

        _, _, IC, channel_multiplier = get_const_tuple(kernel_data.shape)

        n, w, h, c = conv.op.axis
        r_h, r_w = conv.op.reduce_axis
        ho, hi = cfg["tile_h"].apply(s, conv, h)
        wo, wi = cfg["tile_w"].apply(s, conv, w)
        co, ci = cfg["tile_c"].apply(s, conv, c)

        split_val = cfg["tile_c"].size[-1]
        use_tensorization = (
            (in_type == "int16")
            and (split_val == 8)
            and (IC % split_val == 0)
            and (channel_multiplier == 1)
            and is_aarch64_arm()
        )

        data_pad_value = -1
        if conv_data.name == "data_pad":
            assert isinstance(conv_data.op, tvm.te.ComputeOp)
            # Define a strategy for padding computation
            cfg.define_knob("data_pad_strategy", [1, 2, 3])
            if cfg.is_fallback:
                # We cannot inline padding when tensorizing.
                # So, if we can tensorize, let's compute_at the closest axis
                cfg["data_pad_strategy"] = (
                    OtherOptionEntity(2) if use_tensorization else OtherOptionEntity(3)
                )
            # Compute padding on the third to last axis of the computation
            if cfg["data_pad_strategy"].val == 1:
                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                s[conv_data].compute_at(s[conv], ho)
            # Compute padding on the second to last axis of the computation
            if cfg["data_pad_strategy"].val == 2:
                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                s[conv_data].compute_at(s[conv], wo)
            # Inline padding during computation
            if cfg["data_pad_strategy"].val == 3:
                s[conv_data].compute_inline()
            data_pad_value = cfg["data_pad_strategy"].val

        if use_tensorization and data_pad_value != 3:
            smlal = smlal_int16_int32()
            s[conv].tensorize(ci, smlal)
        else:
            s[conv].vectorize(ci)

        if cfg["unroll_tile"].val:
            s[conv].unroll(r_h)
            s[conv].unroll(r_w)
            s[conv].unroll(wi)
            s[conv].unroll(hi)

        s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
        fused_n_ho = s[conv].fuse(n, ho)
        return fused_n_ho

    def schedule_conv_out(out):
        n, h, w, c = out.op.axis
        co, ci = cfg["tile_c"].apply(s, out, c)
        wo, wi = cfg["tile_w"].apply(s, out, w)
        ho, hi = cfg["tile_h"].apply(s, out, h)
        s[out].reorder(n, ho, wo, co, hi, wi, ci)
        if cfg["unroll_tile"]:
            s[out].unroll(wi)
            s[out].unroll(hi)

        if out.dtype in ["int8", "uint8"]:
            # In case of quantized convolution further split the channel in batches of 4 elements
            # so that we can use arm intrinsics to run fixed_point_multiplication
            ci_outer, ci_inner = s[out].split(ci, 4)
            s[out].vectorize(ci_inner)
            s[out].unroll(ci_outer)

        fused_n_ho = s[out].fuse(n, ho)
        return hi, wi, fused_n_ho

    def _callback(op):
        if op.name == "depthwise_conv2d_nhwc_output":
            conv = op.output(0)
            if conv != out:
                hi, wi, p_axis = schedule_conv_out(out)
                schedule_conv(conv)
                if cfg["locate_output"].val == 0:
                    s[conv].compute_at(s[out], hi)
                if cfg["locate_output"].val == 1:
                    s[conv].compute_at(s[out], wi)
            else:
                p_axis = schedule_conv(out)

            s[out].parallel(p_axis)

    traverse_inline(s, outs[0].op, _callback)
    return s
示例#23
0
    def schedule(Apad, W, B):
        """Schedule conv2d_hwcn"""
        sch[Apad].compute_inline()
        AA = sch.cache_read(Apad, "shared", [B])
        WW = sch.cache_read(W, "shared", [B])
        AL = sch.cache_read(AA, "local", [B])
        WL = sch.cache_read(WW, "local", [B])

        if B.op in sch.outputs:
            Out = B
            BL = sch.cache_write(Out, "local")
        else:
            Out = sch.outputs[0].output(0)
            sch[B].set_scope("local")
            BL = B

        hi, wi, fi, ni = sch[Out].op.axis

        # Create tuning space
        n_thread_cand = [1, 2, 4, 8, 16, 32]
        vthread_cand = [1, 2, 4, 8]

        cfg.define_split(
            'tile_fi',
            fi,
            num_outputs=4,
            filter=lambda x:
            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
        cfg.define_split(
            'tile_ni',
            ni,
            num_outputs=4,
            filter=lambda x:
            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand))

        if cfg.is_fallback:
            cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
            cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])

        # Scheduling
        step = 8

        bz = sch[Out].fuse(hi, wi)
        by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
        bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
        sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)

        sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
        sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
        sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
        sch[Out].bind(tyz, tvm.thread_axis('vthread'))
        sch[Out].bind(txz, tvm.thread_axis('vthread'))
        sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
        sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))

        # Schedule BL local write
        sch[BL].compute_at(sch[Out], tx)
        yi, xi, fi, ni = sch[BL].op.axis
        ry, rx, rc = sch[BL].op.reduce_axis
        rco, rci = sch[BL].split(rc, factor=step)
        sch[BL].reorder(rco, ry, rx, rci, fi, ni)
        fuse_index = sch[BL].fuse(ry, rx)
        fuse_index = sch[BL].fuse(fuse_index, rco)
        rx = fuse_index

        sch[AA].compute_at(sch[BL], rx)
        sch[WW].compute_at(sch[BL], rx)
        sch[AL].compute_at(sch[BL], rci)
        sch[WL].compute_at(sch[BL], rci)
        # Schedule for A's shared memory load
        yi, xi, ci, ni = sch[AA].op.axis
        ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
        tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
        _, ni = sch[AA].split(ni, factor=4)
        sch[AA].reorder(ty, tx, yi, xi, ci, ni)
        sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
        sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
        sch[AA].vectorize(ni)
        # Schedule for W's shared memory load
        yi, xi, ci, fi = sch[WW].op.axis
        ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
        tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
        _, fi = sch[WW].split(fi, factor=4)
        sch[WW].reorder(ty, tx, yi, xi, ci, fi)
        sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
        sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
        sch[WW].vectorize(fi)
示例#24
0
def schedule_depthwise_conv2d_nhwc(cfg, outs):
    """Create the schedule for depthwise_conv2d_nchw_spatial_pack"""
    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
    s = te.create_schedule([x.op for x in outs])
    out = outs[0]

    ##### space definition begin #####
    n, h, w, c = s[out].op.axis
    cfg.define_split("tile_c", c, num_outputs=2)
    _, hi = cfg.define_split("tile_h", h, num_outputs=2)
    _, wi = cfg.define_split("tile_w", w, num_outputs=2)
    cfg.define_knob("locate_output", [0, 1])

    # fallback support
    if cfg.is_fallback:
        cfg["tile_c"] = SplitEntity([-1, 8])
        cfg["tile_h"] = SplitEntity([-1, 2])
        cfg["tile_w"] = SplitEntity([-1, 2])
        cfg["locate_output"] = OtherOptionEntity(1)
    ##### space definition end #####

    def schedule_conv(conv):
        conv_data = conv.op.input_tensors[0]

        n, w, h, c = conv.op.axis
        r_h, r_w = conv.op.reduce_axis
        ho, hi = cfg["tile_h"].apply(s, conv, h)
        wo, wi = cfg["tile_w"].apply(s, conv, w)
        co, ci = cfg["tile_c"].apply(s, conv, c)

        if conv_data.name == "data_pad":
            assert isinstance(conv_data.op, tvm.te.ComputeOp)
            # Define a policy for padding computation
            cfg.define_knob("data_pad_inline", [1, 2, 3])
            if cfg.is_fallback:
                cfg["data_pad_inline"] = OtherOptionEntity(3)
            if cfg["data_pad_inline"].val == 1:
                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                s[conv_data].compute_at(s[conv], ho)
            if cfg["data_pad_inline"].val == 2:
                s[conv_data].vectorize(list(s[conv_data].op.axis)[-1])
                s[conv_data].compute_at(s[conv], wo)
            if cfg["data_pad_inline"].val == 3:
                s[conv_data].compute_inline()

        s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci)
        fused_n_ho = s[conv].fuse(n, ho)
        s[conv].vectorize(ci)
        return fused_n_ho

    def schedule_conv_out(out):
        n, h, w, c = out.op.axis
        co, ci = cfg["tile_c"].apply(s, out, c)
        wo, wi = cfg["tile_w"].apply(s, out, w)
        ho, hi = cfg["tile_h"].apply(s, out, h)
        s[out].reorder(n, ho, wo, co, hi, wi)

        if out.dtype in ["int8", "uint8"]:
            # In case of quantized convolution further split the channel in batches of 4 elements
            # so that we can use arm intrinsics to run fixed_point_multiplication
            ci_outer, ci_inner = s[out].split(ci, 4)
            s[out].vectorize(ci_inner)

        fused_n_ho = s[out].fuse(n, ho)
        return hi, wi, fused_n_ho

    def _callback(op):
        if op.name == "depthwise_conv2d_nhwc_output":
            conv = op.output(0)
            if conv != out:
                hi, wi, p_axis = schedule_conv_out(out)
                schedule_conv(conv)
                if cfg["locate_output"].val == 0:
                    s[conv].compute_at(s[out], hi)
                if cfg["locate_output"].val == 1:
                    s[conv].compute_at(s[out], wi)
            else:
                p_axis = schedule_conv(out)

            s[out].parallel(p_axis)

    traverse_inline(s, outs[0].op, _callback)
    return s
示例#25
0
    def _schedule(cfg, op):
        C = op.output(0)
        A, B = s[C].op.input_tensors
        if len(B.op.input_tensors) == 1 and B.op.input_tensors[0] == A:
            s[B].compute_inline()
        _, M, N = get_const_tuple(C.shape)
        AA = s.cache_read(A, "shared", [C])
        AL = s.cache_read(AA, "local", [C])
        BB = s.cache_read(B, "shared", [C])
        BL = s.cache_read(BB, "local", [C])
        CC = s.cache_write(C, "local")
        if op not in s.outputs:
            s[C].compute_inline()
            C = s.outputs[0].output(0)

        b, y, x = s[C].op.axis
        (k, ) = s[CC].op.reduce_axis

        cfg.define_split("tile_y", y, num_outputs=3)
        cfg.define_split("tile_x", x, num_outputs=3)
        cfg.define_split("tile_k", k, num_outputs=2)
        cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
        target = tvm.target.Target.current()
        if target.kind.name in ["nvptx", "rocm"]:
            # llvm-based backends cannot do non-explicit unrolling
            cfg.define_knob("unroll_explicit", [1])
        else:
            cfg.define_knob("unroll_explicit", [0, 1])

        if cfg.is_fallback:
            y_bn = get_max_power2_factor(M, 64)
            x_bn = get_max_power2_factor(N, 64)
            y_nthreads = min(y_bn, 8)
            x_nthreads = min(x_bn, 8)
            cfg["tile_x"] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
            cfg["tile_y"] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
            cfg["tile_k"] = SplitEntity([-1, 8])
            cfg["auto_unroll_max_step"] = OtherOptionEntity(16)

        by, ty, yi = cfg["tile_y"].apply(s, C, y)
        bx, tx, xi = cfg["tile_x"].apply(s, C, x)

        thread_x = te.thread_axis("threadIdx.x")
        thread_y = te.thread_axis("threadIdx.y")

        s[C].reorder(b, by, bx, ty, tx, yi, xi)
        s[C].bind(b, te.thread_axis("blockIdx.z"))
        s[C].bind(by, te.thread_axis("blockIdx.y"))
        s[C].bind(bx, te.thread_axis("blockIdx.x"))
        s[C].bind(ty, thread_y)
        s[C].bind(tx, thread_x)
        s[C].pragma(yi, "auto_unroll_max_step",
                    cfg["auto_unroll_max_step"].val)
        s[C].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val)

        s[CC].compute_at(s[C], tx)
        _, yi, xi = s[CC].op.axis
        ko, ki = cfg["tile_k"].apply(s, CC, k)
        s[CC].reorder(ko, ki, yi, xi)
        s[CC].pragma(ki, "auto_unroll_max_step",
                     cfg["auto_unroll_max_step"].val)
        s[CC].pragma(ki, "unroll_explicit", cfg["unroll_explicit"].val)

        s[AA].compute_at(s[CC], ko)
        s[AL].compute_at(s[CC], ki)
        s[BB].compute_at(s[CC], ko)
        s[BL].compute_at(s[CC], ki)
        _, y, k = s[AA].op.axis
        ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1])
        tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1])
        s[AA].reorder(ty, tx, yi, ki)
        s[AA].bind(ty, thread_y)
        s[AA].bind(tx, thread_x)
        s[AA].pragma(yi, "auto_unroll_max_step",
                     cfg["auto_unroll_max_step"].val)
        s[AA].pragma(yi, "unroll_explicit", cfg["unroll_explicit"].val)

        _, x, k = s[BB].op.axis
        ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
        tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1])
        s[BB].bind(ty, thread_y)
        s[BB].bind(tx, thread_x)
        s[BB].reorder(ty, tx, xi, ki)
        s[BB].pragma(xi, "auto_unroll_max_step",
                     cfg["auto_unroll_max_step"].val)
        s[BB].pragma(xi, "unroll_explicit", cfg["unroll_explicit"].val)
示例#26
0
def schedule_dense_large_batch(cfg, s, C):
    """Schedule float32/64 dense with large batch size"""
    A, B = C.op.input_tensors
    batch, in_dim = get_const_tuple(A.shape)
    out_dim, _ = get_const_tuple(B.shape)
    k = C.op.reduce_axis[0]

    # create tuning space
    try:
        block_cand = [64, 128]
        vthread_cand = [2**x for x in range(1, 7)]
        n_thread_cand = [2**x for x in range(3, 7)]
        cfg.define_split(
            'tile_x',
            batch,
            num_outputs=4,
            filter=lambda x:
            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand and
             (x.size[1] * x.size[2] * x.size[3]) in block_cand))
        cfg.define_split(
            'tile_y',
            out_dim,
            num_outputs=4,
            filter=lambda x:
            (x.size[1] in vthread_cand and x.size[2] in n_thread_cand and
             (x.size[1] * x.size[2] * x.size[3]) in block_cand))
        cfg.define_split('tile_k',
                         in_dim,
                         num_outputs=3,
                         filter=lambda x: x.size[0] > 2)
    except IndexError:
        # Index error happens when no entities left after filtering, which was designed
        # to prune tuning space for better search efficiency.
        logger.debug(
            'Tuning space was created without pruning due to unfit shapes')
        cfg.define_split('tile_x', batch, num_outputs=4)
        cfg.define_split('tile_y', out_dim, num_outputs=4)
        cfg.define_split('tile_k', in_dim, num_outputs=3)

    if cfg.is_fallback:
        if batch > 1:
            cfg['tile_x'] = SplitEntity([-1, 2, 16, 2])
        else:
            cfg['tile_x'] = SplitEntity([1, 1, 1, 1])
        if out_dim > 1:
            cfg['tile_y'] = SplitEntity([-1, 2, 16, 2])
        else:
            cfg['tile_y'] = SplitEntity([1, 1, 1, 1])
        if in_dim > 8:
            cfg['tile_k'] = SplitEntity([-1, 8, 1])
        else:
            cfg['tile_k'] = SplitEntity([-1, 1, 1])

    # Explicit memory access
    AA = s.cache_read(A, "shared", [C])
    BB = s.cache_read(B, "shared", [C])
    AL = s.cache_read(AA, "local", [C])
    BL = s.cache_read(BB, "local", [C])
    CC = s.cache_write(C, "local")

    # Deal with op fusion
    if C.op not in s.outputs:
        s[C].compute_inline()
        C = s.outputs[0].output(0)

    # Split and reorder computation
    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, C.op.axis[0])
    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, C.op.axis[1])
    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)
    s[CC].compute_at(s[C], tx)

    # Binding
    s[C].bind(by, tvm.thread_axis("blockIdx.y"))
    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
    s[C].bind(tyz, tvm.thread_axis("vthread"))
    s[C].bind(txz, tvm.thread_axis("vthread"))
    s[C].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

    # Split reduction
    yo, xo = CC.op.axis
    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)
    s[CC].reorder(ko, kt, ki, yo, xo)
    s[AA].compute_at(s[CC], ko)
    s[BB].compute_at(s[CC], ko)
    s[CC].unroll(kt)
    s[AL].compute_at(s[CC], kt)
    s[BL].compute_at(s[CC], kt)

    # Schedule for A's shared memory load
    num_thread_x = cfg['tile_x'].size[2]
    ty, _ = s[AA].split(s[AA].op.axis[0], nparts=num_thread_x)
    _, xi = s[AA].split(s[AA].op.axis[1], factor=num_thread_x * 4)
    tx, xi = s[AA].split(xi, nparts=num_thread_x)
    s[AA].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[AA].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[AA].double_buffer()

    # Schedule for B' shared memory load
    num_thread_y = cfg['tile_y'].size[2]
    ty, _ = s[BB].split(s[BB].op.axis[0], nparts=num_thread_y)
    _, xi = s[BB].split(s[BB].op.axis[1], factor=num_thread_y * 4)
    tx, xi = s[BB].split(xi, nparts=num_thread_y)
    s[BB].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[BB].bind(tx, tvm.thread_axis("threadIdx.x"))
    s[BB].double_buffer()
示例#27
0
def _default_batch_matmul_config(cfg, M, N, K):
    cfg["tile_k"] = SplitEntity([K // 16, 16])
    x_bn = get_max_power2_factor(N, 8)
    cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
    y_bn = get_max_power2_factor(M, 8)
    cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])