Esempio n. 1
0
def fused_batch_norm_manual_setdim(shape):
    """manual setdim for fused batch norm with dynamic shape"""
    from akg import dim
    info = dim.Dim()
    for i, d in enumerate(DYNAMIC_SETDIM_MAP.get(shape, [])):
        info.setdim(index=0, axis=i, tilel1=d, tilel0=1)
    return str(info)
Esempio n. 2
0
def dropout_set_dim_func(data_tensor, data_mask, keep_prob):
    shape = [x.value for x in data_tensor.shape if x.value != 1]
    dtype = data_tensor.dtype
    storage = 49152
    if dtype.lower() == 'float16':
        dnum = 1
    else:
        dnum = 2

    info = dim.Dim()
    list_info = []

    def cal_max_divisor(a, threshold):
        for i in range(threshold, 0, -1):
            if a % i == 0:
                return i
        return 1
    for i in range(len(shape) - 1, -1, -1):
        if dnum >= storage:
            list_info.append((i, 1))
        elif dnum * shape[i] > storage:
            list_info.append((i, cal_max_divisor(shape[i], storage // dnum)))
        dnum *= shape[i]

    for i in reversed(list_info):
        info.setdim(index=0, axis=i[0], tilel1=i[1], tilel0=1)

    return str(info)
Esempio n. 3
0
def set_dims_group(cut_h, cut_co, cut_m, cut_k, cut_n, out_shape_5d, _c_i,
                   _c_o, group, _k_h, _k_w, _s_h, block_size):
    info = dim.Dim()
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_5d
    tile_out_h = (cut_h - _k_h) // _s_h + 1
    if (out_n > 1):
        info.setdim(index=0, axis=0, tilel1=1, tilel0=0)
    if (out_c1 > 1):
        info.setdim(index=0, axis=0, tilel1=cut_co // block_size, tilel0=0)
    if (out_h > 1):
        info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0)
    if (out_w > 1):
        info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0)
    if (out_c0 > 1):
        info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)
    assert _c_i // block_size // group == 1
    if (_c_i // block_size // group > 1):
        info.setdim(index=0,
                    axis=5,
                    tilel1=_c_i // block_size // group,
                    tilel0=0)
    if (_k_h > 1):
        info.setdim(index=0, axis=5, tilel1=_k_h, tilel0=0)
    if (_k_w > 1):
        info.setdim(index=0, axis=5, tilel1=_k_w, tilel0=0)
    return str(info)
Esempio n. 4
0
def set_dims(fmap_shape, filter_shape, pad_, stride_, dilation_, tile_hh,
             tile_coco, tile_mm, tile_kk, tile_nn, block_size):
    """set dim info in attrs."""
    in_n, in_c, in_h, in_w = fmap_shape
    in_c = (in_c + block_size - 1) // block_size * block_size
    in_c1 = in_c // block_size

    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    k_c = (k_c + block_size - 1) // block_size * block_size
    k_n = (k_n + block_size - 1) // block_size * block_size

    padding = (pad_[0], pad_[0], pad_[1], pad_[1])
    p_top, p_bottom, p_left, p_right = padding
    s_h, s_w = (stride_[0], stride_[1])
    d_h, d_w = (dilation_[0], dilation_[1])
    if (tile_hh == in_h):
        tile_hh += p_top + p_bottom
    tile_coco = (tile_coco + block_size - 1) // block_size * block_size
    tile_mm = (tile_mm + block_size - 1) // block_size * block_size
    tile_kk = (tile_kk + block_size - 1) // block_size * block_size
    tile_nn = (tile_nn + block_size - 1) // block_size * block_size


    k_h_d = (k_h - 1) * d_h + 1
    k_w_d = (k_w - 1) * d_w + 1
    out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
    tile_out_h = (tile_hh - k_h_d) // s_h + 1
    out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1

    out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    if (tile_coco > 0):
        c1_cut = tile_coco // block_size
    else:
        c1_cut = out_c1

    # set dim
    info = dim.Dim()
    if (out_n > 1):
        info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
    if (out_c1 > 1):
        info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0)  # c1
    if (out_h > 1):
        info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0)  # h
    if (out_w > 1):
        info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0)  # w
    if (out_c0 > 1):
        info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0

    if (in_c1 > 1):
        info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0)  # kc1
    if (k_h > 1):
        info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0)  # kh
    if (k_w > 1):
        info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0)  # kw

    return str(info)
Esempio n. 5
0
def smooth_l1_loss_grad_get_dim(shape):
    """
    get dim attr for smooth L1 loss grad

    Args:
        shape: the shape of prediction tensor (e.g. [8, 4718, 4])

    Returns:
        dim string for akg.op.build(attrs=...)
    """

    # example shape: [8, 4718, 4]
    # cut dim: ((1,1), (1024,1024))
    tensor_size = 1
    for i in shape[:-1]:
        tensor_size *= i
    # if tensor_size >= threshold, cut
    ub_size = 256 * 1024
    # estimated maximum number of data copies in UB
    num_data_copies = 32
    data_size = 4
    # do not cut the last dim
    max_tensor_size = int(ub_size / data_size / num_data_copies / shape[-1])

    if tensor_size > max_tensor_size:
        # find the largest divisor of tensor_size to be the tile size
        # currently the dim size must be divisible by tile size
        tile_size = 1
        for i in range(max_tensor_size, 1, -1):
            if tensor_size % i == 0:
                tile_size = i
                break

        # generate setdim string
        info = dim.Dim()
        # do not cut last dim
        for i in range(0, len(shape) - 2):
            info.setdim(index=0, axis=i, tilel1=1, tilel0=1)
        # cut -2 dim
        info.setdim(index=0,
                    axis=len(shape) - 2,
                    tilel1=tile_size,
                    tilel0=tile_size)
        return str(info)
    return ''
Esempio n. 6
0
def test_quant(fmap_shape):
    # input shape(NCHW -> NC1HWC0)
    in_n, in_c, in_h, in_w = fmap_shape
    assert in_c % 32 == 0
    input_shape_nc1hwc0 = (in_n, in_c // 16, in_h, in_w, 16)
    in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0

    # placeholder (NC1HWC0)
    FMap = akg.tvm.placeholder(input_shape_nc1hwc0,
                               dtype='float16',
                               name='FMap')

    ScaleQ = akg.tvm.placeholder((16, ), dtype='float16', name='ScaleQ')
    OffsetQ = akg.tvm.placeholder((16, ), dtype='float16', name='OffsetQ')

    out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32)
    print(out_shape_nc1hwc0)
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    # quantize
    Quant = akg.tvm.compute(out_shape_nc1hwc0,
                            lambda n, c1, h, w, c0:
                            (FMap[n, c1 + c0 // 16, h, w, c0 % 16] * ScaleQ[0]
                             + OffsetQ[0]).astype('int8'),
                            name='output')

    info = dim.Dim()
    info.setdim(index=0, axis=0, tilel1=2, tilel0=0)
    info.setdim(index=0, axis=0, tilel1=32, tilel0=0)
    info.setdim(index=0, axis=0, tilel1=32, tilel0=0)
    info.setdim(index=0, axis=0, tilel1=16, tilel0=0)

    # schedule
    s = akg.tvm.create_schedule(Quant.op)
    with akg.build_config(add_lower_pass=utils.debug_mode(0),
                          dump_pass_ir=True):
        mod = akg.build(s, [FMap, ScaleQ, OffsetQ, Quant],
                        'cce',
                        name='cce_quant',
                        attrs={'dim': str(info)},
                        polyhedral=True)

    source_code = mod.imported_modules[0].get_source()
    print(source_code)
Esempio n. 7
0
def set_dims(tiling):
    """Set dim for tiling."""
    info = dim.Dim()
    for d, tile_d in enumerate(tiling):
        if len(tile_d) == 2:  # only c1 and c0 tile
            index = 0
            axis = d
            c1 = tile_d[0]
            c0 = tile_d[1]
        elif len(tile_d) == 4:  # index, axis, c1, c0
            index = tile_d[0]
            axis = tile_d[1]
            c1 = tile_d[2]
            c0 = tile_d[3]
        else:
            raise RuntimeError(
                "Each element in tiling should be length-2 (c1_tile, c0_tile) "
                "or length-4 (band_index, axis_index, c1_tile, c0_tile)")
        info.setdim(index=index, axis=axis, tilel1=c1, tilel0=c0)
    return str(info)
Esempio n. 8
0
File: conv.py Progetto: zhuyawen/akg
    def gen_static_dim():
        info = dim.Dim()
        if out_n > 1:
            info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
        if out_c1 > 1:
            info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0)  # c1
        if out_h > 1:
            info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0)  # h
        if out_w > 1:
            info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0)  # w
        if out_c0 > 1:
            info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0

        if in_c1 > 1:
            info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0)  # kc1
        if k_h > 1:
            info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0)  # kh
        if k_w > 1:
            info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0)  # kw
        info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0)  # kc0
        return info
Esempio n. 9
0
File: conv.py Progetto: zhuyawen/akg
    def gen_dynamic_dim():
        info = dim.Dim()
        if dynamic:
            info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
        elif out_n > 1:
            info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n

        if dynamic_tiling:
            info.setdim(index=0, axis=0, tilel1=c1_cut_fake, tilel0=0)  # c1
        elif dynamic or out_c1 > 1:
            info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0)  # c1

        if dynamic_tiling:
            info.setdim(index=0, axis="H", tilel1=tile_out_h_fake,
                        tilel0=0)  # h
        elif dynamic or out_h > 1:
            info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0)  # h

        if dynamic_tiling:
            info.setdim(index=0, axis="W", tilel1=tile_out_w_fake,
                        tilel0=0)  # w
        elif dynamic or out_w > 1:
            info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0)  # w

        if dynamic or out_c0 > 1:
            info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0

        if dynamic and not use_autotiling:
            info.setdim(index=0, axis=5, tilel1=dynamic_ci_c1, tilel0=0)  # kc1
        elif dynamic or in_c1 > 1:
            info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0)  # kc1

        if dynamic or k_h > 1:
            info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0)  # kh

        if dynamic or k_w > 1:
            info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0)  # kw

        info.setdim(index=0, axis="KC0", tilel1=block_size, tilel0=0)  # kc0
        return info
Esempio n. 10
0
def test_CCE_Conv(fmap_shape,
                  filter_shape,
                  pad_,
                  stride_,
                  tile_hh=0,
                  tile_coco=0,
                  tile_mm=0,
                  tile_kk=0,
                  tile_nn=0,
                  bypass_l1=False,
                  use_bias=False,
                  kernel_name="quant_conv",
                  cce_path='.'):
    # input shape (NCHW -> NC1HWC0)
    in_n, in_c, in_h, in_w = fmap_shape
    input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
    # out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32)
    in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0

    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    kernel_shape_nc1hwc0 = (k_n, k_c // 32, k_h, k_w, 32)
    k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0
    kernel_shape_fractal = (k_c // 32 * k_h * k_w, k_n // 16, 16, 32)
    f_ko, f_no, f_ni, f_ki = kernel_shape_fractal

    # bias shape
    bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size)

    # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
    padding = (pad_[0], pad_[0], pad_[1], pad_[1])
    p_top, p_bottom, p_left, p_right = padding

    # stride (stride_h, stride_w)
    s_h, s_w = stride_

    # A placeholder (NC1HWCO)
    A = akg.tvm.placeholder(input_shape_nc1hwc0, dtype=conv_dtype, name='FMap')
    # B_placeholder (fractal)
    B = akg.tvm.placeholder(kernel_shape_fractal, dtype='int8', name='Filter')
    ScaleQ = akg.tvm.placeholder((16, ), dtype='float16', name='ScaleQ')
    OffsetQ = akg.tvm.placeholder((16, ), dtype='float16', name='OffsetQ')

    out_shape_nc1hwc0 = (in_n, in_c // 32, in_h, in_w, 32)
    q_n, q_c1, q_h, q_w, q_c0 = out_shape_nc1hwc0
    # print out_shape_nc1hwc0
    Quant = akg.tvm.compute(out_shape_nc1hwc0,
                            lambda qn, qc1, qh, qw, qc0:
                            (A[qn, qc1 + qc0 // 16, qh, qw, qc0 % 16] * ScaleQ[
                                0] + OffsetQ[0]).astype('int8'),
                            name='QuantOUT',
                            attrs={'no_inline': 1})

    if use_bias:
        bias_name = 'bias'
        bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0,
                                         dtype=conv_dtype,
                                         name=bias_name)
    else:
        bias_name = 'None'

    # Create reduction variables
    kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1')
    kh = akg.tvm.reduce_axis((0, k_h), name='kh')
    kw = akg.tvm.reduce_axis((0, k_w), name='kw')
    kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0')

    out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1
    tile_out_h = (tile_hh - k_h) // s_h + 1
    out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1

    out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    if (tile_coco > 0):
        c1_cut = tile_coco // block_size
    else:
        c1_cut = out_c1

    # set dim
    index = 0
    info = dim.Dim()
    if (q_c1 > 1):
        info.setdim(index=index, axis="KO", tilel1=q_c1, tilel0=q_c1)  # ko
    if (q_h > 1):
        info.setdim(index=index,
                    axis="C1",
                    tilel1=tile_out_h,
                    tilel0=tile_out_h)  # c1
    if (q_w > 1):
        info.setdim(index=index, axis="C0", tilel1=q_w, tilel0=q_w)  # c0
    if (q_c0 > 1):
        info.setdim(index=index, axis="KI", tilel1=q_c0, tilel0=q_c0)  # ki

    index += 1
    if (out_c1 > 1):
        info.setdim(index=index, axis="C1", tilel1=c1_cut, tilel0=0)  # c1
    if (out_h > 1):
        info.setdim(index=index, axis="H", tilel1=tile_out_h, tilel0=0)  # h
    if (out_w > 1):
        info.setdim(index=index, axis="W", tilel1=out_w, tilel0=0)  # w
    if (out_c0 > 1):
        info.setdim(index=index, axis="C0", tilel1=out_c0, tilel0=0)  # c0
    if (in_c1 > 1):
        info.setdim(index=index, axis="KC1", tilel1=in_c1 / 2, tilel0=0)  # kc1
    if (k_h > 1):
        info.setdim(index=index, axis="KH", tilel1=k_h, tilel0=0)  # kh
    if (k_w > 1):
        info.setdim(index=index, axis="KW", tilel1=k_w, tilel0=0)  # kw
    info = str(info)

    # Compute the convolution
    output_name = "output0"
    output_bias_name = "output1"

    # print out_shape_nc1hwc0
    C = akg.tvm.compute(
        out_shape_nc1hwc0,
        lambda n, c1, h, w, c0: akg.tvm.sum(akg.tvm.if_then_else(
            akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) >
                        (in_h + p_top - 1), (w * s_w + kw) < p_left,
                        (w * s_w + kw) >
                        (in_w + p_left - 1)), akg.tvm.const(0.0, 'int8'),
            Quant[n, kc1, (h * s_h + kh - p_top),
                  (w * s_w + kw - p_left), kc0]) * B[
                      (kc1 * k_h + kh) * k_w + kw, c1, c0, kc0],
                                            axis=[kc1, kh, kw, kc0]),
        name=output_name,
        attrs={
            "pragma_conv_kernel_n": k_n,
            "pragma_conv_kernel_h": k_h,
            "pragma_conv_kernel_w": k_w,
            "pragma_conv_padding_top": p_top,
            "pragma_conv_padding_bottom": p_bottom,
            "pragma_conv_padding_left": p_left,
            "pragma_conv_padding_right": p_right,
            "pragma_conv_dilation_h": 1,
            "pragma_conv_dilation_w": 1,
            "pragma_conv_bypass_l1": 1 if bypass_l1 else 0,
            "pragma_conv_stride_h": s_h,
            "pragma_conv_stride_w": s_w,
            "pragma_conv_fm_n": in_n,
            "pragma_conv_fm_c": in_c,
            "pragma_conv_fm_h": in_h,
            "pragma_conv_fm_w": in_w,
            "pragma_conv_h_cut": (h_window_cut - 1) * s_h + k_h,
            "pragma_conv_w_cut": (in_w + p_left + p_right),
            "pragma_conv_co_cut": c1_cut * k_c0,
            "pragma_conv_m_cut": tile_mm,
            "pragma_conv_k_cut": tile_kk,
            "pragma_conv_n_cut": tile_nn,
            "feature": Quant.op.name,
            "filter": B.op.name,
            "bias": bias_name,
            "res": output_name,
            "res_bias": output_bias_name
        })

    if use_bias:
        cube = akg.tvm.compute(out_shape_nc1hwc0,
                               lambda n, c1, h, w, c0: C[n, c1, h, w, c0] +
                               bias_value[0, c1, 0, 0, c0],
                               name=output_bias_name)
    else:
        cube = C

    if fusion:
        # leakly relu
        negative_slope = 0.0
        slope_tmp = akg.tvm.const(negative_slope, dtype=conv_dtype)
        # negative_slope*x
        out = akg.lang.cce.vmuls(cube, slope_tmp)
        # max(x,negative_slope*x)
        out = akg.lang.cce.vmax(out, cube)
    else:
        out = cube

    # schedule
    s = akg.tvm.create_schedule(out.op)
    attrs = {}
    attrs["pragma_reschedule"] = 1
    with akg.build_config(add_lower_pass=cce.debug_mode(0), dump_pass_ir=True):
        if fusion:
            if use_bias:
                mod = akg.build(s, [A, B, ScaleQ, OffsetQ, bias_value, out],
                                "cce",
                                name=kernel_name,
                                attrs=attrs,
                                attrs={"dim": info},
                                polyhedral=True)
            else:
                mod = akg.build(s, [A, B, ScaleQ, OffsetQ, out],
                                "cce",
                                name=kernel_name,
                                attrs=attrs,
                                attrs={"dim": info},
                                polyhedral=True)
        else:
            if use_bias:
                mod = akg.build(s, [A, B, ScaleQ, OffsetQ, bias_value, out],
                                "cce",
                                name=kernel_name,
                                attrs=attrs,
                                attrs={"dim": info},
                                polyhedral=True)
            else:
                mod = akg.build(s, [A, B, ScaleQ, OffsetQ, out],
                                "cce",
                                name=kernel_name,
                                attrs=attrs,
                                attrs={"dim": info},
                                polyhedral=True)
    source_code = mod.imported_modules[0].get_source()
    # print(source_code)
    # utils.create_code(kernel_name, cce_path, source_code)
    if run_cce:
        run_conv(mod, fmap_shape, filter_shape, pad_[0], stride_[0], use_bias)
Esempio n. 11
0
def group_conv(N,
               H,
               W,
               CI,
               CO,
               group,
               KH,
               KW,
               PAD_H,
               PAD_W,
               SH,
               SW,
               cutH,
               cutCo,
               cutM,
               cutK,
               cutN,
               block_size,
               use_bias=False,
               kernel_name='conv'):
    """
    split channels of FeatureMap to some groups,every group has its filter-kernel

    Args:
        args1:a list,the size is 3 if use_bias else the size is 2;
              data[0] akg.tvm.Tensor of type float16 ,shape 5D(N, CI//C0, C0, H, W)
              data[1] akg.tvm.Tensor of type float16 ,shape 6D(CI//(CI//C0)//C0, KH, KW, k_ch*CI//C0, C0, C0)
              data[2] akg.tvm.Tensor of type float16 ,shape 5D(N, CI*k_ch//C0, OH, OW, C0)
        N:batchsize
        H:height of featureMap
        W:width of featureMap
        CI:channel of featureMap
        C0:num of Filters
        group:num of spliting channels of FeatureMap
        KH:height of Filter
        KW:width of Filter
        PAD_H:padding pixels in vertical direction
        PAD_W:padding pixels in horizontal direction
        SH:stride in vertical direction
        SW:stride in horizontal direction
        block_size:a int var
        use_bias:a bool value
    Returns:
        akg.tvm.Tensor of same type as data, shape is 5D(N, C0//block_size, block_size, OH, OW)
    """

    conv_dtype = "float16"

    if cutH == H:
        cutH += PAD_H + PAD_H

    assert CO % group == 0 and CI % group == 0
    assert CO % block_size == 0 and (CI // group) % block_size == 0

    # (N, CI, H, W) -> (N, C0, H, W, C1)
    A = akg.tvm.placeholder((N, CI // block_size, H, W, block_size),
                            dtype=conv_dtype,
                            name="A")
    # (CO, CI // group, KH, KW) -> (CI // group // block * KH * KW, CO // block, block, block)
    B = akg.tvm.placeholder((CI // group // block_size * KH * KW,
                             CO // block_size, block_size, block_size),
                            dtype=conv_dtype,
                            name="B")

    bias = akg.tvm.placeholder((1, CO // block_size, 1, 1, block_size),
                               dtype=conv_dtype,
                               name="bias")

    OH = (H + 2 * PAD_H - KH) // SH + 1
    OW = (W + 2 * PAD_W - KW) // SW + 1

    kc1 = akg.tvm.reduce_axis((0, CI // block_size // group), name="kc1")
    kh = akg.tvm.reduce_axis((0, KH), name="kh")
    kw = akg.tvm.reduce_axis((0, KW), name="kw")
    kc0 = akg.tvm.reduce_axis((0, block_size), name="kc0")

    p_top, p_bottom, p_left, p_right = PAD_H, PAD_H, PAD_W, PAD_W
    output_name = "output"
    output_bias_name = "output_bias"

    C = akg.tvm.compute(
        (N, CO // block_size, OH, OW, block_size),
        lambda n, c1, h, w, c0: akg.lang.ascend.mmad(akg.tvm.if_then_else(
            akg.tvm.any((h * SH + kh) < p_top, (h * SH + kh) > (H + p_top - 1),
                        (w * SW + kw) < p_left, (w * SW + kw) >
                        (W + p_left - 1)), akg.tvm.const(0.0, conv_dtype),
            A[n, c1 // ((CO // block_size) // group) * (
                (CI // block_size) // group) + kc1, (h * SH + kh - p_top),
              (w * SW + kw - p_left), kc0]) * B[
                  (kc1 * KH + kh) * KW + kw, c1, c0, kc0],
                                                     axis=[kc1, kh, kw, kc0]),
        attrs={
            "pragma_conv_kernel_n": CO,
            "pragma_conv_kernel_h": KH,
            "pragma_conv_kernel_w": KW,
            "pragma_conv_padding_top": p_top,
            "pragma_conv_padding_bottom": p_bottom,
            "pragma_conv_padding_left": p_left,
            "pragma_conv_padding_right": p_right,
            "pragma_conv_bypass_l1": 1,
            "pragma_conv_stride_h": SH,
            "pragma_conv_stride_w": SW,
            "pragma_conv_fm_n": N,
            "pragma_conv_fm_c": CI,
            "pragma_conv_fm_h": H,
            "pragma_conv_fm_w": W,
            "pragma_conv_dilation_h": 1,
            "pragma_conv_dilation_w": 1,
            "pragma_conv_h_cut": cutH,
            "pragma_conv_w_cut": W + 2 * PAD_W,
            "pragma_conv_co_cut": cutCo,
            "pragma_conv_m_cut": cutM,
            "pragma_conv_k_cut": cutK,
            "pragma_conv_n_cut": cutN,
            "feature": A.op.name,
            "filter": B.op.name,
            "bias": bias.op.name,
            "res": output_name,
            "res_bias": output_bias_name
        },
        name=output_name)

    if use_bias:
        out = akg.tvm.compute(
            C.shape,
            lambda n, c1, h, w, c0: C[n, c1, h, w, c0] + bias[0, c1, 0, 0, c0],
            name=output_bias_name)
        bufs = [A, B, bias, out]
    else:
        out = C
        bufs = [A, B, out]

    # create schedule for cce
    s = akg.tvm.create_schedule([out.op])

    # set cut / tiling
    out_n, out_c1, out_h, out_w, out_c0 = akg.topi.util.get_const_tuple(
        out.shape)

    # set dim
    tile_out_h = (cutH - KH) // SH + 1

    info = dim.Dim()
    if (out_n > 1):
        info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
    if (out_c1 > 1):
        info.setdim(index=0, axis=0, tilel1=cutCo // block_size,
                    tilel0=0)  # c1
    if (out_h > 1):
        info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0)  # h
    if (out_w > 1):
        info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0)  # w
    if (out_c0 > 1):
        info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0
    assert CI // block_size // group == 1
    if (CI // block_size // group > 1):
        info.setdim(index=0,
                    axis=5,
                    tilel1=CI // block_size // group,
                    tilel0=0)  # kc1
    if (KH > 1):
        info.setdim(index=0, axis=5, tilel1=KH, tilel0=0)  # kh
    if (KW > 1):
        info.setdim(index=0, axis=5, tilel1=KW, tilel0=0)  # kw

    # build
    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod = akg.build(s,
                        bufs,
                        "cce",
                        name=kernel_name,
                        attrs={"dim": str(info)},
                        polyhedral=True)

    return OH, OW, A, B, C, mod
Esempio n. 12
0
def add_a_conv_compute(fmap_shape,
                       filter_shape,
                       pad_,
                       stride_,
                       dilation_,
                       tile_hh=0,
                       tile_coco=0,
                       tile_mm=0,
                       tile_kk=0,
                       tile_nn=0,
                       bypass_l1=False,
                       use_bias=False,
                       block_size=16,
                       conv_dtype='float16'):
    # input shape (NCHW -> NC1HWC0)
    in_n, in_c, in_h, in_w = fmap_shape
    in_c = (in_c + block_size - 1) // block_size * block_size
    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    k_c = (k_c + block_size - 1) // block_size * block_size
    k_n = (k_n + block_size - 1) // block_size * block_size
    # padding((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
    padding = (pad_[0], pad_[0], pad_[1], pad_[1])
    p_top, p_bottom, p_left, p_right = padding

    # stride (stride_h, stride_w)
    s_h, s_w = stride_

    # dilation (dilation_h, dilation_w)
    d_h, d_w = dilation_

    if tile_hh == in_h:
        tile_hh += p_top + p_bottom
    tile_coco = (tile_coco + block_size - 1) // block_size * block_size
    tile_mm = (tile_mm + block_size - 1) // block_size * block_size
    tile_kk = (tile_kk + block_size - 1) // block_size * block_size
    tile_nn = (tile_nn + block_size - 1) // block_size * block_size

    c0 = block_size
    c1_cut = tile_coco // c0
    h_window_cut = (tile_hh - k_h) // s_h + 1

    out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1

    kernel_name = "add_a_conv_layer_" + str(in_n) + "_" + str(in_c) + "_" + str(in_h) + "_" + str(in_w) \
                  + "_" + str(k_n) + "_" + str(in_c) + "_" + str(k_h) + "_" + str(k_w) \
                  + "_" + str(p_top) + "_" + str(s_h)

    input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
    in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0

    kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
    k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0
    kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size,
                            block_size, block_size)

    # bias shape
    bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size)

    # a_value placeholder (NC1HWCO)
    a_tmp = akg.tvm.placeholder(input_shape_nc1hwc0,
                                dtype=conv_dtype,
                                name='a_tmp')
    a_value = akg.tvm.compute(a_tmp.shape, lambda n, kc1, h, w, kc0: a_tmp[n, kc1, h, w, kc0] + 1, \
            name='a_value', attrs={'no_inline': 1})
    # b_value placeholder (fractal)
    b_value = akg.tvm.placeholder(kernel_shape_fractal,
                                  dtype=conv_dtype,
                                  name='b_value')

    if use_bias:
        bias_name = 'bias'
        bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0,
                                         dtype=conv_dtype,
                                         name=bias_name)
    else:
        bias_name = 'None'
        bias_value = None

    # Create reduction variables
    kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1')
    kh = akg.tvm.reduce_axis((0, k_h), name='kh')
    kw = akg.tvm.reduce_axis((0, k_w), name='kw')
    kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0')

    k_h_d = (k_h - 1) * d_h + 1
    k_w_d = (k_w - 1) * d_w + 1
    out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
    tile_out_h = (tile_hh - k_h_d) // s_h + 1
    out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1

    out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
    _, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    if tile_coco > 0:
        c1_cut = tile_coco // block_size
    else:
        c1_cut = out_c1

    # set dim
    if s_h > k_h:
        a_cut_h = tile_out_h * s_h
    else:
        a_cut_h = (tile_out_h - 1) * s_h + k_h_d
    a_cut_w = (out_w - 1) * s_w + k_w_d

    index = 0
    info = dim.Dim()
    if in_c1 > 1:
        info.setdim(index=index, axis="C1", tilel1=in_c1, tilel0=in_c1)  # c1
    if in_h > 1:
        info.setdim(index=index, axis="H", tilel1=a_cut_h, tilel0=a_cut_h)  # h
    if in_w > 1:
        info.setdim(index=index, axis="W", tilel1=a_cut_w, tilel0=a_cut_w)  # w
    if in_c0 > 1:
        info.setdim(index=index, axis="C0", tilel1=in_c0, tilel0=in_c0)  # c0

    index += 1
    if out_c1 > 1:
        info.setdim(index=index, axis="C1", tilel1=c1_cut, tilel0=0)  # c1
    if out_h > 1:
        info.setdim(index=index, axis="H", tilel1=tile_out_h, tilel0=0)  # h
    if out_w > 1:
        info.setdim(index=index, axis="W", tilel1=out_w, tilel0=0)  # w
    if out_c0 > 1:
        info.setdim(index=index, axis="C0", tilel1=out_c0, tilel0=0)  # c0
    if in_c1 > 1:
        info.setdim(index=index, axis=5, tilel1=in_c1, tilel0=0)  # kc1
    if k_h > 1:
        info.setdim(index=index, axis=5, tilel1=k_h, tilel0=0)  # kh
    if k_w > 1:
        info.setdim(index=index, axis=5, tilel1=k_w, tilel0=0)  # kw

    # Compute the convolution
    output_name = "c_value"
    output_bias_name = "OUT"
    c_value = akg.tvm.compute(out_shape_nc1hwc0,
                    lambda n, c1, h, w, c0: akg.lang.cce.mmad(
                        akg.tvm.if_then_else(akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) > (in_h + p_top - 1),
                                                 (w * s_w + kw) < p_left, (w * s_w + kw) > (in_w + p_left - 1)),
                                         akg.tvm.const(0.0, 'float16'),
                                         a_value[n, kc1, (h * s_h + (kh * d_h) - p_top), \
                                                 (w * s_w + (kw * d_w) - p_left), kc0])
                        * b_value[(kc1 * k_h + kh) * k_w + kw, c1, c0, kc0],
                        axis=[kc1, kh, kw, kc0]), name=output_name,
                    attrs={
                        "pragma_conv_kernel_n": k_n,
                        "pragma_conv_kernel_h": k_h,
                        "pragma_conv_kernel_w": k_w,
                        "pragma_conv_padding_top": p_top,
                        "pragma_conv_padding_bottom": p_bottom,
                        "pragma_conv_padding_left": p_left,
                        "pragma_conv_padding_right": p_right,
                        "pragma_conv_bypass_l1": 1 if bypass_l1 else 0,
                        "pragma_conv_stride_h": s_h,
                        "pragma_conv_stride_w": s_w,
                        "pragma_conv_dilation_h": d_h,
                        "pragma_conv_dilation_w": d_w,
                        "pragma_conv_fm_n": in_n,
                        "pragma_conv_fm_c": in_c,
                        "pragma_conv_fm_h": in_h,
                        "pragma_conv_fm_w": in_w,
                        "pragma_conv_h_cut": (h_window_cut - 1) * s_h + k_h_d,
                        "pragma_conv_w_cut": (in_w + p_left + p_right),
                        "pragma_conv_co_cut": c1_cut * k_c0,
                        "pragma_conv_m_cut": tile_mm,
                        "pragma_conv_k_cut": tile_kk,
                        "pragma_conv_n_cut": tile_nn,
                        "feature": a_value.op.name,
                        "filter": b_value.op.name,
                        "bias": bias_name,
                        "res": output_name,
                        "res_bias": output_bias_name})

    if use_bias:
        cube = akg.tvm.compute(out_shape_nc1hwc0,
                               lambda n, c1, h, w, c0: c_value[n, c1, h, w, c0]
                               + bias_value[0, c1, 0, 0, c0],
                               name=output_bias_name)
    else:
        cube = c_value
    return cube, a_tmp, b_value, bias_value, kernel_name, str(info)
Esempio n. 13
0
def conv_backprop_input_compute(data,
                                output_shape,
                                filter_shape,
                                input_shape,
                                pad_,
                                stride_,
                                block_size=16,
                                attrs=None,
                                key=None):
    """core computation of conv_backprop_input."""
    _, in_c, w_h, w_w = filter_shape

    # stride (stride_h, stride_w)
    stride_h, stride_w = stride_
    if stride_h != stride_w:
        raise ValueError("stride_h must be equal to stride_w.")

    # output shape (NCHW -> NC1HWC0)
    in_nn, in_cc, in_hh, in_ww = output_shape
    if in_c % block_size != 0:
        raise ValueError("in_c must be divided by block_size.")
    input_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh, in_ww,
                           block_size)
    in_nn, _, in_hh, in_ww, _ = input_shape_nc1hwc0
    input_trans_shape_nc1hwc0 = (in_nn, in_cc // block_size, in_hh * stride_h,
                                 in_ww * stride_w, block_size)
    in_n, in_c1, in_h, in_w, _ = input_trans_shape_nc1hwc0

    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    if k_c % block_size != 0:
        raise ValueError("k_c must be divided by block_size.")
    kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
    k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0
    kernel_shape_trans = (k_n // block_size * k_h * k_w, k_c // block_size,
                          block_size, block_size)
    k_c1 = k_n // block_size
    k_n = k_c

    _, _, input_h, input_w = input_shape

    # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
    padding = (pad_[0], pad_[1], pad_[2], pad_[3])
    pad_t, pad_b, pad_l, pad_r = padding

    # padHT -> padHT'
    p_top = k_h - pad_t - 1
    # padHB -> padHB'
    p_bottom = input_h + pad_t - stride_h * (
        (input_h + pad_t + pad_b - k_h) // stride_h + 1)
    # padWL -> padWL'
    p_left = k_w - pad_l - 1
    # padWR -> padWR'
    p_right = input_w + pad_l - stride_w * (
        (input_w + pad_l + pad_r - k_w) // stride_w + 1)

    s_h = 1
    s_w = 1

    # NC1HWCO
    a_value = data[0]

    if data[1].dtype == 'float32':
        b_value = cast.cast(data[1], 'float16')
        tiling_args = cast_tiling_args
    else:
        b_value = data[1]
        tiling_args = conv_backprop_input_tiling_args

    # Create reduction variables
    kc1 = akg.tvm.reduce_axis((0, k_c1), name='kc1')
    kh = akg.tvm.reduce_axis((0, k_h), name='kh')
    kw = akg.tvm.reduce_axis((0, k_w), name='kw')
    kc0 = akg.tvm.reduce_axis((0, k_c0), name='kc0')
    use_auto_tiling = False
    if attrs is not None and 'conv_tile' in attrs and len(
            attrs['conv_tile']) >= 5:
        tile_value = attrs['conv_tile']
    elif key in tiling_args:
        tile_value = tiling_args[key]
    else:
        use_auto_tiling = True

    out_h = (in_h + p_top + p_bottom - k_h) // (s_h) + 1
    out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1
    out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    # set dim
    info = dim.Dim()
    index_ = 0

    if not use_auto_tiling:
        tile_hh = tile_value[0]
        if tile_hh == input_h:
            tile_hh += pad_t + pad_b

        tile_coco = tile_value[1]
        tile_coco = (tile_coco + block_size - 1) // block_size * block_size

        tile_mm = tile_value[2]
        tile_mm = (tile_mm + block_size - 1) // block_size * block_size

        tile_kk = tile_value[3]
        if not tile_kk % (block_size * w_h * w_w) == 0:
            logging.warning(
                "Warning: tile_k must be a multiple of (block_size * w_h * w_w)"
            )
        tile_kk = (tile_kk + block_size * w_h * w_w -
                   1) // (block_size * w_h * w_w) * (block_size * w_h * w_w)

        tile_nn = tile_value[4]
        tile_nn = (tile_nn + block_size - 1) // block_size * block_size

        tile_ww = input_w
        if len(tile_value) >= 6 and tile_value[5] > 0:
            tile_ww = tile_value[5]
        if tile_ww == input_w:
            tile_ww += pad_l + pad_r

        if tile_hh == in_h:
            tile_hh += p_top + p_bottom
        tile_out_h = (tile_hh - k_h) // s_h + 1

        if tile_ww == in_w:
            tile_ww += p_left + p_right
        tile_out_w = (tile_ww - k_w) // s_w + 1

        if tile_coco > 0:
            c1_cut = tile_coco // block_size
        else:
            c1_cut = out_c1

        if out_n > 1:
            info.setdim(index=index_, axis=0, tilel1=1, tilel0=0)  # n
        if out_c1 > 1:
            info.setdim(index=index_, axis=1, tilel1=c1_cut, tilel0=0)  # c1
        if out_h > 1:
            info.setdim(index=index_, axis="H", tilel1=tile_out_h,
                        tilel0=0)  # h
        if out_w > 1:
            info.setdim(index=index_, axis="W", tilel1=tile_out_w,
                        tilel0=0)  # w
        if out_c0 > 1:
            info.setdim(index=index_, axis=4, tilel1=out_c0, tilel0=0)  # c0
        if in_c1 > 1:
            info.setdim(index=index_, axis=5, tilel1=in_c1, tilel0=0)  # kc1
        if k_h > 1:
            info.setdim(index=index_, axis=5, tilel1=k_h, tilel0=0)  # kh
        if k_w > 1:
            info.setdim(index=index_, axis=5, tilel1=k_w, tilel0=0)  # kw

        info = str(info)
    else:
        info = ""
    # Compute the convolution below

    output_name = "output0"

    # weight_trans [ ko, no, ni, ki ]
    # weight_trans [ co_1, kh, kw, ci_1, ci_0, co_0 ]
    # kw = ko % k_w
    # kh = ko // k_w % k_h
    # co_1 = ko // k_w // k_h
    # ci_1 = no
    # -->
    # weight [ ci_1, kh', kw', co_1, co_0, ci_0 ]
    # weight [ no, k_h - ko // k_w % k_h - 1, k_w - ko % k_w - 1, ko // k_w // k_h, co_0, ci_0 ]
    b_trans = akg.tvm.compute(kernel_shape_trans,
                              lambda ko, no, ni, ki: b_value[
                                  ((no * k_h + k_h - 1 - ko // k_w % k_h) * k_w
                                   + k_w - 1 - ko % k_w), ko //
                                  (k_h * k_w), ki, ni],
                              name='B_trans')

    if ((stride_h > 1) or (stride_w > 1)):

        @akg.tvm.hybrid.script
        def data_trans_hybrid(output, inputs, const_zero):
            """Implements data_trans ( B[n, c1, h * strideH, w * strideW, c0] = A[n, c1, h, w, c0] )."""

            stride_h = output.shape[2] // inputs.shape[2]
            stride_w = output.shape[3] // inputs.shape[3]

            b = allocate(output.shape, output.dtype, 'local')
            for n in range(output.shape[0]):
                for c1 in range(output.shape[1]):
                    for h in range(output.shape[2]):
                        for w in range(output.shape[3]):
                            for c0 in range(output.shape[4]):
                                b[n, c1, h, w, c0] = const_zero
                                if h % stride_h == 0 and w % stride_w == 0:
                                    b[n, c1, h, w,
                                      c0] = inputs[n, c1, h // stride_h,
                                                   w // stride_w, c0]

            return b

        a_trans_init = akg.tvm.placeholder(input_trans_shape_nc1hwc0,
                                           dtype="float16",
                                           name='a_trans')
        const_zero = akg.tvm.const(0, 'float16')
        a_trans = data_trans_hybrid(a_trans_init, a_value, const_zero)
    else:
        a_trans = a_value
    conv_attrs = {
        "pragma_conv_kernel_n": k_n,
        "pragma_conv_kernel_h": k_h,
        "pragma_conv_kernel_w": k_w,
        "pragma_conv_padding_top": p_top,
        "pragma_conv_padding_bottom": p_bottom,
        "pragma_conv_padding_left": p_left,
        "pragma_conv_padding_right": p_right,
        "pragma_conv_bypass_l1": 0,
        "pragma_conv_backprop_input": 1,
        "pragma_conv_stride_h": s_h,
        "pragma_conv_stride_w": s_w,
        "pragma_conv_dilation_h": 1,
        "pragma_conv_dilation_w": 1,
        "pragma_conv_fm_n": in_n,
        "pragma_conv_fm_c": in_c,
        "pragma_conv_fm_h": in_h,
        "pragma_conv_fm_w": in_w,
        "feature": a_trans.op.name,
        "filter": b_value.op.name,
        "bias": 'None',
        "res": output_name
    }
    if not use_auto_tiling:
        conv_attrs["pragma_conv_h_cut"] = (tile_out_h - 1) * s_h + k_h
        conv_attrs["pragma_conv_w_cut"] = (tile_out_w - 1) * s_w + k_w
        conv_attrs["pragma_conv_co_cut"] = c1_cut * k_c0
        conv_attrs["pragma_conv_m_cut"] = tile_mm
        conv_attrs["pragma_conv_k_cut"] = tile_kk
        conv_attrs["pragma_conv_n_cut"] = tile_nn
    res_c = akg.tvm.compute(
        out_shape_nc1hwc0,
        lambda n, c1, h, w, c0: akg.lang.cce.mmad((akg.tvm.if_then_else(
            akg.tvm.any((h * s_h + kh) < p_top, (h * s_h + kh) >
                        (in_h + p_top - 1), (w * s_w + kw) < p_left,
                        (w * s_w + kw) >
                        (in_w + p_left - 1)), akg.tvm.const(0.0, 'float16'),
            a_trans[n, kc1, (h * s_h + kh - p_top),
                    (w * s_w + kw - p_left), kc0]) * b_trans[
                        (kc1 * k_h + kh) * k_w + kw, c1, c0, kc0]).astype(
                            "float32"),
                                                  axis=[kc1, kh, kw, kc0]),
        name=output_name,
        attrs=conv_attrs)

    res_c = cast.cast(res_c, "float16")

    return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_rmselfdep": 0}
Esempio n. 14
0
def avg_pool_5d_hybrid(a_value, kernel, stride, strategy):
    """avgpool with 5d case via hybrid"""
    kernel_h, kernel_w = kernel
    stride_h, stride_w = stride
    shape = get_shape(a_value)
    batch_size, c1_, in_size_h, in_size_w, c0_ = shape
    dtype = a_value.dtype
    if len(shape) != 5:
        raise ValueError("Only support 5-dim pooling!")
    if len(kernel) != 2:
        raise ValueError("Only support 2-dim kernel!")

    [pad_height_head, _, pad_width_head, _], [out_size_h, out_size_w] = \
        cal_pad_shapes_by_strategy(shape, kernel, stride, strategy)

    avg_pre = akg.tvm.const(1.0000 / (kernel_w * kernel_h), dtype=dtype)
    zero = akg.tvm.const(0.0, dtype=dtype)

    @script(capture=locals())
    def avg_pool_hybrid(inputs, zero, avg_pre):
        output = output_tensor((batch_size, c1_, out_size_h, out_size_w, c0_),
                               inputs.dtype)

        for n in range(batch_size):
            for c1 in range(c1_):
                # Head
                for ow in range(out_size_w):
                    for c0 in range(c0_):
                        output[n, c1, 0, ow, c0] = zero
                for ow in range(out_size_w):
                    for kh in range(kernel_h):
                        for kw in range(kernel_w):
                            for c0 in range(c0_):
                                if (kh >= pad_height_head) \
                                        and (ow * stride_w + kw - pad_width_head >= 0) \
                                        and (ow * stride_w + kw <= in_size_w + pad_width_head - 1):
                                    output[n, c1, 0, ow, c0] = output[n, c1, 0, ow, c0] +\
                                        inputs[n, c1, kh - pad_height_head,
                                               ow * stride_w + kw - pad_width_head, c0]
                                else:
                                    output[n, c1, 0, ow, c0] += zero
                for ow in range(out_size_w):
                    for c0 in range(c0_):
                        output[n, c1, 0, ow, c0] *= avg_pre
                # Tail
                for oh in range(out_size_h - 1):
                    for ow in range(out_size_w):
                        for c0 in range(c0_):
                            output[n, c1, oh + 1, ow, c0] = zero
                for oh in range(out_size_h - 1):
                    for ow in range(out_size_w):
                        for kh in range(kernel_h):
                            for kw in range(kernel_w):
                                for c0 in range(c0_):
                                    if ((oh + 1) * stride_h + kh <= in_size_h + pad_height_head - 1)\
                                            and (ow * stride_w + kw >= pad_width_head)\
                                            and (ow * stride_w + kw <= in_size_w + pad_width_head - 1):
                                        output[n, c1, oh + 1, ow, c0] = output[n, c1, oh + 1, ow, c0] +\
                                            inputs[n, c1, (oh + 1) * stride_h +
                                                   kh - pad_height_head, ow * stride_w +
                                                   kw - pad_width_head, c0]
                                    else:
                                        output[n, c1, oh + 1, ow, c0] += zero
                for oh in range(out_size_h - 1):
                    for ow in range(out_size_w):
                        for c0 in range(c0_):
                            output[n, c1, oh + 1, ow, c0] *= avg_pre
        return output

    res_value = avg_pool_hybrid(a_value, zero, avg_pre)

    # set dim
    info = dim.Dim()
    # first part
    info.setdim(index=0, axis=0, tilel1=out_size_w, tilel0=0)  # ow
    info.setdim(index=0, axis=1, tilel1=c0_, tilel0=0)  # c0
    info.setdim(index=0, axis=2, tilel1=kernel_h, tilel0=0)  # kh

    # second part
    info.setdim(index=1, axis=0, tilel1=out_size_h - 1, tilel0=0)  # oh-1
    info.setdim(index=1, axis=1, tilel1=out_size_w, tilel0=0)  # ow
    info.setdim(index=1, axis=2, tilel1=c0_, tilel0=0)  # c0
    info.setdim(index=1, axis=3, tilel1=kernel_h, tilel0=0)  # kh

    info = str(info)

    attrs = {DIM: info}
    return res_value, attrs
Esempio n. 15
0
def conv_backprop_filter_compute(data, input_shape, filter_shape, output_shape, pad_, stride_, dilation_,
                                 block_size=16, attrs=None, key=None):
    """core computation of conv_backprop_filter_compute."""
    # stride (stride_h, stride_w)
    stride_h, stride_w = stride_
    if stride_h != stride_w:
        raise ValueError("stride_h must be equal to stride_w.")
    # conv_backprop_filter input shape (NCHW -> NC1HWC0 -> fractal): load2d L0A
    input_n, input_c, input_h, input_w = output_shape
    if input_c % block_size != 0:
        raise ValueError("output channel must be divided by block_size.")
    if input_n > 32:
        raise ValueError("Batch must be less than or equal to 32.")
    input_shape_nc1hwc0 = (input_n, input_c // block_size, input_h, input_w, block_size)
    input_n, input_c1, input_h, input_w, input_c0 = input_shape_nc1hwc0
    mo = (input_h * input_w + block_size - 1) // block_size
    mi = block_size
    input_trans_shape_fractal = (input_n, input_c1, mo, input_c0, mi)

    # conv_backprop_filter kernel shape (NCHW -> NC1HWC0): img2col L0B
    k_n, k_c, k_h, k_w = input_shape
    if k_c % block_size != 0:
        raise ValueError("input channel must be divided by block_size.")
    kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
    k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0

    # conv_backprop_filter output shape (NCHW -> NC1HWC0)
    out_n, out_c, out_h, out_w = filter_shape
    if out_n != input_c:
        raise ValueError("out_n must be equal to input_c.")
    output_shape_nc1hwc0 = (out_n, out_c // block_size, out_h, out_w, block_size)
    out_n, out_c1, out_h, out_w, _ = output_shape_nc1hwc0
    output_shape_fractal = (out_c1, out_h, out_w, out_n // block_size, block_size, block_size)
    out_c1, out_h, out_w, out_mo, out_mi, out_ni = output_shape_fractal

    # padding ((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
    padding = (pad_[0], pad_[1], pad_[2], pad_[3])
    p_top, p_bottom, p_left, p_right = padding

    s_h, s_w = stride_

    data_a = data[0]
    o_n, o_c1, o_h, o_w, o_c0 = data_a.shape
    mo = (o_h * o_w + block_size - 1) // block_size
    mi = block_size
    a_shape_fractal = (o_n, o_c1, mo, mi, o_c0)
    a_fractal = akg.tvm.placeholder(a_shape_fractal, dtype=data_a.dtype, name="backprop")
    a_buf = akg.tvm.decl_buffer(a_shape_fractal, a_fractal.dtype, name="backprop")
    data_b = data[1]
    tiling_args = batch_conv_backprop_filter_tiling_args
    use_autotiling = False
    if k_n == 1:
        tiling_args = conv_backprop_filter_tiling_args
    if attrs is not None and 'conv_tile' in attrs and len(attrs['conv_tile']) >= 8:
        tile = attrs['conv_tile']
    elif key in tiling_args:
        tile = tiling_args[key]
    else:
        use_autotiling = True

    in_h = k_h
    in_w = k_w
    if not use_autotiling:
        # set dim
        info = dim.Dim()
        index_ = 0

        # tile = [Ci, KH, KW, Co, Batch, H, W, M, K, N]
        tile_ci = tile[0]
        if tile_ci > k_c1 * k_c0:
            tile_ci = k_c1 * k_c0
        tile_ci = (tile_ci + block_size - 1) // block_size

        tile_kh = tile[1]
        if tile_kh > out_h:
            tile_kh = out_h

        tile_kw = tile[2]
        if tile_kw > out_w:
            tile_kw = out_w

        tile_coco = tile[3]
        if tile_coco > input_c1 * input_c0:
            tile_coco = input_c1 * input_c0
        tile_coco = (tile_coco + block_size - 1) // block_size

        tile_batch = tile[4]
        if tile_batch > input_n:
            tile_batch = input_n
        if tile_batch != 1:
            raise ValueError("tile_batch must be 1.")

        d_h, d_w = dilation_

        tile_hh = tile[5]
        if tile_hh == in_h:
            tile_hh = in_h + p_top + p_bottom
        elif tile_hh > in_h + p_top + p_bottom:
            tile_hh = in_h + p_top + p_bottom
        h_win_cut = (tile_hh - ((out_h - 1) * d_h + 1)) // s_h + 1

        tile_ww = tile[6]
        if tile_ww == in_w:
            tile_ww = in_w + p_left + p_right
        elif tile_ww > in_w + p_left + p_right:
            tile_ww = in_w + p_left + p_right
        w_win_cut = (tile_ww - ((out_w - 1) * d_w + 1)) // s_w + 1

        tile_mm = tile[7]
        tile_kk = tile[8]
        tile_nn = tile[9]

        tile_mm = (tile_mm + block_size - 1) // block_size * block_size
        tile_kk = (tile_kk + block_size - 1) // block_size * block_size
        tile_nn = (tile_nn + block_size - 1) // block_size * block_size

        if out_c1 > 1:
            info.setdim(index=index_, axis=0, tilel1=tile_ci, tilel0=tile_ci)
        if out_h > 1:
            info.setdim(index=index_, axis=0, tilel1=tile_kh, tilel0=tile_kh)
        if out_w > 1:
            info.setdim(index=index_, axis=0, tilel1=tile_kw, tilel0=tile_kw)
        if out_mo > 1:
            info.setdim(index=index_, axis=0, tilel1=tile_coco, tilel0=tile_coco)
        if out_mi > 1:
            info.setdim(index=index_, axis=0, tilel1=out_mi, tilel0=out_mi)  # mi don't tile
        if out_ni > 1:
            info.setdim(index=index_, axis=0, tilel1=out_ni, tilel0=out_ni)  # ni don't tile
        if input_n > 1:
            info.setdim(index=index_, axis=0, tilel1=tile_batch, tilel0=tile_batch)  # Batch tile
        if k_h > 1:
            info.setdim(index=index_, axis="H", tilel1=h_win_cut, tilel0=h_win_cut)  # out_h
        if k_w > 1:
            info.setdim(index=index_, axis="W", tilel1=w_win_cut, tilel0=w_win_cut)  # out_w

        info = str(info)
    else:
        info = ""

    # Compute the convolution
    output_name = "filter"

    a_trans = akg.tvm.compute(input_trans_shape_fractal,
                              lambda n, co1, mo, co0, mi: a_fractal[n, co1, mo, mi, co0], name='dy_trans')

    # Create reduction variables
    no = akg.tvm.reduce_axis((0, input_n), name='no')
    ho = akg.tvm.reduce_axis((0, input_h), name='ho')
    wo = akg.tvm.reduce_axis((0, input_w), name='wo')

    conv_filter_attr = {
        "pragma_conv_kernel_n": out_n,
        "pragma_conv_kernel_h": out_h,
        "pragma_conv_kernel_w": out_w,
        "pragma_conv_padding_top": p_top,
        "pragma_conv_padding_bottom": p_bottom,
        "pragma_conv_padding_left": p_left,
        "pragma_conv_padding_right": p_right,
        "pragma_conv_bypass_l1": 0,
        "pragma_conv_backprop_filter": 1,
        "pragma_conv_stride_h": s_h,
        "pragma_conv_stride_w": s_w,
        "pragma_conv_dilation_h": 1,
        "pragma_conv_dilation_w": 1,
        "pragma_conv_fm_n": k_n,
        "pragma_conv_fm_c": k_c,
        "pragma_conv_fm_h": k_h,
        "pragma_conv_fm_w": k_w,
        "feature": data_b.op.name,
        "filter": a_fractal.op.name,
        "bias": 'None',
        "res": output_name}

    if not use_autotiling:
        conv_filter_attr["pragma_conv_batch_cut"] = tile_batch
        conv_filter_attr["pragma_conv_h_cut"] = (h_win_cut - 1) * s_h + ((out_h - 1) * d_h + 1)
        conv_filter_attr["pragma_conv_w_cut"] = (w_win_cut - 1) * s_w + ((out_w - 1) * d_w + 1)
        conv_filter_attr["pragma_conv_co_cut"] = tile_coco * block_size
        conv_filter_attr["pragma_conv_cin_cut"] = tile_ci * block_size
        conv_filter_attr["pragma_conv_m_cut"] = tile_mm
        conv_filter_attr["pragma_conv_k_cut"] = tile_kk
        conv_filter_attr["pragma_conv_n_cut"] = tile_nn
        conv_filter_attr["pragma_conv_kh_cut"] = tile_kh
        conv_filter_attr["pragma_conv_kw_cut"] = tile_kw

    res_c = akg.tvm.compute(output_shape_fractal,
                            lambda c1, h, w, mo, mi, ni: akg.lang.cce.mmad(
                                (akg.tvm.if_then_else(akg.tvm.any((h + s_h * ho) < p_top,
                                                                  (h + s_h * ho) > (in_h + p_top - 1),
                                                                  (w + s_w * wo) < p_left,
                                                                  (w + s_w * wo) > (in_w + p_left - 1)),
                                                      akg.tvm.const(0.0, 'float16'),
                                                      a_trans[no, mo, (input_w * ho + wo) // 16,
                                                              mi, (input_w * ho + wo) % 16])
                                 * data_b[no, c1, (ho * s_h + h - p_top),
                                          (wo * s_w + w - p_left), ni]).astype("float32"),
                                axis=[no, ho, wo]), name=output_name, attrs=conv_filter_attr)

    return res_c, {"dim": info, "pragma_reschedule": 1, "pragma_conv_special_dma": 1,
                   utils.BINDS: {data_a: a_buf, a_fractal: a_buf}}
Esempio n. 16
0
def depthwise_set_dim_func(data,
                           N,
                           H,
                           W,
                           CI,
                           k_ch,
                           KH,
                           KW,
                           PAD_H,
                           PAD_W,
                           SH,
                           SW,
                           block_size,
                           use_bias=False):
    key = [N, H, W, CI, k_ch, KH, KW, PAD_H, PAD_W, SH, SW]
    hash_key = str((tuple(key)))
    clear = True
    if hash_key in depthwise_set_dim_map:
        cutH, cutCo, _, _, _ = depthwise_set_dim_map[hash_key]
        clear = False
    else:
        # raise RuntimeError("other can not find cutH, cutCo, cutM, cutK, cutN")
        cutH = (KH - 1) * KH + 1
        cutCo = 16
    group = CI // block_size
    CO = CI * k_ch

    OH = (H + 2 * PAD_H - KH) // SH + 1
    OW = (W + 2 * PAD_W - KW) // SW + 1

    out_n, out_c1, out_h, out_w, out_c0 = [
        N, CO // block_size, OH, OW, block_size
    ]
    # set dim
    tile_out_h = (cutH - KH) // SH + 1

    info = dim.Dim()
    if (out_n > 1):
        info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
    if (out_c1 > 1):
        info.setdim(index=0, axis=0, tilel1=cutCo // block_size,
                    tilel0=0)  # c1
    if (out_h > 1):
        info.setdim(index=0, axis='H', tilel1=tile_out_h, tilel0=0)  # h
    if (out_w > 1):
        info.setdim(index=0, axis=3, tilel1=out_w, tilel0=0)  # w
    if (out_c0 > 1):
        info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0
    assert CI // block_size // group == 1
    if (CI // block_size // group > 1):
        info.setdim(index=0,
                    axis=5,
                    tilel1=CI // block_size // group,
                    tilel0=0)  # kc1
    if (KH > 1):
        info.setdim(index=0, axis=5, tilel1=KH, tilel0=0)  # kh
    if (KW > 1):
        info.setdim(index=0, axis=5, tilel1=KW, tilel0=0)  # kw
    if clear:
        info = ""
    return str(info)
Esempio n. 17
0
def conv_02(fmap_shape,
            filter_shape,
            pad_,
            stride_,
            dilation_,
            tile_hh=0,
            tile_coco=0,
            tile_mm=0,
            tile_kk=0,
            tile_nn=0,
            bypass_l1=False,
            use_bias=False,
            block_size=16,
            conv_dtype='float16'):

    # input shape (NCHW -> NC1HWC0)
    in_n, in_c, in_h, in_w = fmap_shape
    in_c = (in_c + block_size - 1) // block_size * block_size
    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    k_c = (k_c + block_size - 1) // block_size * block_size
    k_n = (k_n + block_size - 1) // block_size * block_size

    input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
    in_n, _, in_h, in_w, _ = input_shape_nc1hwc0

    kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
    k_n, _, k_h, k_w, _ = kernel_shape_nc1hwc0
    kernel_shape_fractal = (k_c // block_size * k_h * k_w, k_n // block_size,
                            block_size, block_size)

    # A placeholder (NC1HWCO)
    A = akg.tvm.placeholder(input_shape_nc1hwc0,
                            dtype=conv_dtype,
                            name="input0")
    # B_placeholder (fractal)
    B = akg.tvm.placeholder(kernel_shape_fractal,
                            dtype=conv_dtype,
                            name="input1")

    if use_bias:
        bias_shape_nc1hwc0 = (1, k_n // block_size, 1, 1, block_size)
        bias_name = "input2"
        bias_value = akg.tvm.placeholder(bias_shape_nc1hwc0,
                                         dtype=conv_dtype,
                                         name=bias_name)
    else:
        bias_name = 'None'
        bias_value = None

    conv_forward = conv_compute_forward(fmap_shape, filter_shape, pad_,
                                        stride_, dilation_, A, B, bias_value,
                                        tile_hh, tile_coco, tile_mm, tile_kk,
                                        tile_nn, bypass_l1, use_bias,
                                        block_size, conv_dtype)

    k_hw = k_h * k_w
    const_shift = k_hw - 1

    # B in Fractal format; result in Fractal format
    def flip_weight(B, k_c, k_hw, const_shift):
        out_shape = (B.shape[1].value * k_hw, k_c // block_size, block_size,
                     block_size)
        B_flip = akg.tvm.compute(
            out_shape,
            lambda i0, i1, i2, i3: B[i1 * k_hw + const_shift - truncmod(
                i0, k_hw),
                                     floordiv(i0, k_hw), i3, i2],
            name=B.name + "_flipped")
        return B_flip

    # H in 5D format; result in 5D format
    def strided_head(H, s_h, s_w):
        n, c1, h, w, c0 = H.shape
        out_shape = (n, c1, (h - 1) * s_h + 1, (w - 1) * s_w + 1, c0)
        H_strided = akg.tvm.compute(
            out_shape,
            lambda i0, i1, i2, i3, i4: akg.tvm.expr.Select(
                akg.tvm.any(truncmod(i2, s_h) != 0,
                            truncmod(i3, s_w) != 0),
                akg.tvm.const(0.0, dtype="float16"), H[i0, i1,
                                                       floordiv(i2, s_h),
                                                       floordiv(i3, s_w), i4]),
            name=H.name + "_strided")

        return H_strided

    # A in 5D format; result in 5D format
    def transpose_data(A):
        out_shape = (A.shape[1].value * block_size,
                     A.shape[0].value // block_size, A.shape[2].value,
                     A.shape[3].value, block_size)

        A_transpose = akg.tvm.compute(
            out_shape,
            lambda j0, j1, j2, j3, j4: A[j1 * block_size + j4,
                                         floordiv(j0, block_size), j2, j3,
                                         truncmod(j0, block_size)],
            name=A.name + "_transposed")
        return A_transpose

    # Head is in 5D format; result in Fractal format
    def transpose_convert_head(Head):
        out_shape = ((Head.shape[0].value // block_size) *
                     Head.shape[2].value * Head.shape[3].value,
                     Head.shape[1].value, block_size, block_size)
        tmp_6D_shape = (Head.shape[0].value // block_size, block_size,
                        Head.shape[1].value, Head.shape[2].value,
                        Head.shape[3].value, block_size)
        Head_6D = akg.topi.reshape(Head, tmp_6D_shape)
        Head_6D_transpose = akg.topi.transpose(Head_6D, (0, 3, 4, 2, 5, 1))
        Head_transpose_convert = akg.topi.reshape(Head_6D_transpose, out_shape)
        return Head_transpose_convert

    HEAD = akg.tvm.placeholder(conv_forward.shape,
                               name="Head",
                               dtype='float16')
    Head_transposed_NCHW = (HEAD.shape[1].value * HEAD.shape[4].value,
                            HEAD.shape[0].value, HEAD.shape[2].value,
                            HEAD.shape[3].value)
    s_h, s_w = stride_
    Head_strided_NCHW = (HEAD.shape[0].value,
                         HEAD.shape[1].value * HEAD.shape[4].value,
                         (HEAD.shape[2].value - 1) * s_h + 1,
                         (HEAD.shape[3].value - 1) * s_w + 1)

    A_transposed_NCHW = (in_c, in_n, in_h, in_w)
    K_flip_rot_NCHW = (k_c, k_n, k_h, k_w)

    Head_transposed_converted = transpose_convert_head(HEAD)
    pld_Head_transposed_converted = akg.tvm.placeholder(
        Head_transposed_converted.shape,
        name="Head_trans_fractal",
        dtype=conv_dtype)
    A_transposed = transpose_data(A)
    pld_A_transposed = akg.tvm.placeholder(A_transposed.shape,
                                           name="A_trans",
                                           dtype=conv_dtype)

    info = dim.Dim()
    info.setdim(index=0, axis=0, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=1, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=2, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=3, tilel1=1, tilel0=1)

    B_flip = flip_weight(B, k_c, k_hw, const_shift)
    pld_B_flipped = akg.tvm.placeholder(B_flip.shape,
                                        name="B_flip",
                                        dtype=conv_dtype)

    s_flipped = akg.tvm.create_schedule(B_flip.op)
    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_weight_flipped = akg.build(s_flipped, [B, B_flip],
                                       "cce",
                                       name=B.name + "_flipped",
                                       attrs={"dim": str(info)},
                                       polyhedral=True)

    s_transposed_converted = akg.tvm.create_schedule(
        Head_transposed_converted.op)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_head_transposed_converted = akg.build(
            s_transposed_converted, [HEAD, Head_transposed_converted],
            "cce",
            name="H_trans_converted",
            attrs={"dim": str(info)},
            polyhedral=True)

    Head_strided = strided_head(HEAD, s_h, s_w)
    pld_Head_strided = akg.tvm.placeholder(Head_strided.shape,
                                           name="Head_trans_5D",
                                           dtype=conv_dtype)

    s_strided = akg.tvm.create_schedule(Head_strided.op)
    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_head_strided = akg.build(s_strided, [HEAD, Head_strided],
                                     "cce",
                                     name="H_strided",
                                     attrs={"dim": str(info)},
                                     polyhedral=True)

    s_transposed = akg.tvm.create_schedule(A_transposed.op)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_transposed = akg.build(s_transposed, [A, A_transposed],
                                   "cce",
                                   name="A_transposed",
                                   attrs={"dim": str(info)},
                                   polyhedral=True)

    ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 1}
    jacs = list(
        akg.differentiate(conv_forward, [A], HEAD, ad_attrs,
                          [pld_Head_strided, pld_B_flipped, None]))
    info = set_dims(Head_strided_NCHW, (k_c, k_n, k_h, k_w),
                    (k_h - 1, k_w - 1), (1, 1), (1, 1), tile_hh, tile_coco,
                    tile_mm, tile_kk, tile_nn, block_size)

    sjac = akg.tvm.create_schedule([jacs[0].op])
    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_AD_data = akg.build(sjac,
                                [pld_Head_strided, pld_B_flipped, jacs[0]],
                                "cce",
                                name="conv_AD_data",
                                attrs={"dim": str(info)},
                                polyhedral=True)

    conv_data = conv_compute_forward(Head_strided_NCHW, K_flip_rot_NCHW,
                                     (k_h - 1, k_h - 1, k_w - 1, k_w - 1),
                                     (1, 1), (1, 1), pld_Head_strided,
                                     pld_B_flipped, None, tile_hh, tile_coco,
                                     tile_mm, tile_kk, tile_nn, bypass_l1,
                                     use_bias, block_size, conv_dtype)

    info = set_dims(Head_strided_NCHW, (k_c, k_n, k_h, k_w),
                    (k_h - 1, k_w - 1), (1, 1), (1, 1), tile_hh, tile_coco,
                    tile_mm, tile_kk, tile_nn, block_size)

    s_data = akg.tvm.create_schedule(conv_data.op)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        _ = akg.build(s_data, [pld_Head_strided, pld_B_flipped, conv_data],
                      "cce",
                      name="conv_data",
                      attrs={"dim": str(info)},
                      polyhedral=True)

    ad_attrs = {"ad_conv_enable": 1, "ad_conv_reuse_conv": 1}
    jacs = list(
        akg.differentiate(
            conv_forward, [B], HEAD, ad_attrs,
            [pld_A_transposed, pld_Head_transposed_converted, None]))
    info = set_dims(A_transposed_NCHW, Head_transposed_NCHW, (0, 0), (1, 1),
                    (s_h, s_w), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn,
                    block_size)

    sjac = akg.tvm.create_schedule([jacs[0].op])
    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_AD_weight = akg.build(
            sjac, [pld_A_transposed, pld_Head_transposed_converted, jacs[0]],
            "cce",
            name="conv_AD_weight",
            attrs={"dim": str(info)},
            polyhedral=True)

    conv_weight = conv_compute_forward(
        A_transposed_NCHW, Head_transposed_NCHW, (0, 0, 0, 0), (1, 1),
        (s_h, s_w), pld_A_transposed, pld_Head_transposed_converted, None,
        tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, bypass_l1, use_bias,
        block_size, conv_dtype)

    info = set_dims(A_transposed_NCHW, Head_transposed_NCHW, (0, 0), (1, 1),
                    (s_h, s_w), tile_hh, tile_coco, tile_mm, tile_kk, tile_nn,
                    block_size)

    s_weight = akg.tvm.create_schedule(conv_weight.op)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        akg.build(
            s_weight,
            [pld_A_transposed, pld_Head_transposed_converted, conv_weight],
            "cce",
            name="conv_weight",
            attrs={"dim": str(info)},
            polyhedral=True)

    return mod_AD_data, mod_AD_weight, mod_transposed, mod_head_transposed_converted, mod_head_strided, mod_weight_flipped
Esempio n. 18
0
File: base.py Progetto: zhuyawen/akg
    def get_dim_info(self, arg, is_conv=False):
        info = dim.Dim()
        tile_dims = []
        dims = None
        enable_multicore = None
        dynamic = False
        partial_dynamic = False
        bypass_l1 = False
        if "dynamic" in arg:
            dynamic = True
            if isinstance(arg, tuple):
                arg = list(arg)
                arg.remove("dynamic")
                arg = tuple(arg)
            else:
                arg.remove("dynamic")
        if "partial_dynamic" in arg:
            partial_dynamic = True
            arg.remove("partial_dynamic")
        if "bypassL1" in arg:
            bypass_l1 = True
            arg.remove("bypassL1")
        if is_conv:
            dy = dynamic or partial_dynamic
            if len(arg) == 4:
                conv_tile = arg[3]
                if len(conv_tile) > 0:
                    if not dy:
                        return {
                            "dim": str(info),
                            "conv_tile": conv_tile,
                            "enable_multicore": True,
                            "bypass": 1 if bypass_l1 else 0,
                        }
                    else:
                        return {
                            "dim": str(info),
                            "conv_tile": conv_tile,
                            "dynamic": dynamic,
                            "partial_dynamic": partial_dynamic,
                            "bypass": 1 if bypass_l1 else 0,
                        }
            elif dy and len(arg) == 3:
                return {
                    "dynamic": dynamic,
                    "partial_dynamic": partial_dynamic,
                    "bypass": 1 if bypass_l1 else 0,
                }

        if len(arg) == 5 and not arg[-1]:
            dims = arg[3]
            for d in range(len(dims)):
                tile_dims.append(dims[d][0])
        elif (len(arg) == 5 and arg[-1]) or len(arg) == 4:
            if isinstance(arg[3], (bool, int)):  # only multicore info
                enable_multicore = arg[3]
            elif isinstance(arg[3][-1],
                            (bool, int)):  # dim info and multicore info
                enable_multicore = arg[3][-1]
                dims = arg[3][0]
            else:  # only dim info
                dims = arg[3]
            if dims is not None:
                for i in range(len(dims)):
                    if (isinstance(dims[i][0], int)):
                        # only one index, ((l1,l0),(l1,l0),...)
                        i_dims = dims
                    else:
                        # multiple indices, (((l1,l0),(l1,l0),...), ((l1,l0),(l1,l0),...))
                        i_dims = dims[i]

                    for d in range(len(i_dims)):
                        info.setdim(index=i,
                                    axis=d,
                                    tilel1=i_dims[d][0],
                                    tilel0=i_dims[d][1])

        if len(arg) == 5 and not arg[-1]:
            return {"tile": tile_dims}
        else:
            res = {"dim": str(info), "dynamic": dynamic}
            if enable_multicore:
                res["enable_multicore"] = enable_multicore
            return res
Esempio n. 19
0
def group_conv_ad(_n, _h, _w, _c_i, _c_o, group, _k_h, _k_w, pad_h, pad_w, _s_h, _s_w,
                  cut_h, cut_co, cut_m, cut_k, cut_n, block_size, use_bias=False, kernel_name='group_conv'):
    conv_dtype = 'float16'
    _a = akg.tvm.placeholder((_n, _c_i // block_size, _h, _w, block_size), name="input0", dtype=conv_dtype)
    _b = akg.tvm.placeholder(((_c_i // group) // block_size * _k_h * _k_w, _c_o // block_size, block_size, block_size),
                             name="input1", dtype=conv_dtype)

    mod_forward = group_conv_forward(_n, _h, _w, _c_i, _c_o, group, _k_h, _k_w, _a, _b, None,
                                     pad_h, pad_w, _s_h, _s_w, cut_h, cut_co, cut_m, cut_k, cut_n, block_size)
    _o_h = mod_forward.shape[2].value
    _o_w = mod_forward.shape[3].value


    head = akg.tvm.placeholder(mod_forward.shape, name="head", dtype=conv_dtype)
    # (_n,_c_o,_o_h,_o_w)--(stride)-->(_n,_c_o,(_o_h-1)*_s_h+1,
    # (_o_w-1)*_s_w+1)--(5d)-->(_n,_c_o/16,(_o_h-1)*_s_h+1,(_o_w-1)*_s_w+1,16)
    pld_head_strided = akg.tvm.placeholder((_n, _c_o // block_size, (_o_h - 1) * _s_h + 1, (_o_w - 1) * _s_w + 1, block_size),
                                       name="head_strided_5d", dtype=conv_dtype)

    # (_c_o,_c_i//group,_k_h,_k_w)--(flip)-->
    # (_c_i,_c_o//group,_k_h,_k_w)--(Fractal)-->((_c_o//group)/16*_k_h*_k_w, _c_i/16,16,16)
    pld_b_flipped = akg.tvm.placeholder(((_c_o // group) // block_size * _k_h * _k_w, _c_i // block_size, block_size, block_size),
                                    name="b_flip", dtype=conv_dtype)

    # b in Fractal format; result in Fractal format
    b_group_flipped = group_flip_weight(_b, _k_h, _k_w, group, _c_o // group // block_size, _c_i // group // block_size, block_size)
    s_gr_fl = akg.tvm.create_schedule([b_group_flipped.op])
    info = dim.Dim()
    info.setdim(index=0, axis=0, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=1, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=2, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=3, tilel1=1, tilel0=1)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False):
        mod_b_group_flip = akg.build(s_gr_fl, [_b, b_group_flipped], "cce", name="b_group_flip",
                                    attrs={"dim": str(info)}, polyhedral=True)

    head_strided = strided_head(head, _s_h, _s_w)
    s_striding = akg.tvm.create_schedule(head_strided.op)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False):
        mod_head_strided = akg.build(s_striding, [head, head_strided], "cce", name="h_strided",
                                    attrs={"dim": str(info)}, polyhedral=True)


    a_transposed = transpose_regroup(_a, block_size, group)
    s_transposed_nc = akg.tvm.create_schedule(a_transposed.op)
    info = dim.Dim()
    info.setdim(index=0, axis=0, tilel1=16, tilel0=16)
    info.setdim(index=0, axis=1, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=2, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=3, tilel1=1, tilel0=1)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_transposed_nc = akg.build(s_transposed_nc, [_a, a_transposed], "cce", name="a_transposed",
                                     attrs={"dim": str(info)}, polyhedral=True)

    head_transposed_convert = transpose_convert_head(head, block_size)
    s_transposed_convert = akg.tvm.create_schedule(head_transposed_convert.op)
    info = dim.Dim()
    info.setdim(index=0, axis=0, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=1, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=2, tilel1=1, tilel0=1)
    info.setdim(index=0, axis=3, tilel1=1, tilel0=1)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_transposed_convert = akg.build(s_transposed_convert, [head, head_transposed_convert], "cce",
                                           name="a_transposed", attrs={"dim": str(info)}, polyhedral=True)


    # Begin with the ad kernels
    ad_attrs = {"ad_conv_enable": 1}
    _jacs_data = list(akg.differentiate(mod_forward, [_a], head, ad_attrs, [pld_head_strided, pld_b_flipped, None]))

    cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e = ((_o_h - 1) * _s_h + 1 + 2 * (_k_h - 1 - pad_h), 16, _h * _w, 48, 16)
    cut_m_e = ((cut_m_e + block_size - 1) // block_size) * block_size

    info = set_dims_group(cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e,
                          expr_to_int(_a.shape), _c_o, _c_i, group, _k_h, _k_w, _s_h, block_size)

    s_data = akg.tvm.create_schedule([_jacs_data[0].op])
    # low_data = akg.lower(s_data, [pld_head_strided, pld_b_flipped, _jacs_data[0]], simple_mode=True)

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=False):
        mod_ad_data = akg.build(s_data, [pld_head_strided, pld_b_flipped, _jacs_data[0]], "cce",
                                name="conv_ad_data", attrs={"dim": info}, polyhedral=True)

    # (_n,_c_i,_h,_w)--(trans)-->(_c_i,_n,_h,_w)--(regroup)-->
    # (_c_i//group,_n*group,_h,_w)--(5d)-->(_c_i//group,(_n*group)/16,_h,_w,16)
    pld_x_trans = akg.tvm.placeholder((_c_i // group, (_n * group) // block_size, _h, _w, block_size),
                                      name="x_trans_5d", dtype=conv_dtype)

    # (_n,_c_o,_o_h,_o_w)--(trans)-->
    # (_c_o,_n,_o_h,_o_w)--(Fractal)-->(_n/16*_o_h*_o_w, _c_o/16,16,16)
    pld_head_trans_converted = akg.tvm.placeholder((_n // block_size * _o_h * _o_w, _c_o // block_size, block_size, block_size),
                                                   name="head_trans_convert", dtype=conv_dtype)

    # ad_attrs = {"ad_conv_enable": 1}
    _jacs_weights = list(akg.differentiate(mod_forward, [_b], head, ad_attrs,
                                           [pld_x_trans, pld_head_trans_converted, None]))

    cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e = (_h + 2 * pad_h, 16, _k_h * _k_w, 48, 16)
    cut_m_e = ((cut_m_e + block_size - 1) // block_size) * block_size

    info = set_dims_group(cut_h_e, cut_co_e, cut_m_e, cut_k_e, cut_n_e,
                          (_c_i // group, _c_o // block_size, _k_h, _k_w, block_size),
                          _n * group, _c_o, group, _o_h, _o_w, 1, block_size)

    s_weights = akg.tvm.create_schedule([_jacs_weights[0].op])

    with akg.build_config(add_lower_pass=debug_mode(0), dump_pass_ir=True):
        mod_ad_weights = akg.build(s_weights, [pld_x_trans, pld_head_trans_converted, _jacs_weights[0]], "cce",
                                   name="conv_ad_weights", attrs={"dim": info}, polyhedral=True)


    print("Forward input data shape: ", _a.shape)
    print("Forward input weight shape: ", _b.shape)
    print("Forward output shape: ", mod_forward.shape)
    print("Backward wrt. DATA input data shape: ", pld_head_strided.shape)
    print("Backward wrt. DATA input weight shape: ", pld_b_flipped.shape)
    print("Backward wrt. DATA output shape: ", _jacs_data[0].shape)
    print("Backward wrt. WEIGHT input data shape: ", pld_x_trans.shape)
    print("Backward wrt. WEIGHT input weight shape: ", pld_head_trans_converted.shape)
    print("Backward wrt. WEIGHT output shape: ", _jacs_weights[0].shape)

    return mod_ad_data, mod_ad_weights, mod_b_group_flip, mod_head_strided, mod_transposed_nc, mod_transposed_convert
Esempio n. 20
0
def cast_conv_set_dim_func(data,
                           fmap_shape,
                           filter_shape,
                           pad_,
                           stride_,
                           dilation_,
                           use_bias=False,
                           block_size=16,
                           attrs=None):

    if isinstance(stride_, int):
        stride_ = [stride_] * 2
    elif isinstance(stride_, (list, tuple)) and 1 == len(stride_):
        stride_ = list(stride_) * 2
    elif isinstance(stride_, (list, tuple)) and 2 == len(stride_):
        pass
    else:
        raise RuntimeError('stride para illegal !!!')

    if isinstance(pad_, int):
        pad_ = [pad_] * 4
    elif isinstance(pad_, (list, tuple)) and 1 == len(pad_):
        pad_ = list(pad_) * 4
    elif isinstance(pad_, (list, tuple)) and 4 == len(pad_):
        pass
    else:
        raise RuntimeError('pad para illegal !!!')

    if isinstance(dilation_, int):
        dilation_ = [dilation_] * 2
    elif isinstance(dilation_, (list, tuple)) and 1 == len(dilation_):
        dilation_ = list(dilation_) * 2
    elif isinstance(dilation_, (list, tuple)) and 2 == len(dilation_):
        pass
    else:
        raise RuntimeError('dilation para illegal !!!')

    key = []

    key.append(tuple(fmap_shape))
    key.append(tuple(filter_shape))
    key.append(tuple(pad_))
    key.append(tuple(stride_))
    key.append(tuple(dilation_))

    hash_key = str(tuple(key))

    # input shape (NCHW -> NC1HWC0)
    in_n, in_c, in_h, in_w = fmap_shape
    in_c = (in_c + block_size - 1) // block_size * block_size

    # kernel shape (NCHW -> NC1HWC0 -> Fractal)
    k_n, k_c, k_h, k_w = filter_shape
    k_c = (k_c + block_size - 1) // block_size * block_size
    k_n = (k_n + block_size - 1) // block_size * block_size

    # padding((padding_h, padding_w) -> (padding_top, padding_bottom, padding_left, padding_right))
    padding = (pad_[0], pad_[0], pad_[1], pad_[1])
    p_top, p_bottom, p_left, p_right = padding

    # stride (stride_h, stride_w)
    s_h, s_w = stride_

    # dilation (dilation_h, dilation_w)
    d_h, d_w = dilation_

    k_w_d = (k_w - 1) * d_w + 1
    out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
    bypass_list = [0, 1]
    bypass = 0
    if attrs is not None and 'conv_tile' in attrs and len(
            attrs['conv_tile']) >= 5:
        tile_hh = attrs['conv_tile'][0]
        tile_coco = attrs['conv_tile'][1]
        tile_mm = attrs['conv_tile'][2]
        tile_kk = attrs['conv_tile'][3]
        tile_nn = attrs['conv_tile'][4]
        if len(attrs['conv_tile']) > 5:
            tile_ww = attrs['conv_tile'][5]
        else:
            tile_ww = (out_w - 1) * s_w + k_w_d
        if 'bypass' in attrs:
            bypass = attrs['bypass']
    elif hash_key in cast_conv_set_dim_map:
        configs = cast_conv_set_dim_map[hash_key]
        if isinstance(configs, tuple):
            tiles = configs[0]
            if "bypass" in configs[1]:
                bypass = configs[1]["bypass"]
        else:
            tiles = configs
        if len(tiles) > 5:
            tile_hh, tile_coco, tile_mm, tile_kk, tile_nn, tile_ww = tiles
        else:
            tile_hh, tile_coco, tile_mm, tile_kk, tile_nn = tiles
            tile_ww = (out_w - 1) * s_w + k_w_d
    else:
        tile_hh = (k_h - 1) * d_h + 1 + p_top * s_h
        tile_ww = (out_w - 1) * s_w + k_w_d
        tile_coco = 16
        tile_mm = 16
        tile_kk = 16
        tile_nn = 16
    if not (bypass in bypass_list):
        raise RuntimeError("conv_cce ony supports %s while bypass is %d" %
                           (",".join(str(bypass_list)), bypass))

    if (tile_hh == in_h):
        tile_hh += p_top + p_bottom
    tile_coco = (tile_coco + block_size - 1) // block_size * block_size
    tile_mm = (tile_mm + block_size - 1) // block_size * block_size
    tile_kk = (tile_kk + block_size - 1) // block_size * block_size
    tile_nn = (tile_nn + block_size - 1) // block_size * block_size

    c0 = block_size
    c1_cut = tile_coco // c0
    h_window_cut = (tile_hh - k_h) // s_h + 1

    out_w = (in_w + p_left + p_right - k_w) // (s_w) + 1

    input_shape_nc1hwc0 = (in_n, in_c // block_size, in_h, in_w, block_size)
    in_n, in_c1, in_h, in_w, in_c0 = input_shape_nc1hwc0

    kernel_shape_nc1hwc0 = (k_n, k_c // block_size, k_h, k_w, block_size)
    k_n, k_c1, k_h, k_w, k_c0 = kernel_shape_nc1hwc0

    k_h_d = (k_h - 1) * d_h + 1
    k_w_d = (k_w - 1) * d_w + 1
    out_h = (in_h + p_top + p_bottom - k_h_d) // (s_h) + 1
    tile_out_h = (tile_hh - k_h_d) // s_h + 1
    out_w = (in_w + p_left + p_right - k_w_d) // (s_w) + 1
    tile_out_w = (tile_ww - k_w_d) // s_w + 1

    out_shape_nc1hwc0 = (in_n, k_n // block_size, out_h, out_w, block_size)
    out_n, out_c1, out_h, out_w, out_c0 = out_shape_nc1hwc0

    if (tile_coco > 0):
        c1_cut = tile_coco // block_size
    else:
        c1_cut = out_c1

    # set dim
    info = dim.Dim()
    if (out_n > 1):
        info.setdim(index=0, axis=0, tilel1=1, tilel0=0)  # n
    if (out_c1 > 1):
        info.setdim(index=0, axis=0, tilel1=c1_cut, tilel0=0)  # c1
    if (out_h > 1):
        info.setdim(index=0, axis="H", tilel1=tile_out_h, tilel0=0)  # h
    if (out_w > 1):
        info.setdim(index=0, axis="W", tilel1=tile_out_w, tilel0=0)  # w
    if (out_c0 > 1):
        info.setdim(index=0, axis=4, tilel1=out_c0, tilel0=0)  # c0

    if (in_c1 > 1):
        info.setdim(index=0, axis=5, tilel1=in_c1, tilel0=0)  # kc1
    if (k_h > 1):
        info.setdim(index=0, axis=5, tilel1=k_h, tilel0=0)  # kh
    if (k_w > 1):
        info.setdim(index=0, axis=5, tilel1=k_w, tilel0=0)  # kw

    return str(info)  # ct_util.set_dims_by_key(hash_key, conv_set_dim_map)