Ejemplo n.º 1
0
def CusMatMulCube(input_x1,
                  input_x2,
                  bias=None,
                  output_y={},
                  trans_a=False,
                  trans_b=False,
                  kernel_name="matmulcube"):
    """
    calculating  matrix multiplication with bias, C = A*B + bias, support input
    data with fractal format.

    Parameters:
    shape_a: list or tuple
            Shape of the first tensor a with rank > 1
    shape_b:  list or tuple
            Shape of the second tensor b with the same type with a,
            and shape_a, shape_b must be 2 dims
    src_dtype: str
            The data type of input, support "float32", "float16"
    dst_dtype: str
            The data type of output, support "float32", "float16"
    trans_a: bool
            If True, shape_a == transposed before multiplication
    trans_b: bool
            If True, shape_b == transposed before multiplication
    is_fractal: bool
            If True, the input data format of a and b must be fractal format
    shape_bias: list or tuple
            Shape of bias, only support the input data format with ND

    Returns
    -------
    None
    """
    shape_a = input_x1.get("ori_shape")
    shape_b = input_x2.get("ori_shape")

    if shape_a is not None:
        if len(shape_a) < 2:
            shape_a = input_x1.get("shape")

    if shape_b is not None:
        if len(shape_b) < 2:
            shape_b = input_x2.get("shape")

    shape_a = list(shape_a)
    shape_b = list(shape_b)

    if input_x1.get("format") == "FRACTAL_NZ":
        shape_a = _get_input_shape(shape_a)
        shape_b = _get_input_shape(shape_b)

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_a)
    util.check_shape_rule(shape_b)
    util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)

    if input_x1.get("format") == "FRACTAL_NZ":
        shape_a = [shape_a[1], shape_a[0]]
        trans_a = bool(1 - trans_a)

    if input_x2.get("format") == "FRACTAL_NZ":
        shape_b = [shape_b[1], shape_b[0]]
        trans_b = bool(1 - trans_b)

    shape_bias = ()
    if bias is not None and bool(bias):
        shape_bias = bias.get("shape")
        shape_bias = list(shape_bias)
        shape_bias = _get_bias(shape_bias)

    src_dtype = input_x1.get("dtype").lower()
    dst_dtype = output_y.get("dtype").lower()
    if src_dtype in ("float32", "int32"):
        matmul_vector_cce(shape_a, shape_b, src_dtype, trans_a, trans_b,
                          shape_bias, kernel_name)
        return
    _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)
    m_shape = shape_a[len(shape_a) - 2]
    km_shape = shape_a[len(shape_a) - 1]
    kn_shape = shape_b[len(shape_a) - 2]
    n_shape = shape_b[len(shape_a) - 1]

    if src_dtype == "float16":
        block_reduce = cce.BLOCK_REDUCE

    block_in = cce.BLOCK_IN
    block_out = cce.BLOCK_OUT

    if trans_a and km_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if not trans_a and m_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if trans_b and kn_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if not trans_b and n_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if trans_a:
        shape_a_temp = (m_shape // block_reduce, km_shape // block_in,
                        block_reduce, block_in)
    else:
        shape_a_temp = (m_shape // block_in, km_shape // block_reduce,
                        block_in, block_reduce)

    if trans_b:
        shape_b_temp = (kn_shape // block_out, n_shape // block_reduce,
                        block_reduce, block_out)
    else:
        shape_b_temp = (kn_shape // block_reduce, n_shape // block_out,
                        block_out, block_reduce)

    if input_x1.get("format") == "FORMAT_FRACTAL_Z":
        shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2],
                        shape_a_temp[3])
        format_a = "fractal"
    elif input_x1.get("format") == "FRACTAL_NZ":
        shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2],
                        shape_a_temp[3])
        format_a = "FRACTAL_NZ"
    else:
        shape_a_temp = (shape_a[len(shape_a) - 2], shape_a[len(shape_a) - 1])
        format_a = "ND"

    if input_x2.get("format") == "FORMAT_FRACTAL_Z":
        shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2],
                        shape_b_temp[3])
        format_b = "fractal"
    elif input_x2.get("format") == "FRACTAL_NZ":
        shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2],
                        shape_b_temp[3])
        format_b = "FRACTAL_NZ"
    else:
        shape_b_temp = (shape_b[len(shape_b) - 2], shape_b[len(shape_b) - 1])
        format_b = "ND"

    tensor_bias = None
    tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', dtype=src_dtype)
    tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', dtype=src_dtype)

    if shape_bias:
        tensor_bias = tvm.placeholder(shape_bias,
                                      name='tensor_bias',
                                      dtype=dst_dtype)
    result = te.lang.cce.matmul(tensor_a,
                                tensor_b,
                                trans_a,
                                trans_b,
                                format_a=format_a,
                                format_b=format_b,
                                dst_dtype=dst_dtype,
                                tensor_bias=tensor_bias)

    with tvm.target.cce():
        schedule = generic.auto_schedule(result)

    tensor_list = [tensor_a, tensor_b, result]
    if shape_bias:
        tensor_list = [tensor_a, tensor_b, tensor_bias, result]

    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(schedule, config)
Ejemplo n.º 2
0
def check_conv3dbp_input_params(
        shape_filter,  # pylint:disable=R0913,R0914,R0915
        shape_out_backprop,
        input_sizes,
        strides,
        pads,
        dilations,
        filter_dtype,
        out_backprop_dtype,
        res_dtype,
        kernel_name):
    """
    The params check function of conv3d backprop input

    Parameters:
    -------------------------
    shape_filter : The shape of filter.
                   5-D with shape (depth, height, weight, batch, channels)

    shape_out_backprop : The shape of gradients.
                         5-D with shape[batch, depth, height, weight,channels]

    input_sizes : The shape of feature map.
                  5-D with shape [batch, depth, height, weight, channels].

    strides : A list of ints. The stride of the sliding window.

    pads : A list of ints.

    dilations : An optional list of ints. Only support [1, 1, 1, 1, 1] now.

    filter_dtype : The dtype of filter data. Default value is float16.

    out_backprop_dtype : The dtype of gradients data. Default value is float16

    res_dtype : The dtype of result(De/Dx) data. Default value is float16.

    kernel_name : Cce kernel name.
                  Default value is "conv3d_backprop_intput_cce"

    Returns : All transformed params.


    """
    def _check_attr_range(attr_name, attr_value, attr_min, attr_max):
        if attr_value < attr_min or attr_value > attr_max:
            dict_args = {
                'errCode': 'E60011',
                'range': '[{},{}]'.format(attr_min, attr_max),
                'attr_name': attr_name,
                'value': str(attr_value)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

    def _check_64bits_limitation(attr_name, attr_value, dtype=None):
        if dtype is None:
            bit_ratio = BIT_RATIO_DICT.get("float16")
        else:
            bit_ratio = BIT_RATIO_DICT.get(dtype)
        if attr_value * bit_ratio > DATA_SIZE_MAX:
            dict_args = {
                'errCode': 'E60020',
                'attr_name': attr_name,
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

    def _check_l1_limitation():
        block_size = 16
        w_value = dedy_w * stride_w
        if fmap_w > block_size:
            h_value_max = filter_h_dilation + 1
        elif block_size % fmap_w == 0:
            h_value_max = filter_h_dilation + block_size // fmap_w - 1
        else:
            h_value_max = filter_h_dilation + block_size // fmap_w + 1

        a_l1_size = h_value_max * w_value * \
                    ((filter_d_dilation - 2)//stride_d + 2) * block_size * 2
        b_l1_size = filter_h_dilation * filter_w_dilation * \
                    filter_d_dilation * block_size * block_size * 2
        l1_size = get_soc_spec("L1_SIZE")
        if (a_l1_size + b_l1_size) > l1_size:
            dict_args = {'errCode': 'E60022'}
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

    def _check_shape_error():
        fmap_h_padding = fmap_h + pad_up + pad_down
        fmap_w_padding = fmap_w + pad_left + pad_right
        fmap_d_padding = fmap_deep + pad_head + pad_tail

        if fmap_channel != filter_channel:
            dict_args = {
                'errCode': 'E60108',
                'reason': "Shape error: Fmap's C must be equal to Filter'C."
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if dedy_channel != filter_batch:
            dict_args = {
                'errCode': 'E60108',
                'reason': "Shape error: Dedy's C must be equal to Filter'N."
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if fmap_batch != dedy_batch:
            dict_args = {
                'errCode': 'E62503',
                'backprop_N': str(dedy_batch),
                'forward_shape': str(fmap_batch)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if filter_h_dilation > fmap_h_padding:
            dict_args = {
                'errCode': 'E62507',
                'dim': 'H',
                'filter_dila': str(filter_h_dilation),
                'input_pad': str(fmap_h_padding)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if filter_w_dilation > fmap_w_padding:
            dict_args = {
                'errCode': 'E62507',
                'dim': 'W',
                'filter_dila': str(filter_w_dilation),
                'input_pad': str(fmap_w_padding)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if filter_d_dilation > fmap_d_padding:
            dict_args = {
                'errCode': 'E62507',
                'dim': 'D',
                'filter_dila': str(filter_d_dilation),
                'input_pad': str(fmap_d_padding)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if ((fmap_h - filter_h_dilation + pad_up + pad_down) // stride_h +
                1) != dedy_h:
            dict_args = {
                'errCode': 'E60024',
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if ((fmap_w - filter_w_dilation + pad_left + pad_right) // stride_w +
                1) != dedy_w:
            dict_args = {
                'errCode': 'E60025',
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if ((fmap_deep - filter_d_dilation + pad_head + pad_tail) // stride_d +
                1) != dedy_deep:
            dict_args = {
                'errCode': 'E62508',
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

    # Base check, Mainly required by interface appearance
    # ===========================================================
    # util check
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_filter, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(shape_out_backprop, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(input_sizes, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(strides, STRIDES_SHAPE_DIM, STRIDES_SHAPE_DIM,
                          DEFAULT_MAX_SHAPE_NUM)

    # pads check
    if isinstance(pads, (tuple, list)) and \
            len(pads) != CONV_BACKPROP_PAD_SHAPE_DIM:
        dict_args = {
            'errCode': 'E62501',
            'param_name': 'pads',
        }
        raise RuntimeError(dict_args, err_mana.get_error_message(dict_args))

    if isinstance(pads, str) and pads not in ['SAME', 'VALID']:
        dict_args = {
            'errCode': 'E60000',
            'param_name': 'pads',
            'expected_value': 'SAME or VALID',
            'input_value': str(pads),
        }
        raise RuntimeError(dict_args, err_mana.get_error_message(dict_args))
    # dilations check
    util.check_shape_rule(dilations, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    dilation_n, dilation_d, dilation_h, dilation_w, dilation_c = dilations
    if dilation_n != 1 or dilation_c != 1:
        dict_args = {
            'errCode': 'E60023',
            'dilation_n': str(dilation_n),
            'dilation_c': str(dilation_c),
        }
        raise RuntimeError(dict_args, err_mana.get_error_message(dict_args))

    # detype chek
    filter_dtype = filter_dtype.lower()
    out_backprop_dtype = out_backprop_dtype.lower()
    res_dtype = res_dtype.lower()
    util.check_dtype_rule(filter_dtype, ['float16'])
    util.check_dtype_rule(out_backprop_dtype, ['float16'])
    util.check_dtype_rule(res_dtype, ['float16'])

    # the relation limits between shape
    shape_filter = list(shape_filter)
    shape_out_backprop = list(shape_out_backprop)
    input_sizes = list(input_sizes)
    strides = list(strides)
    fmap_batch, fmap_deep, fmap_h, fmap_w, fmap_channel = input_sizes
    dedy_batch, dedy_deep, dedy_h, dedy_w, dedy_channel = shape_out_backprop
    filter_depth, filter_h, \
    filter_w, filter_channel, filter_batch = shape_filter
    _, stride_d, stride_h, stride_w, _ = strides

    filter_h_dilation = (filter_h - 1) * dilation_h + 1
    filter_w_dilation = (filter_w - 1) * dilation_w + 1
    filter_d_dilation = (filter_depth - 1) * dilation_d + 1

    if pads == 'SAME':
        pad_h = align(fmap_h, stride_h) - stride_h + filter_h - fmap_h
        pad_h = max(pad_h, 0)
        pad_up = pad_h // 2
        pad_down = pad_h - pad_up
        pad_w = align(fmap_w, stride_w) - stride_w + filter_w - fmap_w
        pad_w = max(pad_w, 0)
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left
        pad_d = align(fmap_deep, stride_d)\
                - stride_d + filter_depth - fmap_deep
        pad_d = max(pad_d, 0)
        pad_head = pad_d // 2
        pad_tail = pad_d - pad_head

        pads = [pad_head, pad_tail, pad_up, pad_down, pad_left, pad_right]
    elif pads == "VALID":
        pads = PADDING_VAILD
    # pads compute
    pads = list(pads)
    pad_head, pad_tail, pad_up, pad_down, pad_left, pad_right = pads

    fmap_h_padding = fmap_h + pad_up + pad_down
    fmap_w_padding = fmap_w + pad_left + pad_right

    # special cases
    dey_hw_min, fmap_hw_min = DEDY_HW_MIN, FMAP_HW_MIN
    # limitation by chip:
    # if kernel h,w in [1,11] and fmap h/w after padding equals to filter h/w
    # load3d support h,w is 1
    if (1 <= filter_h <= 11) and (1 <= filter_w <= 11) \
            and (fmap_h_padding == filter_h or fmap_w_padding == filter_w):
        dey_hw_min = 1
        fmap_hw_min = 1
    _check_shape_error()
    _check_l1_limitation()

    # Dedy value limit
    _check_attr_range("Dedy's H after expands", dedy_h * stride_h, dey_hw_min,
                      DEDY_HW_MAX)
    _check_attr_range("Dedy's W after expands", dedy_w * stride_w, dey_hw_min,
                      DEDY_HW_MAX)

    # filter value limit
    _check_attr_range("filter's H", filter_h, FILTER_HW_MIN, FILTER_HW_MAX)
    _check_attr_range("filter's W", filter_w, FILTER_HW_MIN, FILTER_HW_MAX)
    _check_attr_range("filter's D", filter_depth, FILTER_HW_MIN, FILTER_D_MAX)

    _check_attr_range("filter H*W", filter_h * filter_w, FILTER_HW_MIN,
                      FILTER_HW_SIZE)

    _check_attr_range("filter H*W*D", filter_h * filter_w * filter_depth,
                      FILTER_HW_MIN, KHWD_COEFF)

    # Fmap value limit
    _check_attr_range("Fmap's H", fmap_h, fmap_hw_min, FMAP_HW_MAX)
    _check_attr_range("Fmap's W", fmap_w, fmap_hw_min, FMAP_HW_MAX)

    # stride value limit
    _check_attr_range("stride's H", stride_h, STRIDE_HW_MIN, STRIDE_HW_MAX)
    _check_attr_range("stride's W", stride_w, STRIDE_HW_MIN, STRIDE_HW_MAX)
    _check_attr_range("stride's H*W", stride_h * stride_w, STRIDE_HW_MIN,
                      STRIDE_SIZE_MAX)
    _check_attr_range("stride's H*W*D", stride_h * stride_w * stride_d,
                      STRIDE_HW_MIN, STRIDE_SIZE_HWD_MAX)

    # check shape size, 64 bits limitation
    # ===========================================================
    c0_size = cce_params.C0_SIZE
    fmap_size = fmap_batch * align(fmap_channel, c0_size) \
                * fmap_deep * fmap_h * fmap_w
    dedy_size = dedy_batch * align(dedy_channel, c0_size) \
                * dedy_deep * dedy_h * dedy_w
    filter_size = align(filter_batch, c0_size) * \
    align(filter_channel, c0_size) * filter_depth * filter_h * filter_w
    _check_64bits_limitation("input", fmap_size, dtype=res_dtype)
    _check_64bits_limitation("out_backprop",
                             dedy_size,
                             dtype=out_backprop_dtype)
    _check_64bits_limitation("filter", filter_size, dtype=filter_dtype)

    result = (shape_filter, shape_out_backprop, input_sizes, strides, pads,
              dilations, filter_dtype, out_backprop_dtype, res_dtype,
              kernel_name)
    return result
Ejemplo n.º 3
0
def custom_minimum(shape1,
                   shape2,
                   dtype,
                   kernel_name="cce_tf_minimum",
                   need_build=False,
                   need_print=False):
    """
    do element-wise minimum operation between two input tensors

    Parameters:
    ----------
    shape1 : shape of input data1

    shape2 : shape of input data2

    dtype : source data type, support float16,float32,int32

    kernel_name : cce kernel name, default value is "cce_tf_minimum"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape1)
    util.check_shape_rule(shape2)
    util.check_shape_size(shape1, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape2, SHAPE_SIZE_LIMIT)

    check_list = ["float16", "float32", "int32"]

    dtype = dtype.lower()
    if dtype not in check_list:
        raise RuntimeError("tf_minimum_cce only support %s while dtype is %s" %
                           (",".join(check_list), dtype))

    shape1, shape2, shape_max = util.produce_shapes(shape1, shape2)
    util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT)

    data1 = tvm.placeholder(shape1, dtype=dtype, name="data1")
    data2 = tvm.placeholder(shape2, dtype=dtype, name="data2")

    with tvm.target.cce():
        data1_tmp1 = te.lang.cce.broadcast(data1, shape_max)
        data2_tmp1 = te.lang.cce.broadcast(data2, shape_max)
        res = te.lang.cce.vmin(data1_tmp1, data2_tmp1)

        sch = generic.auto_schedule(res)

    config = {
        "print_ir": need_print,
        "need_build": need_build,
        "name": kernel_name,
        "tensor_list": [data1, data2, res]
    }
    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 4
0
def conv2d_backprop_input_d(
        filter,  # pylint: disable=W0622,C0103,R0913,R0914
        out_backprop,
        y,
        input_size,
        strides,
        pads,
        dilations=(1, 1, 1, 1),
        groups=None,
        data_format="NHWC",
        kernel_name="conv2d_backprop_input"):
    """
    algorithm: conv2d_backprop_input

    Parameters
    ----------
    filter: dict with keys(shape and dtype)
            input weight tensor

    out_backprop: dict with keys(shape and dtype)
                  The shape of gradients.

    y: dict with keys(shape and dtype)
       conv2d_backprop_input output tensor, dtype must be assigned

    input_size: The shape of feature map.
                 4-D with shape [batch, channels, height, weight].

    strides: tuple/list of 4 integers
             filter move stride

    pads: tuple/list of 4 integers
             [pad_top, pad_bottom, pad_left, pad_right]

    dilations: tuple/list of 4 integers
               filter expand size of dilated conv2d_backprop_input
    groups: int
            param for group conv2d_backprop_input

    data_format: str
            An optional string from: "NHWC", "NCHW". Defaults to "NHWC".
            Specify the data format of the input and output data.

    kernel_name: str
                 kernel name, default value is "conv2d_backprop_input"

    Returns
    -------
    None
    """

    ori_shape_filters = filter.get("ori_shape")
    ori_shape_out_backprop = out_backprop.get("ori_shape")
    ori_shape_res = y.get("ori_shape")

    filters_dtype = filter.get("dtype")
    out_backprop_dtype = out_backprop.get("dtype")
    res_dtype = y.get("dtype")

    ori_format_filters = filter.get("ori_format")
    ori_format_out_backprop = out_backprop.get("ori_format")
    ori_format_res = y.get("ori_format")
    if list(input_size) != list(ori_shape_res):
        dict_args = {}
        dict_args['errCode'] = "E65007"
        dict_args['param1'] = "input_size"
        dict_args['param2'] = "ori_shape of y"
        dict_args['actual_value'] = "{}, {}". \
            format(input_size, ori_shape_res)
        raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(ori_shape_filters, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(ori_shape_out_backprop, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(input_size, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(ori_shape_res, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(dilations, CONV_BACKPROP_SHAPE_DIM,
                          CONV_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)

    if len(strides) == 4:
        h_index = data_format.find('H')
        w_index = data_format.find('W')
        strides = [strides[h_index], strides[w_index]]

    shape_filters = comm.get_filter_shape(ori_format_filters,
                                          ori_shape_filters)

    shape_out_backprop = comm.get_shape_out_backprop(ori_format_out_backprop,
                                                     ori_shape_out_backprop)

    shape_res = comm.get_shape_res(ori_format_res, ori_shape_res)

    dilations = comm.get_shape_dilation(data_format, dilations)

    conv2d_backprop_input_cce(shape_filters, shape_out_backprop, shape_res,
                              strides, pads, dilations, filters_dtype,
                              out_backprop_dtype, res_dtype, kernel_name)
def check_supported(input_x1,
                    input_x2,
                    bias=None,
                    output_y={},
                    trans_a=False,
                    trans_b=False,
                    kernel_name="matmulcube"):
    """check_supported"""
    shape_a = input_x1.get("shape")
    shape_b = input_x2.get("shape")
    print("shape_a: ", shape_a)
    print("shape_b: ", shape_b)
    src_dtype = input_x1.get("dtype")
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_a)
    util.check_shape_rule(shape_b)
    util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)
    try:
        trans_a_f = bool(1 - trans_a)
        if src_dtype in ("float32", "int32"):
            if len(shape_a) != 2 and len(shape_b) != 2:
                return False
            if trans_b:
                if shape_b[0] == 1:
                    return False
            else:
                if shape_b[1] == 1:
                    return False
            if trans_a:
                if trans_b:
                    if shape_a[0] != shape_b[1]:
                        return False
                elif shape_a[0] != shape_b[0]:
                    return False
            elif trans_b:
                if shape_a[1] != shape_b[1]:
                    return False
            elif shape_a[1] != shape_b[0]:
                return False

            if trans_a_f and trans_b and shape_b[1] == 1:
                return False

        if src_dtype == "float16":
            if len(shape_a) != 2 and len(shape_b) != 2:
                return False

            if trans_a:
                m_shape = shape_a[1]
                k_shape = shape_a[0]
            else:
                m_shape = shape_a[0]
                k_shape = shape_a[1]

            if trans_b:
                n_shape = shape_b[0]
                k_b_shape = shape_b[1]
            else:
                n_shape = shape_b[1]
                k_b_shape = shape_b[0]

            if k_shape != k_b_shape:
                return False

            if m_shape == 1 or n_shape == 1:
                if k_shape % 256 != 0:
                    return False

    except RuntimeError as e:
        print(e)
        return False

    return True
def fake_learned_scale_quant_perchannel_grad_d(
        dout,
        input_x,
        alpha,
        quant_max,
        dx,
        dalpha,
        neg_trunc,
        channel_axis,
        kernel_name="fake_learned_scale_quant_perchannel_grad_d"):
    """FakeLearnedScaleQuantPerChannelGradD"""
    input_shape = input_x.get("shape")
    input_x_shape_ = input_x.get("ori_shape")
    input_x_format = input_x.get("format")
    input_dtype = input_x.get("dtype")
    alpha_shape = alpha.get("ori_shape")
    alpha_dtype = alpha.get("dtype")
    quant_max_shape = quant_max.get("ori_shape")
    quant_max_dtype = quant_max.get("dtype")
    # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
    if channel_axis == 0 and input_x_shape_[0] != alpha_shape[
            0] and input_x_shape_[1] == alpha_shape[0]:
        channel_axis_ = 1
    else:
        channel_axis_ = channel_axis

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(input_shape)
    util.check_shape_rule(alpha_shape, 1, 1, input_x_shape_[channel_axis_])
    util.check_shape_rule(quant_max_shape, 1, 1, 1)
    util.check_tensor_shape_size(input_shape)
    util.check_tensor_shape_size(alpha_shape)
    util.check_tensor_shape_size(quant_max_shape)

    check_list = ["float32", "float16"]
    input_dtype = input_dtype.lower()
    alpha_dtype = alpha_dtype.lower()
    quant_max_dtype = quant_max_dtype.lower()
    util.check_dtype_rule(input_dtype, check_list)
    util.check_dtype_rule(alpha_dtype, check_list)
    util.check_dtype_rule(quant_max_dtype, check_list)

    shape_c = [1] * len(input_shape)
    shape_c[channel_axis_] = alpha.get("ori_shape")[0]
    if input_x_format == "NC1HWC0" and channel_axis_ == 1:
        shape_c = alpha.get("shape")

    dout_data = tvm.placeholder(input_shape, name="dout", dtype=input_dtype)
    input_data = tvm.placeholder(input_shape, name="x", dtype=input_dtype)
    alpha_data = tvm.placeholder(shape_c, name="alpha_data", dtype=alpha_dtype)
    quant_max_data = tvm.placeholder(quant_max_shape,
                                     name="quant_max_data",
                                     dtype=quant_max_dtype)
    res = fake_learned_scale_quant_perchannel_grad_d_compute(
        dout_data, input_data, alpha_data, quant_max_data, neg_trunc,
        kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res)

    tensor_list = [dout_data, input_data, alpha_data, quant_max_data
                   ] + list(res)
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 7
0
def decode_bbox(box_predictions,
                anchors,
                decoded_boxes,
                decode_clip,
                kernel_name="decode_bbox"):
    """
    calculating data

    Parameters
    ----------
    box_predictions : shape and dtype of input
    anchors : shape and dtype of input
    decoded_boxes : shape and dtype of output, s
                    hould be same shape and type as input
    decode_clip : decode_clip
    kernel_name : kernel name, default value is "decode_bbox"
    Returns
    -------
    None
    """

    # check param & data
    shape_box_predictions = box_predictions.get("shape")
    shape_anchors = anchors.get("shape")
    shape_decoded_boxes = decoded_boxes.get("shape")
    util.check_kernel_name(kernel_name)
    format_box_predictions = box_predictions.get("format")
    format_anchors = anchors.get("format")
    format_decoded_boxes = decoded_boxes.get("format")
    check_format_shape(format_box_predictions, format_anchors,
                       format_decoded_boxes)
    util.check_shape_rule(shape_box_predictions, CONFIG_THREE, CONFIG_FOUR,
                          None)
    util.check_shape_rule(shape_anchors, CONFIG_THREE, CONFIG_FOUR, None)
    util.check_shape_rule(shape_decoded_boxes, CONFIG_TWO, CONFIG_TWO, None)
    util.check_shape_size(shape_box_predictions, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_anchors, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_decoded_boxes, SHAPE_SIZE_LIMIT)
    util.check_dtype_rule(box_predictions.get("dtype").lower(), ("float16", ))
    util.check_dtype_rule(anchors.get("dtype").lower(), ("float16", ))
    util.check_dtype_rule(decoded_boxes.get("dtype").lower(), ("float16", ))
    if shape_box_predictions != shape_anchors:
        raise RuntimeError("the input shape_box_predictions and anchors)"
                           "must be same")
    if (reduce(lambda x, y: x * y, shape_box_predictions[:])) \
            != (reduce(lambda x, y: x * y, shape_decoded_boxes[:])):
        raise RuntimeError("the input shape (box_predictions and anchors"
                           "is not equal to out shape(decoded_boxes)")
    if (shape_box_predictions[-1] == CONFIG_FOUR
            and len(shape_box_predictions) == CONFIG_THREE):
        if shape_decoded_boxes[1] != CONFIG_FOUR:
            raise RuntimeError("the output shape_decoded_boxes must be 4")
    else:
        if (shape_box_predictions[0] == CONFIG_FOUR
                and len(shape_box_predictions) == CONFIG_FOUR):
            if shape_decoded_boxes[0] != CONFIG_FOUR:
                raise RuntimeError("the output shape_decoded_boxes must be 4")
        else:
            raise RuntimeError("the input shape not in {(4,C,H,W), (H,W,4)}")
    if not isinstance(decode_clip, (float, int)):
        raise RuntimeError("input param type of decode_clip should be Float")
    if decode_clip < 0 or decode_clip > 10:
        raise RuntimeError(
            "input param decode_clip can't be negtive and shoud be [0,10]! ")
    # init the tiling shape
    print("shape_box_predictions", shape_box_predictions)
    shape = TilingFunc(shape_box_predictions)
    # calculate the deocede_bbox
    tik_instance = tik.Tik(tik.Dprofile())
    data_tensor = InitTensor(tik_instance, shape)
    if shape.input_shape[-1] == CONFIG_FOUR \
            and len(shape.input_shape) == CONFIG_THREE:
        decode_bbox_compute(tik_instance, shape, data_tensor, decode_clip,
                            kernel_name)
    if shape.input_shape[0] == CONFIG_FOUR \
            and len(shape.input_shape) == CONFIG_FOUR:
        decode_bbox_compute_transpose(tik_instance, shape, data_tensor,
                                      decode_clip, kernel_name)
    return tik_instance
Ejemplo n.º 8
0
def custom_Exp(shape,
               dtype,
               gamma,
               alpha,
               beta,
               kernel_name="cce_exp",
               need_build=False,
               need_print=False):
    """
    calculate gamma **(alpha * data + beta), calculate exp(log(gamma) * alpha * data) * (gamma ** beta)
    
    Parameters
    ----------
    shape : shape of data

    dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32

    gamma : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma, base

    alpha : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma, scale

    beta : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma, shift

    kernel_name : cce kernel name, default value is "cce_exp"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
        
    """
    supported_dtypes = ["float16", "float32"]
    device_api = "DeviceExp"

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)

    if not (dtype.lower() in supported_dtypes):
        raise RuntimeError(
            "caffe_exp_layer_cce only support %s while dtype is %s" %
            (",".join(supported_dtypes), dtype))

    if gamma != -1 and gamma <= 0:  # api  cc_device_exp_c handle gamma == -1 as e
        raise ValueError(
            "please ensure gamma is greater than 0, where gamma = %s" %
            str(gamma))

    inp_dtype = dtype.lower()
    shape = util.shape_refine(shape)
    data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)

    v_datatype = util.get_device_api_dtype(inp_dtype)
    v_ndim = len(shape)
    block_num = "block_num"
    block_idx = "block_idx"
    padC0 = 0
    p_scale = util.create_param_ptr([alpha], inp_dtype, "p_scale")
    p_shift = util.create_param_ptr([beta], inp_dtype, "p_shift")
    p_base = util.create_param_ptr([gamma], inp_dtype, "p_base")
    p_shape = util.create_param_ptr(shape, "int32", "p_shape")

    # scale --> alpha, shitf --> beta, base --> gamma
    output = tvm.extern(
        shape,
        [data_input, p_scale, p_shift, p_base, p_shape],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_datatype,
            ins[1].access_ptr("r"),  # scale
            ins[2].access_ptr("r"),  # shift
            ins[3].access_ptr("r"),  # base
            v_ndim,
            ins[4].access_ptr("r"),  # shape
            padC0,
            ins[0].access_ptr("r"),  # input x
            outs[0].access_ptr("w")),
        name="output",
        dtype=inp_dtype)

    s = tvm.create_schedule(output.op)

    if need_print:
        with build_config:
            print(tvm.lower(s, [data_input, output], simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(s, [data_input, output], "cce", name=kernel_name)
def custom_logical_and(shape_x,
                       shape_y,
                       dtype,
                       kernel_name="cce_tf_logical_and",
                       need_build=False,
                       need_print=False):
    """
    do element-wise logical-and operation between two input tensors

    Parameters:
    ----------
    shape_x : shape of input data1

    shape_y : shape of input data2

    dtype : source data type, support "bool"

    kernel_name : cce kernel name, default value is "cce_tf_logical_and"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x)
    util.check_shape_rule(shape_y)

    check_list = ["bool"]
    if not (dtype.lower() in check_list):
        raise RuntimeError(
            "logical_and_cce only support %s while dtype is %s" %
            (",".join(check_list), dtype))

    util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT)

    inp_dtype = dtype.lower()

    shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y)
    data1 = tvm.placeholder(shape_x, dtype=inp_dtype, name="data1")
    data2 = tvm.placeholder(shape_y, dtype=inp_dtype, name="data2")

    with tvm.target.cce():
        data1_tmp1 = te.lang.cce.broadcast(data1, shape_max)
        data1_tmp2 = te.lang.cce.broadcast(data2, shape_max)

        min_value = tvm.const(0, dtype=inp_dtype)
        res = tvm.compute(
            shape_max,
            lambda *i: tvm.select(
                tvm.all(
                    tvm.any(
                        data1_tmp1(*i) > min_value,
                        data1_tmp1(*i) < -min_value),
                    tvm.any(
                        data1_tmp2(*i) > min_value,
                        data1_tmp2(*i) < -min_value)), True, False),
            name="res")

        sch = tvm.create_schedule(res.op)

    if need_print:
        with build_config:
            print(tvm.lower(sch, [data1, data2, res], simple_mode=True))

    if need_build:
        with build_config:
            tvm.build(sch, [data1, data2, res], "cce", name=kernel_name)
Ejemplo n.º 10
0
def custom_Power(shape,
                 dtype,
                 gamma,
                 alpha,
                 beta,
                 kernel_name="cce_caffe_power",
                 need_build=False,
                 need_print=False):
    """
    calculate (alpha * data + beta) ** gamma, calulation method exp(gamma * log(alpha * data + beta)).
    when alpha * data + beta < 0 , the output is a meaningless value.
    Parameters
    ----------
    shape : shape of data

    dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32

    gamma : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma

    alpha : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma

    beta : the data type must be same with dtype parameter
        args in (alpha * data + beta) ** gamma

    kernel_name : string
        kernel name in generated CCE kernal. default value is "cce_caffe_power"


    need_buid : bool
        if need to build CCEC kernel

    need_print : bool
        if need to print Halide IR

    Returns
    -------
    None
        
    """
    supported_dtypes = ["float16", "float32"]
    device_api = "cc_device_pow"

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)

    if not (dtype.lower() in supported_dtypes):
        raise RuntimeError("power_cce only support %s while dtype is %s" %
                           (",".join(supported_dtypes), dtype))

    inp_dtype = dtype.lower()
    shape = util.shape_refine(shape)
    data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)

    v_datatype = util.get_device_api_dtype(inp_dtype)
    v_ndim_x = len(shape)
    v_ndim_y = 0
    p_shape_y = 0
    p_input_y = "nullptr"
    block_num = "block_num"
    block_idx = "block_idx"
    padC0 = 0

    p_scale = util.create_param_ptr([alpha], inp_dtype, "p_scale")
    p_shift = util.create_param_ptr([beta], inp_dtype, "p_shift")
    p_power = util.create_param_ptr([gamma], inp_dtype, "p_power")
    p_shape_x = util.create_param_ptr(shape, "int32", "p_shape_x")

    # scale --> alpha, shitf --> beta, power --> gamma
    output = tvm.extern(
        shape,
        [data_input, p_scale, p_shift, p_power, p_shape_x],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_datatype,
            ins[1].access_ptr("r"),  # scale
            ins[2].access_ptr("r"),  # shift
            ins[3].access_ptr("r"),  # power
            v_ndim_x,
            ins[4].access_ptr("r"),  # shape
            padC0,
            ins[0].access_ptr("r"),  # input x
            v_ndim_y,
            v_ndim_y,
            p_shape_y,
            padC0,
            p_input_y,
            outs[0].access_ptr("w")),
        name="output",
        dtype=inp_dtype)

    s = tvm.create_schedule(output.op)

    if need_print:
        with build_config:
            print(tvm.lower(s, [data_input, output], simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(s, [data_input, output], "cce", name=kernel_name)
Ejemplo n.º 11
0
def batchnorm_fold(x,
                   x_sum,
                   x_square_sum,
                   mean,
                   variance,
                   y,
                   batch_mean,
                   batch_std,
                   running_mean,
                   running_std,
                   mean_updated,
                   variance_updated,
                   momentum=0.9,
                   epsilon=1e-5,
                   is_training=True,
                   freeze_bn=0,
                   data_format="NCHW",
                   kernel_name="batchnorm_fold"):
    """batchnorm_fold TBE op"""
    util.check_kernel_name(kernel_name)
    data_format = data_format.upper()
    if data_format != "NCHW":
        raise RuntimeError("The data_format only support NCHW")

    shape_x = x.get("shape")
    shape_mean = mean.get("shape")
    shape_variance = variance.get("shape")
    dtype_x = x.get("dtype")
    dtype_mean = mean.get("dtype")
    dtype_variance = variance.get("dtype")
    for shape in (shape_x, shape_mean, shape_variance):
        util.check_shape_rule(shape)
        util.check_tensor_shape_size(shape)
    check_tuple = ("float16", "float32")
    for dtype in (dtype_x, dtype_mean, dtype_variance):
        util.check_dtype_rule(dtype.lower(), check_tuple)

    format_data = x.get("format").upper()
    if format_data not in ("NCHW", "NC1HWC0"):
        raise RuntimeError("Format of input only support 4D and 5HD")

    if format_data == "NC1HWC0":
        if len(shape_x) != 5:
            raise RuntimeError("batchnorm_fold only support shape 5D"
                               "when input format is NC1HWC0")
        shape_mean = (1, shape_x[1], 1, 1, shape_x[4])
    elif format_data == "NCHW":
        if len(shape_x) < 2 or len(shape_x) > 4:
            raise RuntimeError("batchnorm_fold only support shape 2D to 4D")
        if shape_x[1] != shape_mean[0]:
            raise RuntimeError("data_format is NCHW, shape_bias must"
                               "be equal to the second axis of shape_x")
        shape_mean = (
            1,
            shape_x[1],
        )
        for _ in range(2, len(shape_x)):
            shape_mean = shape_mean + (1, )

    x_input = tvm.placeholder(shape_x, name="x_input", dtype=dtype_x.lower())
    x_sum = tvm.placeholder(shape_mean, name="x_sum", dtype=dtype_x.lower())
    x_square_sum = tvm.placeholder(shape_mean,
                                   name="x_square_sum",
                                   dtype=dtype_x.lower())
    mean = tvm.placeholder(shape_mean, name="mean", dtype=dtype_mean.lower())
    variance = tvm.placeholder(shape_mean,
                               name="variance",
                               dtype=dtype_variance.lower())

    shape_x = te.lang.cce.util.shape_to_list(x_input.shape)
    num = shape_x[0] * shape_x[2] * shape_x[3]
    num_rec = 1.0 / num

    # compute the mean of x
    batch_mean = te.lang.cce.vmuls(x_sum, num_rec)

    # compute the variance of x
    variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
    mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
    batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
    batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
    if num == 1:
        batch_var_scaler = 0.0
    else:
        batch_var_scaler = float(num) / (num - 1)
    batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)

    factor = 1.0 - momentum
    factor_reverse = momentum
    mean_mul = te.lang.cce.vmuls(batch_mean, factor)
    mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
    mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)

    var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
    var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
    variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)

    y = te.lang.cce.vadds(x_input, 0.0)
    running_mean = te.lang.cce.vadds(mean, 0.0)
    running_std = te.lang.cce.vsqrt(te.lang.cce.vadds(variance, epsilon))
    res = [
        y, batch_mean, batch_std, running_mean, running_std, mean_updated,
        variance_updated
    ]

    with tvm.target.cce():
        sch = generic.auto_schedule(res)
    config = {
        "name": kernel_name,
        "tensor_list": [x_input, x_sum, x_square_sum, mean, variance] + res
    }
    te.lang.cce.cce_build_code(sch, config)
def custom_truncatemod(shape1,
                       shape2,
                       dtype,
                       kernel_name="cce_tf_truncatemod",
                       need_build=False,
                       need_print=False):
    """
    do element-wise truncatemod operation between two input tensors

    Parameters:
    ----------
    shape1 : shape of input data1

    shape2 : shape of input data2

    dtype : source data type, support float16,float32,int32

    kernel_name : cce kernel name, default value is "cce_tf_truncatemod"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """
    max_dim = 8
    shape1_len = len(shape1)
    shape2_len = len(shape2)
    if shape1_len > max_dim or shape2_len > max_dim:
        raise RuntimeError(
            "mod_cce only support up to %d dimensions while the shape's dimensions is %d, %d"
            % (max_dim, shape1_len, shape2_len))
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape1)
    util.check_shape_rule(shape2)

    util.check_shape_size(shape1, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape2, SHAPE_SIZE_LIMIT)

    check_list = ["float16", "float32", "int32"]
    device_api_map = {
        "float16": "cc_device_truncatemod_float16",
        "float32": "cc_device_truncatemod_float",
        "int32": "cc_device_truncatemod_int32"
    }

    dtype = dtype.lower()
    if not (dtype in check_list):
        raise RuntimeError(
            "tf_truncatemod_cce only support %s while dtype is %s" %
            (",".join(check_list), dtype))

    shape1, shape2, shape_out = util.produce_shapes(shape1, shape2)
    util.check_shape_size(shape_out, SHAPE_SIZE_LIMIT)

    inp_dtype = dtype.lower()

    device_api = device_api_map[inp_dtype]

    ## block
    block_num = "block_num"
    block_idx = "block_idx"
    ## x param
    v_xndim_cnt = tvm.const(len(shape1), "int32")
    p_xshape = util.create_param_ptr(shape1, "int32", "p_xshape")
    xpadC0 = tvm.const(0, "int32")
    data_input_x = tvm.placeholder(shape1,
                                   name="data_input_x",
                                   dtype=inp_dtype)
    ## y param
    v_yndim_cnt = tvm.const(len(shape2), "int32")
    p_yshape = util.create_param_ptr(shape2, "int32", "p_yshape")
    ypadC0 = tvm.const(0, "int32")
    data_input_y = tvm.placeholder(shape2,
                                   name="data_input_y",
                                   dtype=inp_dtype)
    ## output
    v_out_ndim_cnt = tvm.const(len(shape_out), "int32")
    p_out_shape = util.create_param_ptr(shape_out, "int32", "p_yshape")
    out_padC0 = tvm.const(0, "int32")

    output = tvm.extern(
        shape_out,
        [p_xshape, data_input_x, p_yshape, data_input_y, p_out_shape],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_xndim_cnt,
            ins[0].access_ptr("r"),  # shape x
            xpadC0,
            ins[1].access_ptr("r"),  # input x
            v_yndim_cnt,
            ins[2].access_ptr("r"),  # shape y
            ypadC0,
            ins[3].access_ptr("r"),  # input y
            v_out_ndim_cnt,
            ins[4].access_ptr("r"),  # shape out
            out_padC0,
            outs[0].access_ptr("w")),
        name="output",
        dtype=inp_dtype)

    s = tvm.create_schedule(output.op)

    #print IR
    if need_print:
        with build_config:
            print(
                tvm.lower(s, [data_input_x, data_input_y, output],
                          simple_mode=True))
#Compile to generate the cce file
    if need_build:
        with build_config:
            tvm.build(s, [data_input_x, data_input_y, output],
                      "cce",
                      name=kernel_name)
Ejemplo n.º 13
0
def custom_tile(shape,
                multiples,
                dtype,
                kernel_name="cce_tile",
                need_build=False,
                need_print=False):
    """Operation and Schedule for tile, construct an array by repeating shape the number of times given by multiply_shape.

    Parameters
    ----------
    shape:shape of Tensor
    
    multiples:  shape of Tensor
    
    dtype: 
        the data type. only support float16, float32, int32, int8, uint8

    kernel_name : cce kernel name, default value is "cce_tile"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
        None
    """
    check_list = ["float16", "float32", "int32", "int8", "uint8"]
    if not (dtype.lower() in check_list):
        raise RuntimeError("tile_cce only support %s while dtype is %s" %
                           (",".join(check_list), dtype))
    tensor_l = []

    inp_dtype = dtype.lower()

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
    tensor_l.append(tvm.placeholder(shape, name="shape", dtype=inp_dtype))

    for i in range(len(multiples)):
        if not isinstance(multiples[i], int):
            raise RuntimeError("InvalidArgumentError: Expected int value")
        if multiples[i] < 0:
            raise RuntimeError(
                "InvalidArgumentError: Expected int value or multiples[%d] >= 0, but got %d!"
                % (i, multiples[i]))

    tensor_l.append(
        tvm.placeholder(multiples, name="multiples", dtype=inp_dtype))

    out_tensor = compute_tile_cce(a_tuple=tensor_l)

    s = schedule_tile_cce(out_tensor)
    if need_print:
        with build_config:
            print(
                tvm.lower(s, [tensor_l[0], tensor_l[1], out_tensor],
                          simple_mode=True))

    if need_build:
        with build_config:
            tvm.build(s, tensor_l + [out_tensor], "cce", name=kernel_name)
Ejemplo n.º 14
0
def custom_squared_difference(shape_x,
                              shape_y,
                              dtype,
                              kernel_name="cce_tf_squared_difference",
                              need_build=False,
                              need_print=False):
    """
    algorithm: tf_squared_difference

    calculating data's tf_squared_difference,y= (x - y) * (x - y)

    Parameters
    ----------
    shape_x : shape of input x

    shape_y : shape of input y

    dtype : the data type, assume src_dtype equals dst_dtype, only support \
    float16, float32, int32

    kernel_name : cce kernel name, default value is "cce_tf_squared_difference"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x)
    util.check_shape_rule(shape_y)
    util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT)

    check_list = ["float16", "float32", "int32"]

    if not dtype.lower() in check_list:
        raise RuntimeError(
            "tf_squared_difference_cce only support %s while dtype is %s" %
            (",".join(check_list), dtype))

    dtype = dtype.lower()

    shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y)
    util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT)

    data_x = tvm.placeholder(shape_x, dtype=dtype, name="data_x")
    data_y = tvm.placeholder(shape_y, dtype=dtype, name="data_y")

    with tvm.target.cce():
        data_x_tmp = te.lang.cce.broadcast(data_x, shape_max)
        data_y_tmp = te.lang.cce.broadcast(data_y, shape_max)
        data_sub = te.lang.cce.vsub(data_x_tmp, data_y_tmp)
        res = te.lang.cce.vmul(data_sub, data_sub)
        sch = generic.auto_schedule(res)

    config = {
        "print_ir": need_print,
        "need_build": need_build,
        "name": kernel_name,
        "tensor_list": [data_x, data_y, res]
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 15
0
def fake_quant_with_min_max_vars_ema(x,
                                     min_val,
                                     max_val,
                                     y,
                                     ema,
                                     ema_decay,
                                     symmetric,
                                     narrow_range,
                                     training,
                                     num_bits,
                                     quant_delay,
                                     kernel_name="fake_quant"):
    """FakeQuantWithMinMax"""
    input_shape = x.get("shape")
    input_dtype = x.get("dtype")
    min_shape = min_val.get("ori_shape")
    min_dtype = min_val.get("dtype")
    max_shape = max_val.get("ori_shape")
    max_dtype = max_val.get("dtype")

    min_shape = util.scalar2tensor_one(min_shape)
    max_shape = util.scalar2tensor_one(max_shape)
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(input_shape)
    util.check_shape_rule(min_shape, 1, 1, 1)
    util.check_shape_rule(max_shape, 1, 1, 1)
    util.check_tensor_shape_size(input_shape)
    util.check_tensor_shape_size(min_shape)
    util.check_tensor_shape_size(max_shape)

    check_list = ["float32", "float16"]
    x_dtype = input_dtype.lower()
    min_dtype = min_dtype.lower()
    max_dtype = max_dtype.lower()
    util.check_dtype_rule(x_dtype, check_list)
    util.check_dtype_rule(min_dtype, check_list)
    util.check_dtype_rule(max_dtype, check_list)

    input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]), )
    shape_min, _, _ = util.produce_shapes(min_shape, input_shape)

    if symmetric:
        quant_min = 0 - 2**(num_bits - 1)
        quant_max = 2**(num_bits - 1) - 1
    else:
        quant_min = 0
        quant_max = 2**num_bits - 1
    if narrow_range:
        quant_min = quant_min + 1

    input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
    min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
    max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
    res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data,
                                                   max_data, y, quant_min,
                                                   quant_max, kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res)

    tensor_list = [input_data, min_data, max_data, res]
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 16
0
def write_select(input_x, output_x, kernel_name="write_select"):
    """
    Write data with offset

    Parameters
    ----------
    input_x : dict
        shape and dtype of input
    output_x : dict
        shape and dtype of output, should be same shape and type as input
    kernel_name : str
        kernel name, default value is "write_select"

    Returns
    -------
    None
    """
    input_shape = input_x.get("shape")
    input_dtype = input_x.get("dtype").lower()
    valid_shape = output_x.get("valid_shape")

    util.check_shape_rule(input_shape)
    util.check_shape_rule(valid_shape)
    util.check_tensor_shape_size(input_shape)
    util.check_tensor_shape_size(valid_shape)
    util.check_kernel_name(kernel_name)

    if util.is_lhisi_version():
        check_list = ["int32", "float16", "int8"]
    else:
        check_list = ["int32", "float16", "float32", "int8"]

    if input_dtype not in check_list:
        raise RuntimeError("write_select only support %s while dtype is %s" %
                           (",".join(check_list), input_dtype))

    if len(valid_shape) != PARA_LIST_LEN:
        raise RuntimeError("the len of valid shape should be 5")

    dst_out_flag = "DDR"
    if "dst_out_flag" in output_x:
        dst_out_flag = output_x.get("dst_out_flag")

    input_tensor_ph = tvm.placeholder(input_shape,
                                      name="input_tensor_ph",
                                      dtype=input_dtype,
                                      attrs={
                                          "valid_shape": valid_shape,
                                          "dst_out_flag": dst_out_flag
                                      })

    input_tensor = tvm.compute(input_shape,
                               lambda *indice: input_tensor_ph(*indice),
                               name="input_tensor")
    res = write_select_compute(input_tensor, output_x, kernel_name=kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res)

    config = {"name": kernel_name, "tensor_list": [input_tensor_ph, res]}
    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 17
0
def check_conv3dbp_filter_params(shape_x, shape_out_backprop, filter_sizes,
                                 strides, pads, dilations, x_dtype,
                                 out_backprop_dtype, res_dtype, kernel_name):
    """
    The params check function of conv3d_backprop_filter

    Parameters:
    ----------
    shape_x : The shape of feature map,
              which is 5-D [batch, depth, channels, height, weight].

    shape_out_backprop : The shape of gradients,
                         which is 5-D [batch, depth,channels, height, weight].

    filter_sizes : The shape of filter.
                   which is 5-D [batch, depth, channels, height, weight].

    strides : The stride of the sliding window. A list of ints.

    pads : "SAME"or"VALID",
           indicating the type of pads algorithm to use, or list.

    dilations : An optional list of ints. Default value is [1, 1, 1, 1].

    x_dtype : Fmeature map  data dtype. Default value is float16.

    out_backprop_dtype : Gradients data dtype. Default value is float16.

    res_dtype : Result(De/Dw) data dtype. Default value is float32.

    kernel_name : Kernel name of cce.
                  Default value is "conv3d_backprop_filter_cce"

    Returns : All transformed params.
    ----------
    """
    def _check_attr_range_dw(name, value, attr_min=None, attr_max=None):
        if not attr_min and not attr_max:
            return
        if not attr_min:
            if value > attr_max:
                args_dict = {
                    'errCode': 'E60011',
                    'range': '(,{}]'.format(attr_max),
                    'attr_name': name,
                    'value': str(value)
                }
                raise RuntimeError(args_dict,
                                   err_mana.get_error_message(args_dict))
        elif not attr_max:
            if value < attr_min:
                args_dict = {
                    'errCode': 'E60011',
                    'range': '[{},)'.format(attr_min),
                    'attr_name': name,
                    'value': str(value)
                }
                raise RuntimeError(args_dict,
                                   err_mana.get_error_message(args_dict))
        elif value > attr_max or value < attr_min:
            args_dict = {
                'errCode': 'E60011',
                'range': '[{},{}]'.format(attr_min, attr_max),
                'attr_name': name,
                'value': str(value)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

    def _check_64bits_limitation(attr_name, attr_value, dtype=None):
        if dtype:
            bit_ratio = BIT_RATIO_DICT.get(dtype)
        else:
            bit_ratio = BIT_RATIO_DICT.get("float16")
        if attr_value * bit_ratio > DATA_SIZE_MAX:
            args_dict = {'errCode': 'E60020', 'attr_name': attr_name}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

    # First : Base check, Mainly required by interface appearance
    # ===========================================================
    # util check
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x, CONV3D_BACKPROP_SHAPE_DIM,
                          CONV3D_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(shape_out_backprop, CONV3D_BACKPROP_SHAPE_DIM,
                          CONV3D_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(filter_sizes, CONV3D_BACKPROP_SHAPE_DIM,
                          CONV3D_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    util.check_shape_rule(strides, STRIDES_SHAPE_DIM, STRIDES_SHAPE_DIM,
                          DEFAULT_MAX_SHAPE_NUM)

    def _check_attr_pads():
        # pads check
        if isinstance(pads, (tuple, list)) and \
                len(pads) != PADDING_SHAPE_DIM:
            args_dict = {'errCode': 'E62501', 'param_name': 'pads'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

        if isinstance(pads, str) and pads not in PADDING_SUPPORT:
            args_dict = {
                'errCode': 'E60021',
                'expected_pad_mode': '[{}]'.format(PADDING_SUPPORT),
                'actual_pad_mode': str(pads)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

    _check_attr_pads()

    # dilations check
    util.check_shape_rule(dilations, CONV3D_BACKPROP_SHAPE_DIM,
                          CONV3D_BACKPROP_SHAPE_DIM, DEFAULT_MAX_SHAPE_NUM)
    dilation_n, dilation_d, dilation_c, dilation_h, dilation_w = dilations
    _check_attr_range_dw("dilations's H", dilation_h, DILATION_MIN,
                         DILATION_MAX)
    _check_attr_range_dw("dilations's W", dilation_w, DILATION_MIN,
                         DILATION_MAX)

    if dilation_n != 1 or dilation_c != 1:
        args_dict = {
            'errCode': 'E60023',
            'dilation_n': str(dilation_n),
            'dilation_c': str(dilation_c)
        }
        raise RuntimeError(args_dict, err_mana.get_error_message(args_dict))

    # detype check
    x_dtype = x_dtype.lower()
    out_backprop_dtype = out_backprop_dtype.lower()
    res_dtype = res_dtype.lower()
    util.check_dtype_rule(x_dtype, ['float16'])
    util.check_dtype_rule(out_backprop_dtype, ['float16'])
    util.check_dtype_rule(res_dtype, ['float32', 'float16'])

    # Second : Furture Check, Mainly required by SRS
    # ===========================================================
    # the relation limits between shape
    shape_x = list(shape_x)
    shape_out_backprop = list(shape_out_backprop)
    filter_sizes = list(filter_sizes)
    strides = list(strides)
    fmap_batch, fmap_d, fmap_channel, fmap_h, fmap_w = shape_x
    dedy_batch, dedy_d, dedy_channel, dedy_h, dedy_w = shape_out_backprop
    filter_batch, filter_d, filter_channel, filter_h, filter_w = filter_sizes
    stride_d, stride_h, stride_w = strides

    filter_d_dilation = (filter_d - 1) * dilation_d + 1
    filter_h_dilation = (filter_h - 1) * dilation_h + 1
    filter_w_dilation = (filter_w - 1) * dilation_w + 1

    # pads compute
    if pads == 'SAME':
        pad_d = \
            align(fmap_d, stride_d) - stride_d + filter_d_dilation - fmap_d
        pad_d = max(pad_d, 0)
        pad_front = pad_d // 2
        pad_back = pad_d - pad_front
        pad_w = \
            align(fmap_w, stride_w) - stride_w + filter_w_dilation - fmap_w
        pad_w = max(pad_w, 0)
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left
        pad_h = \
            align(fmap_h, stride_h) - stride_h + filter_h_dilation - fmap_h
        pad_h = max(pad_h, 0)
        pad_up = pad_h // 2
        pad_down = pad_h - pad_up
        pads = [pad_front, pad_back, pad_up, pad_down, pad_left, pad_right]
    elif pads == "VALID":
        pads = PADDING_VAILD
    pads = list(pads)
    pad_front, pad_back, pad_up, pad_down, pad_left, pad_right = pads
    if pad_front >= filter_d_dilation or pad_back >= filter_d_dilation:
        args_dict = {
            'errCode': 'E60013',
            'depth_of_pad': '{}, {}'.format(pad_front, pad_back),
            'depth_of_filter': '{}'.format(filter_d_dilation)
        }
        raise RuntimeError(args_dict, err_mana.get_error_message(args_dict))
    if pad_up >= filter_h_dilation or pad_down >= filter_h_dilation:
        args_dict = {
            'errCode': 'E60016',
            'h_of_filter': '{}'.format(filter_h_dilation),
            'h_of_pad': '{}, {}'.format(pad_up, pad_down)
        }
        raise RuntimeError(args_dict, err_mana.get_error_message(args_dict))
    if pad_left >= filter_w_dilation or pad_right >= filter_w_dilation:
        args_dict = {
            'errCode': 'E60017',
            'w_of_filter': '{}'.format(filter_w_dilation),
            'w_of_pad': '{}, {}'.format(pad_left, pad_right)
        }
        raise RuntimeError(args_dict, err_mana.get_error_message(args_dict))

    fmap_w_padding = fmap_w + pad_left + pad_right
    fmap_h_padding = fmap_h + pad_up + pad_down

    # special cases
    fmap_hw_min, dey_hw_min = FMAP_HW_MIN, DEDY_HW_MIN
    # limitation by chip:
    # if kernel h,w in [1,11] and fmap h/w after padding equals to filter h/w
    # load3d support h,w is 1
    if (1 <= filter_w <= 11) and (1 <= filter_h <= 11) and (1 <= filter_d <= 11)\
            and (fmap_w_padding == filter_w or fmap_h_padding == filter_h):
        fmap_hw_min = 1
        dey_hw_min = 1

    # Dedy value limit
    _check_attr_range_dw("Dedy's H", dedy_h, dey_hw_min, DEDY_HW_MAX)
    _check_attr_range_dw("Dedy's W", dedy_w, dey_hw_min, DEDY_HW_MAX)

    # filter value limit
    _check_attr_range_dw("filter's H", filter_h, FILTER_HW_MIN, FILTER_HW_MAX)
    _check_attr_range_dw("filter's W", filter_w, FILTER_HW_MIN, FILTER_HW_MAX)

    # Fmap value limit
    _check_attr_range_dw("Fmap's H", fmap_h, fmap_hw_min, FMAP_HW_MAX)
    _check_attr_range_dw("Fmap's W", fmap_w, fmap_hw_min, FMAP_HW_MAX)

    # stride value limit
    _check_attr_range_dw("stride's H", stride_h, STRIDE_HW_MIN, STRIDE_HW_MAX)
    _check_attr_range_dw("stride's W", stride_w, STRIDE_HW_MIN, STRIDE_HW_MAX)

    def _check_axis_hw():
        if fmap_batch != dedy_batch:
            args_dict = {
                'errCode': 'E62503',
                'backprop_N': str(dedy_batch),
                'forward_shape': str(fmap_batch)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if dedy_channel != filter_batch:
            args_dict = {
                'errCode': 'E62504',
                'backprop_C': str(dedy_channel),
                'forward_shape': str(filter_batch)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if fmap_channel != filter_channel:
            args_dict = {
                'errCode': 'E60010',
                'channel_of_x': str(fmap_channel),
                'channel_of_filter': str(filter_channel)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if filter_w_dilation > fmap_w_padding:
            args_dict = {
                'errCode': 'E60015',
                'w_of_x': str(fmap_w_padding),
                'w_of_filter': str(filter_w_dilation)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if filter_h_dilation > fmap_h_padding:
            args_dict = {
                'errCode': 'E60014',
                'h_of_x': str(fmap_h_padding),
                'h_of_filter': str(filter_h_dilation)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

        # Third : value check, Mainly required by the convolution rule
        if ((fmap_w - filter_w_dilation + pad_left + pad_right) // stride_w +
                1) != dedy_w:
            args_dict = {'errCode': 'E60025'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if ((fmap_h - filter_h_dilation + pad_up + pad_down) // stride_h +
                1) != dedy_h:
            args_dict = {'errCode': 'E60024'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

    _check_axis_hw()

    def _min_l1_byte():
        # Forth : L1 limitation, Mainly required by chip
        al1_min_byte = C0 * C0 * 2

        if dedy_w % C0 == 0:
            bl1_min_byte = filter_h_dilation * fmap_w * C0 * 2
        else:
            bl1_min_byte = (filter_h_dilation + stride_h) * fmap_w * C0 * 2

        l1_size = get_soc_spec("L1_SIZE")  # L1 size
        if (al1_min_byte + bl1_min_byte) > l1_size:
            args_dict = {'errCode': 'E60022'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

    _min_l1_byte()
    # Fifth : check shape size, 64 bits limitation
    c0_size = cce_params.C0_SIZE
    fmap_size = fmap_batch * fmap_d * align(fmap_channel,
                                            c0_size) * fmap_h * fmap_w
    dedy_size = dedy_batch * dedy_d * align(dedy_channel,
                                            c0_size) * dedy_h * dedy_w
    filter_size = \
        align(filter_batch, c0_size) * filter_d * align(filter_channel, c0_size) \
        * filter_h * filter_w
    _check_64bits_limitation("fmap_size", fmap_size, dtype=x_dtype)
    _check_64bits_limitation("dedy_size", dedy_size, dtype=out_backprop_dtype)
    _check_64bits_limitation("filter_size", filter_size, dtype=res_dtype)

    result = (shape_x, shape_out_backprop, filter_sizes, strides, pads,
              dilations, x_dtype, out_backprop_dtype, res_dtype, kernel_name)
    return result
def fake_quant_min_max_per_channel_update(
        x,
        min_val,
        max_val,
        min_up,
        max_up,
        ema,
        ema_decay,
        symmetric,
        narrow_range,
        training,
        num_bits,
        channel_axis,
        kernel_name="fake_quant_min_max_per_channel_update"):
    """FakeQuantPerLayer op"""
    x_shape = x.get("ori_shape")
    x_format = x.get("format")
    x_dtype = x.get("dtype")
    min_shape = min_val.get("ori_shape")
    min_dtype = min_val.get("dtype")
    max_shape = max_val.get("ori_shape")
    max_dtype = max_val.get("dtype")

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(x_shape)
    util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
    util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
    util.check_tensor_shape_size(x_shape)
    util.check_tensor_shape_size(min_shape)
    util.check_tensor_shape_size(max_shape)

    check_list = ["float32", "float16"]
    x_dtype = x_dtype.lower()
    min_dtype = min_dtype.lower()
    max_dtype = max_dtype.lower()
    util.check_dtype_rule(x_dtype, check_list)
    util.check_dtype_rule(min_dtype, check_list)
    util.check_dtype_rule(max_dtype, check_list)

    if symmetric:
        quant_min = 0 - 2**(num_bits - 1)
        quant_max = 2**(num_bits - 1) - 1
    else:
        quant_min = 0
        quant_max = 2**num_bits - 1
    if narrow_range:
        quant_min = quant_min + 1

    shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
    input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
    min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
    max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
    res_list = fake_quant_min_max_per_channel_update_compute(
        input_data, min_data, max_data, ema, ema_decay, quant_min, quant_max,
        training, channel_axis, kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res_list)

    tensor_list = [input_data, min_data, max_data] + list(res_list)
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
def fake_quant_per_layer_grad(dout,
                              x,
                              min_val,
                              max_val,
                              dx,
                              num_bits,
                              symmetric,
                              narrow_range,
                              kernel_name="fake_quant_per_layer_grad"):
    """FakeQuantPerLayerGrad"""
    input_shape = x.get("shape")
    input_dtype = x.get("dtype")
    min_shape = min_val.get("ori_shape")
    min_dtype = min_val.get("dtype")
    max_shape = max_val.get("ori_shape")
    max_dtype = max_val.get("dtype")

    min_shape = util.scalar2tensor_one(min_shape)
    max_shape = util.scalar2tensor_one(max_shape)
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(input_shape)
    util.check_shape_rule(min_shape, 1, 1, 1)
    util.check_shape_rule(max_shape, 1, 1, 1)
    util.check_tensor_shape_size(input_shape)
    util.check_tensor_shape_size(min_shape)
    util.check_tensor_shape_size(max_shape)

    check_list = ["float32", 'float16']
    x_dtype = input_dtype.lower()
    min_dtype = min_dtype.lower()
    max_dtype = max_dtype.lower()
    util.check_dtype_rule(x_dtype, check_list)
    util.check_dtype_rule(min_dtype, check_list)
    util.check_dtype_rule(max_dtype, check_list)

    input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]), )
    shape_min, _, _ = util.produce_shapes(min_shape, input_shape)

    quant_min = 0
    quant_max = 2**num_bits - 1
    if narrow_range:
        quant_min = quant_min + 1

    dout_data = tvm.placeholder(input_shape, name="dout", dtype=x_dtype)
    input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
    min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
    max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
    res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data,
                                            max_data, quant_min, quant_max,
                                            symmetric, kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res)

    tensor_list = [dout_data, input_data, min_data, max_data, res]
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 20
0
def correction_mul_grad(dout,
                        x,
                        batch_std,
                        running_std,
                        dx,
                        d_batch_std,
                        channel,
                        kernel_name="correction_mul_grad"):
    """CorrectionMulGrad op"""
    shape_dout = dout.get("shape")
    shape_x = dout.get("shape")

    dtype_dout = dout.get("dtype")
    dtype_x = x.get("dtype")
    dtype_batch_std = batch_std.get("dtype")
    dtype_running_std = running_std.get("dtype")

    inp_dtype_dout = dtype_dout.lower()
    inp_dtype_x = dtype_x.lower()
    inp_dtype_batch_std = dtype_batch_std.lower()
    inp_dtype_running_std = dtype_running_std.lower()

    util.check_dtype_rule(inp_dtype_dout, ("float16", "float32"))
    util.check_dtype_rule(inp_dtype_x, ("float16", "float32"))
    util.check_dtype_rule(inp_dtype_batch_std, ("float32", ))
    util.check_dtype_rule(inp_dtype_running_std, ("float32", ))
    util.compare_tensor_dict_key(dout, x, "dtype")
    util.compare_tensor_dict_key(dout, x, "shape")
    util.compare_tensor_dict_key(dx, x, "shape")
    util.compare_tensor_dict_key(batch_std, running_std, "shape")
    util.compare_tensor_dict_key(batch_std, d_batch_std, "shape")

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x)
    util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)

    data_format = dout.get("format")
    ori_format = dout.get("format")
    if data_format.upper() not in ("NC1HWC0", "NCHW"):
        raise RuntimeError("Un supported data format {}".format(data_format))
    if data_format.upper() == "NCHW" and ori_format != "NCHW":
        raise RuntimeError("data_format(NCHW) must same as ori_format")

    shape_c = [1] * len(shape_x)
    shape_c[channel] = batch_std.get("ori_shape")[0]
    if data_format == "NC1HWC0" and channel == 1:
        shape_c = batch_std.get("shape")

    dout_t = tvm.placeholder(shape_dout, name="dout", dtype=inp_dtype_dout)
    x_t = tvm.placeholder(shape_x, name="x", dtype=inp_dtype_x)
    batch_std_t = tvm.placeholder(shape_c,
                                  name="batch_std",
                                  dtype=inp_dtype_batch_std)
    running_std_t = tvm.placeholder(shape_c,
                                    name="running_std",
                                    dtype=inp_dtype_running_std)
    res_list = correction_mul_grad_compute(dout_t, x_t, batch_std_t,
                                           running_std_t, channel, data_format,
                                           kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res_list)

    tensor_list = [dout_t, x_t, batch_std_t, running_std_t] + list(res_list)
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 21
0
def custom_expm1(shape,
                 dtype,
                 kernel_name="cce_tf_expm1",
                 need_build=False,
                 need_print=False):
    """
    algorithm: expm1

    calculating data's expm1, y= (e ** x) - 1,dtype is float16 or float32.

    Parameters
    ----------
    shape : shape of data.

    dtype : the data type, assume src_dtype equals dst_dtype, only support float16, float32.

    kernel_name : cce kernel name, default value is "cce_tf_expm1".

    need_buid : if need to build CCEC kernel, default value is False.

    need_print : if need to print the ir, default value is False.

    Returns
    -------
    None

    """

    # [aicpu] int32_t cc_device_exp(uint32_t blockNum, uint32_t blockIdx, int32_t dataType, const void *scale, const void *shift,
    # const void *base, int32_t dimCnt, int32_t *shape, uint32_t padC0, const void *x, void *y);

    supported_dtypes = ["float16", "float32"]

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)

    if not (dtype.lower() in supported_dtypes):
        raise RuntimeError("tf_expm1_cce only support %s while dtype is %s" %
                           (",".join(supported_dtypes), dtype))

    inp_dtype = dtype.lower()
    shape = util.shape_refine(shape)
    data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)

    # step 1. calculate y = exp ** x by aicpu api
    device_api = "DeviceExp"
    v_datatype = util.get_device_api_dtype(inp_dtype)
    v_ndim = len(shape)
    block_num = "block_num"
    block_idx = "block_idx"
    padC0 = 0
    p_scale = util.create_param_ptr([1], inp_dtype, "p_scale")
    p_shift = util.create_param_ptr([0], inp_dtype, "p_shift")
    p_base = util.create_param_ptr([-1], inp_dtype, "p_base")
    p_shape = util.create_param_ptr(shape, "int32", "p_shape")

    output_exp = tvm.extern(
        shape,
        [data_input, p_scale, p_shift, p_base, p_shape],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_datatype,
            ins[1].access_ptr("r"),  # scale
            ins[2].access_ptr("r"),  # shift
            ins[3].access_ptr("r"),  # base
            v_ndim,
            ins[4].access_ptr("r"),  # shape
            padC0,
            ins[0].access_ptr("r"),  # input x
            outs[0].access_ptr("w")),
        name="output_exp",
        dtype=inp_dtype)

    offset = tvm.const((-1), dtype=inp_dtype)

    # step 2. cauculate y = exp ** x - 1 by tvm
    output = tvm.compute(
        shape,
        lambda *indice: output_exp(*indice) + offset.astype(inp_dtype),
        name="output")

    # step 3. schedule the computation by tvm
    s = tvm.create_schedule(output.op)

    # step 4. build by tvm
    if need_print:
        with build_config:
            print(tvm.lower(s, [data_input, output], simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(s, [data_input, output], "cce", name=kernel_name)
Ejemplo n.º 22
0
def fake_quant_perchannel(x,
                          min_val,
                          max_val,
                          y,
                          symmetric,
                          narrow_range,
                          num_bits,
                          channel_axis,
                          kernel_name="fake_quant_perchannel"):
    """FakeQuantPerChannel"""
    x_shape = x.get("shape")
    x_shape_ = x.get("ori_shape")
    x_format = x.get("format")
    x_dtype = x.get("dtype")
    min_shape = min_val.get("ori_shape")
    min_dtype = min_val.get("dtype")
    max_shape = max_val.get("ori_shape")
    max_dtype = max_val.get("dtype")
    # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
    if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[
            1] == min_shape[0]:
        channel_axis_ = 1
    else:
        channel_axis_ = channel_axis
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(x_shape)
    util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
    util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
    util.check_tensor_shape_size(x_shape)
    util.check_tensor_shape_size(min_shape)
    util.check_tensor_shape_size(max_shape)

    check_list = ["float32", "float16"]
    x_dtype = x_dtype.lower()
    min_dtype = min_dtype.lower()
    max_dtype = max_dtype.lower()
    util.check_dtype_rule(x_dtype, check_list)
    util.check_dtype_rule(min_dtype, check_list)
    util.check_dtype_rule(max_dtype, check_list)

    if symmetric:
        quant_min = 0 - 2**(num_bits - 1)
        quant_max = 2**(num_bits - 1) - 1
    else:
        quant_min = 0
        quant_max = 2**num_bits - 1
    if narrow_range:
        quant_min = quant_min + 1

    shape_c = [1] * len(x_shape)
    shape_c[channel_axis_] = min_val.get("ori_shape")[0]
    if x_format == "NC1HWC0" and channel_axis_ == 1:
        shape_c = min_val.get("shape")
    input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
    min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
    max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
    res = fake_quant_perchannel_compute(input_data, min_data, max_data, y,
                                        quant_min, quant_max, kernel_name)

    with tvm.target.cce():
        sch = generic.auto_schedule(res)

    tensor_list = [input_data, min_data, max_data, res]
    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 23
0
def custom_Upsample(shape,
                    dtype,
                    scale,
                    data_format="channels_last",
                    kernel_name="cce_darknet_upsample",
                    need_build=False,
                    need_print=False):
    """
    Parameters
    ----------
    shape: input tensor's shape

    dtype: input tensor's dtype, support:`float16,float32

    scale: the upsampling factors

    data_format: "channels_last" or "channels_first"

    kernel_name : kernel name, default value is "MyUpsample"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """
    """
    TODO:
    Please refer to the TE DSL Manual, And code here with TE DSL.
    """
    inp_dtype = dtype.lower()
    check_list = ["float16", "float32", "int32", "int8", "uint8"]
    if inp_dtype not in check_list:
        raise RuntimeError("upsample only support %s while dtype is %s" %
                           (",".join(check_list), dtype))

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
    size = (scale, scale)

    shape_size = len(shape)
    if not (shape_size == 4 or shape_size == 5):
        raise RuntimeError(
            "upsample only support 4D or 5D while len(shape):%d" % len(shape))

    input_tensor = tvm.placeholder(shape, name="input_tensor", dtype=inp_dtype)

    res = None
    if shape_size == 5:
        # shape_size == 5 D-sepecial (N, C1, H, W, C0)
        output_shape = (shape[0], shape[1], shape[2] * size[0],
                        shape[3] * size[1], shape[4])
        res = tvm.compute(
            output_shape, lambda n, c0, h, w, c: input_tensor[n, c0, h // size[
                0], w // size[1], c])
    else:
        if data_format == "channels_last":
            output_shape = (shape[0], shape[1] * size[0], shape[2] * size[1],
                            shape[3])
            res = tvm.compute(
                output_shape, lambda n, h, w, c: input_tensor[n, h // size[0],
                                                              w // size[1], c])
        elif data_format == "channels_first":
            output_shape = (shape[0], shape[1], shape[2] * size[0],
                            shape[3] * size[1])
            res = tvm.compute(
                output_shape, lambda n, c, h, w: input_tensor[n, c, h // size[
                    0], w // size[1]])
        else:
            raise RuntimeError(
                "upsample only support channels_last|channels_first "
                "while input type %s" % data_format)

    schedule = tvm.create_schedule(res.op)
    if need_print:
        with build_config:
            print(tvm.lower(schedule, [input_tensor, res], simple_mode=True))

    if need_build:
        with build_config:
            tvm.build(schedule, [input_tensor, res], "cce", name=kernel_name)
Ejemplo n.º 24
0
def custom_greater_equal(shape_x,
                         shape_y,
                         dtype,
                         kernel_name="cce_tf_greater_equal",
                         need_build=False,
                         need_print=False):
    """
    do element-wise greater equal operation between two input tensors

    Parameters:
    ----------
    shape_x : shape of input data1

    shape_y : shape of input data2

    dtype : source data type, support [float16,float32,int32,int8,uint8]

    kernel_name : cce kernel name, default value is "cce_tf_greater_equal"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x)
    util.check_shape_rule(shape_y)

    check_list = ["float16", "float32", "int32", "int8", "uint8", "bool"]

    dtype = dtype.lower()
    if not (dtype in check_list):
        raise RuntimeError("tf_equal_cce only support %s while dtype is %s" %
                           (",".join(check_list), dtype))

    util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT)

    shape_x, shape_y, shape_max = util.produce_shapes(shape_x, shape_y)

    util.check_shape_size(shape_max, SHAPE_SIZE_LIMIT)

    data1 = tvm.placeholder(shape_x, dtype=dtype, name="data1")
    data2 = tvm.placeholder(shape_y, dtype=dtype, name="data2")

    data1_tmp1 = te.lang.cce.broadcast(data1, shape_max)
    data2_tmp1 = te.lang.cce.broadcast(data2, shape_max)

    res = tvm.compute(shape_max,
                      lambda *i: data1_tmp1(*i) >= data2_tmp1(*i),
                      name='res')

    sch = tvm.create_schedule(res.op)

    if need_print:
        with build_config:
            print(tvm.lower(sch, [data1, data2, res], simple_mode=True))

    if need_build:
        with build_config:
            tvm.build(sch, [data1, data2, res], "cce", name=kernel_name)
def CusMatMulCubeDenseLeft(input_x1,
                           input_x2,
                           bias=None,
                           output_y={},
                           trans_a=False,
                           trans_b=False,
                           kernel_name="matmulcube"):
    """
    calculating  matrix multiplication with bias, C = A*B + bias, support input
    data with fractal format.

    Parameters:
    shape_a: list or tuple
            Shape of the first tensor a with rank > 1
    shape_b:  list or tuple
            Shape of the second tensor b with the same type with a,
            and shape_a, shape_b must be 2 dims
    src_dtype: str
            The data type of input, support "float32", "float16"
    dst_dtype: str
            The data type of output, support "float32", "float16"
    trans_a: bool
            If True, shape_a == transposed before multiplication
    trans_b: bool
            If True, shape_b == transposed before multiplication
    is_fractal: bool
            If True, the input data format of a and b must be fractal format
    shape_bias: list or tuple
            Shape of bias, only support the input data format with ND

    Returns
    -------
    None
    """
    print("!!!!come into zzt~~~~~~~!!!!")
    shape_a = input_x1.get("ori_shape")
    shape_b = input_x2.get("ori_shape")
    shape_output = output_y.get("ori_shape")
    print("============")
    print(input_x1.get("format"), input_x2.get("format"))
    print(shape_a, shape_b)
    print("============")
    if input_x2.get("format") == "FRACTAL_Z":
        n, c, h, w = shape_b
        c0 = 16
        c1 = c // c0
        if c1 == 0:
            c1 = 1
        shape_b = [n, c1 * h * w * c0]
        shape_a = [n, n]

    if input_x1.get("format") == "FRACTAL_Z":
        n, c, h, w = shape_a
        c0 = 16
        c1 = c // c0
        if c1 == 0:
            c1 = 1
        shape_a = [n, c1 * h * w * c0]
        shape_b = [c1 * h * w * c0, c1 * h * w * c0]

    if input_x2.get("format") == "FRACTAL_NZ":
        shape_a = [shape_b[0], shape_b[0]]
        shape_b = shape_b

    if input_x1.get("format") == "FRACTAL_NZ":
        shape_a = shape_a
        shape_b = [shape_a[1], shape_a[1]]

    shape_a = list(shape_a)
    shape_b = list(shape_b)

    shape_a = _get_input_shape(shape_a)
    shape_b = _get_input_shape(shape_b)

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_a)
    util.check_shape_rule(shape_b)
    util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)

    shape_a = [shape_a[1], shape_a[0]]
    trans_a = bool(1 - trans_a)

    shape_b = [shape_b[1], shape_b[0]]
    trans_b = bool(1 - trans_b)

    shape_bias = ()
    if bias is not None and bool(bias):
        shape_bias = bias.get("shape")
        shape_bias = list(shape_bias)
        shape_bias = _get_bias(shape_bias)

    src_dtype = input_x1.get("dtype").lower()
    dst_dtype = output_y.get("dtype").lower()
    _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)

    m_shape = shape_a[len(shape_a) - 2]
    km_shape = shape_a[len(shape_a) - 1]
    kn_shape = shape_b[len(shape_a) - 2]
    n_shape = shape_b[len(shape_a) - 1]

    if src_dtype == "float16":
        block_reduce = cce.BLOCK_REDUCE

    block_in = cce.BLOCK_IN
    block_out = cce.BLOCK_OUT

    if trans_a and km_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if not trans_a and m_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if trans_b and kn_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if not trans_b and n_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if trans_a:
        shape_a_temp = (m_shape // block_reduce, km_shape // block_in,
                        block_reduce, block_in)
    else:
        shape_a_temp = (m_shape // block_in, km_shape // block_reduce,
                        block_in, block_reduce)

    if trans_b:
        shape_b_temp = (kn_shape // block_out, n_shape // block_reduce,
                        block_reduce, block_out)
    else:
        shape_b_temp = (kn_shape // block_reduce, n_shape // block_out,
                        block_out, block_reduce)
    shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2],
                    shape_a_temp[3])
    format_a = "FRACTAL_NZ"
    shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2],
                    shape_b_temp[3])
    format_b = "FRACTAL_NZ"

    print("=======================================")
    print(shape_a_temp, shape_b_temp)
    print(format_a, format_b)
    print("=======================================")
    tensor_bias = None
    tensor_a = tvm.placeholder(shape_a_temp, name='tensor_a', dtype=src_dtype)
    tensor_b = tvm.placeholder(shape_b_temp, name='tensor_b', dtype=src_dtype)

    if shape_bias:
        tensor_bias = tvm.placeholder(shape_bias,
                                      name='tensor_bias',
                                      dtype=dst_dtype)

    if shape_a_temp[0] == 63 and shape_a_temp[1] == 63 and shape_b_temp[
            0] == 128 and shape_b_temp[1] == 63:
        if util.get_product_version() == util.VERSION_MINI:
            tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
        else:
            tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))

        input_x1 = tik_instance.Tensor("float16",
                                       shape_a_temp,
                                       name="left_matrix",
                                       scope=tik.scope_gm)
        input_x2 = tik_instance.Tensor("float16",
                                       shape_b_temp,
                                       name="right_matrix",
                                       scope=tik.scope_gm)
        resMatmul = tik_instance.Tensor("float16",
                                        shape_output,
                                        name="output",
                                        scope=tik.scope_gm)
        with tik_instance.for_range(0, 32, block_num=32) as block_index:
            resMatmul_local_UB = tik_instance.Tensor("float16", (128 * 256, ),
                                                     scope=tik.scope_ubuf,
                                                     name="resMatmul_local_UB")
            resMatmul_local_UB_local_L0C = tik_instance.Tensor(
                "float32", (128 * 256, ),
                scope=tik.scope_cc,
                name="resMatmul_local_UB")
            input_1_local_L1_local_L0A = tik_instance.Tensor(
                "float16", (128 * 128, ),
                scope=tik.scope_ca,
                name="input_1_local_L1_local_L0A")
            input_2_local_L1 = tik_instance.Tensor("float16", (128 * 256, ),
                                                   scope=tik.scope_cbuf,
                                                   name="input_2_local_L1")
            input_1_local_L1 = tik_instance.Tensor("float16", (128 * 128, ),
                                                   scope=tik.scope_cbuf,
                                                   name="input_1_local_L1")
            input_2_local_L1_local_L0B = tik_instance.Tensor(
                "float16", (128 * 256, ),
                scope=tik.scope_cb,
                name="input_2_local_L1_local_L0B")
            core_m_idx = block_index % 8
            core_n_idx = block_index // 8
            with tik_instance.if_scope(core_m_idx != 7):
                tik_instance.data_move(
                    input_1_local_L1,
                    input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 8, 128,
                    55 * 16, 0)
                tik_instance.data_move(
                    input_2_local_L1,
                    input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008],
                    0, 32, 128, 55 * 16, 0)
                with tik_instance.for_range(0, 8) as cc12:
                    tik_instance.load2dv1(
                        input_1_local_L1_local_L0A[cc12 * 2048],
                        input_1_local_L1[cc12 * 256], 0, 8, 8, 0, False)
                with tik_instance.for_range(0, 2) as cc6:
                    with tik_instance.for_range(0, 8) as cc121:
                        tik_instance.load2dv1(
                            input_2_local_L1_local_L0B[cc121 * 4096],
                            input_2_local_L1[cc6 * 32768 + cc121 * 256], 0, 16,
                            8, 0, True)
                    tik_instance.mmad(resMatmul_local_UB_local_L0C,
                                      input_1_local_L1_local_L0A,
                                      input_2_local_L1_local_L0B, 128, 128,
                                      256, 0)
                    tik_instance.data_move(resMatmul_local_UB,
                                           resMatmul_local_UB_local_L0C, 0, 1,
                                           128, 0, 0, 1)
                    tik_instance.data_move(
                        resMatmul[cc6 * 256 * 1008 + core_m_idx * 8 * 256 +
                                  core_n_idx * 512 * 1008], resMatmul_local_UB,
                        0, 16, 256 // 2, 0, 55 * 16 * 2 // 2)
            with tik_instance.else_scope():
                tik_instance.data_move(
                    input_1_local_L1,
                    input_x1[core_m_idx * (8 * 256 + 128 * 1008)], 0, 7, 112,
                    56 * 16, 0)
                tik_instance.data_move(
                    input_2_local_L1,
                    input_x2[core_m_idx * 8 * 256 + core_n_idx * 512 * 1008],
                    0, 32, 112, 56 * 16, 0)
                with tik_instance.for_range(0, 7) as cc10:
                    tik_instance.load2dv1(
                        input_1_local_L1_local_L0A[cc10 * 1792],
                        input_1_local_L1[cc10 * 256], 0, 7, 7, 0, False)
                with tik_instance.for_range(0, 2) as cc5:
                    with tik_instance.for_range(0, 7) as cc101:
                        tik_instance.load2dv1(
                            input_2_local_L1_local_L0B[cc101 * 4096],
                            input_2_local_L1[cc5 * 28672 + cc101 * 256], 0, 16,
                            7, 0, True)
                    tik_instance.mmad(resMatmul_local_UB_local_L0C,
                                      input_1_local_L1_local_L0A,
                                      input_2_local_L1_local_L0B, 112, 112,
                                      256, 0)
                    tik_instance.data_move(resMatmul_local_UB,
                                           resMatmul_local_UB_local_L0C, 0, 1,
                                           112, 0, 0, 1)
                    tik_instance.data_move(
                        resMatmul[cc5 * 256 * 1008 + core_m_idx * 8 * 256 +
                                  core_n_idx * 512 * 1008], resMatmul_local_UB,
                        0, 16, 224 // 2, 0, 56 * 16 * 2 // 2)
        tik_instance.BuildCCE(kernel_name=kernel_name,
                              inputs=[input_x1, input_x2],
                              outputs=[resMatmul])
        return tik_instance

    print("come into tbe, shape is error!")
    result = te.lang.cce.matmul(tensor_a,
                                tensor_b,
                                trans_a,
                                trans_b,
                                format_a=format_a,
                                format_b=format_b,
                                dst_dtype=dst_dtype,
                                tensor_bias=tensor_bias)

    with tvm.target.cce():
        schedule = generic.auto_schedule(result)

    tensor_list = [tensor_a, tensor_b, result]
    if shape_bias:
        tensor_list = [tensor_a, tensor_b, tensor_bias, result]

    config = {
        "print_ir": False,
        "name": kernel_name,
        "tensor_list": tensor_list
    }

    te.lang.cce.cce_build_code(schedule, config)
Ejemplo n.º 26
0
def custom_Reduction(shape,
                     dtype,
                     axis,
                     op,
                     coeff,
                     kernel_name="cce_reductionLayer",
                     need_build=False,
                     need_print=False):
    """
    Reduce a tensor on a certain axis, and scale output with coeff

    Parameters
    ----------
    shape : shape of data

    dtype : source data type, only support float16, float32, int8, uint8

    axis : the first axis to reduce, may be negative to index from the end
           (e.g., -1 for the last axis).
           If axis == 0, the output Blob always has the empty shape (count 1),
           performing reduction across the entire input.

    op : can only be one of "SUM, ASUM (sum of abs), SUMSQ (sum of sqr), MEAN"

    coeff : scale for output

    kernel_name : cce kernel name, default value is "cce_reductionLayer"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None

    """
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)

    check_list = ["float16", "float32", "int8", "uint8"]
    if not dtype.lower() in check_list:
        raise RuntimeError(
            "reductionLayer_cce only support %s while dtype is %s" %
            (",".join(check_list), dtype))

    reduction_op = ("SUM", "ASUM", "SUMSQ", "MEAN")

    if not isinstance(axis, int):
        raise RuntimeError("type of axis value should be int")
    if op not in reduction_op:
        raise RuntimeError("op can only be one of SUM, ASUM, SUMSQ , MEAN")
    if not isinstance(coeff, int) and not isinstance(coeff, float):
        raise RuntimeError("coeff must be a value")
    axis_origin = axis
    shape_origin = shape
    axis = util.axis_check(len(shape), axis)
    util.check_reduce_shape_rule(shape)
    shape = list(shape)
    shape1 = shape[:axis] + [
        functools_reduce(lambda x, y: x * y, shape[axis:])
    ]
    shape1, axis = util.shape_refine(shape1, axis)
    if not axis:
        axis = [0]
        shape1 = [1] + shape1
    inp_dtype = dtype.lower()
    data = tvm.placeholder(shape1, name="data_input", dtype=inp_dtype)
    with tvm.target.cce():
        res = caffe_reduction_layer_compute([data], shape_origin, dtype,
                                            axis_origin, op, coeff,
                                            kernel_name, need_build,
                                            need_print)

    if op == "MEAN" and (inp_dtype == "int8" or inp_dtype == "uint8"):
        util.check_shape_size(shape, SHAPE_SIZE_LIMIT)
        res = te.lang.cce.cast_to(res, inp_dtype)
        schedule = tvm.create_schedule(res.op)
        if need_print:
            with build_config:
                print(tvm.lower(schedule, [data, res], simple_mode=True))
        if need_build:
            with build_config:
                tvm.build(schedule, [data, res], "cce", name=kernel_name)
    else:
        with tvm.target.cce():
            sch = generic.auto_schedule(res)

        config = {
            "print_ir": need_print,
            "need_build": need_build,
            "name": kernel_name,
            "tensor_list": [data, res]
        }
        te.lang.cce.cce_build_code(sch, config)
Ejemplo n.º 27
0
def custom_floor(shape,
                 dtype,
                 kernel_name="cce_tf_floor",
                 need_build=False,
                 need_print=False):
    """
    calculate floor(data), calculating data type is float16 or float32

    Parameters
    ----------
    shape : shape of data

    dtype : source data type, assume src_dtype equals dst_type, only support
    float16 or float32

    kernel_name : cce kernel name, default value is "cce_tf_floor"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None

    """
    check_list = ["float16", "float32"]
    device_api_map = {
        "float16": "cc_device_floor_float16",
        "float32": "cc_device_floor_float"
    }

    max_dim = 8
    shape_len = len(shape)
    if shape_len > max_dim:
        raise RuntimeError(
            "floor_cce only support up to %d dimensions while the shape's \
            dimension is %d" % (max_dim, shape_len))

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)

    if not dtype.lower() in check_list:
        raise RuntimeError("floor_cce only support %s while dtype is %s" %
                           (",".join(check_list), dtype))

    inp_dtype = dtype.lower()
    shape = util.shape_refine(shape)
    data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)
    device_api = device_api_map[inp_dtype]

    block_num = "block_num"
    block_idx = "block_idx"
    v_ndim = tvm.const(len(shape), "int32")
    pad_c0 = tvm.const(0, "int32")
    p_shape = util.create_param_ptr(shape, "int32", "p_shape")

    output = tvm.extern(
        shape,
        [data_input, p_shape],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_ndim,
            ins[1].access_ptr("r"),  # shape
            pad_c0,
            ins[0].access_ptr("r"),  # input x
            outs[0].access_ptr("w")),
        name="output",
        dtype=inp_dtype)

    schedule = tvm.create_schedule(output.op)

    if need_print:
        with build_config:
            print(tvm.lower(schedule, [data_input, output], simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(schedule, [data_input, output], "cce", name=kernel_name)
Ejemplo n.º 28
0
def custom_batch_matmul(shape_x,
                        shape_y,
                        dtype,
                        trans_a=False,
                        trans_b=False,
                        kernel_name="cce_tf_batch_matmul",
                        need_build=False,
                        need_print=False):
    """
    Multiplies slices of two tensors in batches(each slice can be viewed
    as an element of a batch), the output is of the same batch size.

    Each of the individual slices can optionally be transposed before
    multiplication by setting the trans_a or trans_b flag to True, which
    are by default False. The input tensors are 2-D or higher with the
    shape [..., r_x, c_x] and [..., r_y, c_y]. The output tensor is 2-D
    or higher with the shape [..., r_o, c_o], where
    r_o = c_x if trans_a else r_x
    c_o = r_y if trans_b else c_y

    Parameters
    ----------
    shape_x : shape of the first tensor x with rank > 1

    shape_y : shape of the second tensor y with the same type and shape with x

    dtype : the data type, support int8, uint8,float16,float32,int32

    kernel_name : cce kernel name, default value is "cce_batch_matmul"

    trans_a : if True, shape_x is transposed before multiplication

    trans_b : if True, shape_y is transposed before multiplication

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None
    """
    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_x)
    util.check_shape_rule(shape_y)

    util.check_shape_size(shape_x, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_y, SHAPE_SIZE_LIMIT)

    data_dtype = dtype.lower()
    check_list = ["int8", "uint8", "float16", "float32", "int32"]
    if data_dtype not in check_list:
        raise RuntimeError(
            "batch_matmul_cce ony supports %s while dtype is %s" %
            (",".join(check_list), dtype))

    def transpose_tensor(shape, size):
        """Transpose the shape, e.g., the shape [..., r_x, c_x] is transposed
        to [..., c_x, r_x].

        Parameters
        ----------
        shape : shape of a tensor

        size : length of the shape

        Returns
        -------
        shape_ori : the transposed shape
        """
        shape_ori = ()
        if size == 1:
            shape_ori = shape_ori + shape
        elif size == 2:
            shape_ori = shape_ori + (shape[1], ) + (shape[0], )
        else:
            shape_ori = shape_ori + (shape[:(size - 2)]) + (
                shape[size - 1], ) + (shape[size - 2], )
        return shape_ori

    def check_matmul(shape_x, shape_y):
        """Check whether batch_matmul is supported or not.

        Parameters
        ----------
        shape_x : shape of the first tensor x

        shape_y : shape of the second tensor y with the same type and shape
        with x

        Returns
        -------
        None
        """
        len_x = len(shape_x)
        len_y = len(shape_y)
        if (len_x < 2) or (len_y < 2):
            raise RuntimeError("Only tensors of rank>=2 are supported!")
        if shape_x[len_x - 1] != shape_y[len_y - 2]:
            raise RuntimeError(
                "Invalid matrix multiplication for the inner 2 dimensions!")
        if (len_x == len_y) and (len_x > 2):
            for i in range(len_x - 2):
                if shape_x[i] != shape_y[i]:
                    raise RuntimeError("Outer dimensions do not match!")
            return
        elif (len_x == len_y) and (len_x == 2):
            return
        else:
            raise RuntimeError("The input tensors are not with the same rank!")

    def _compute(output_shape, x, y, K, trans_a, trans_b, *indices):
        """matmul compuation in terms of the output shape and the transposes

        Parameters
        ----------
        output_shape : the final output shape, e.g., shape_x = (2, 6),
            shape_y = (8, 2), trans_a = True, True_b = True, then,
            output_shape = (6, 8).

        x : the first input tensor according to shape_x.

        y : the second input tensor according to shape_y.

        K : the number of the axis for sum, in the above example, K = 2.

        trans_a : if True, x needs to be transposed.

        trans_b : if True, y needs to be transposed.

        *indices : the output shape space for tvm.compute.

        Returns
        -------
        tvm.Tensor
        """
        n_len = len(output_shape)
        k = tvm.reduce_axis((0, K), 'k')
        if trans_a is True and trans_b is False:
            # For example, A: (6, 7, 8), B: (6, 7, 9), so the length is n = 3
            # C = A' * B : (6, 8, 9), A' means the transpose of A
            # indices means the space of (6, 8, 9), k = 7
            # x_indices = indices[:1]+(7, )+indices[1:2] = (6, 7, 8)
            # y_indices = indices[:1]+(7, )+indices[2:] = (6, 7, 9)
            x_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 2):
                                                                (n_len - 1)]
            y_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 1):]
            return tvm.sum(x(*x_indices) * y(*y_indices), axis=k)
        elif not trans_a and trans_b:
            # For example, A: (6, 7, 8), B: (6, 9, 8), C = A * B' : (6, 7, 9)
            # indices means the space of (6, 7, 9), n=3, k = 8
            # x_indices = indices[:2]+(8, ) = (6, 7, 8)
            # y_indices = indices[:1]+indices[2:]+(8, ) = (6, 9, 8)
            x_indices = indices[:(n_len - 1)] + (k, )
            y_indices = indices[:(n_len - 2)] + indices[(n_len - 1):] + (k, )
            return tvm.sum(x(*x_indices) * y(*y_indices), axis=k)
        elif trans_a and trans_b:
            # For example, A: (6, 8, 10), B: (6, 12, 8), C = A' * B' : \
            # (6, 10, 12)
            # indices means the space of (6, 10, 12), n=3, k = 8
            # x_indices = indices[:1]+(8, )+indices[1:2] = (6, 8, 10)
            # y_indices = indices[:1]+indices[2:]+(8, ) = (6, 12, 8)
            x_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 2):
                                                                (n_len - 1)]
            y_indices = indices[:(n_len - 2)] + indices[(n_len - 1):] + (k, )
            return tvm.sum(x(*x_indices) * y(*y_indices), axis=k)
        else:
            # For example, A: (6, 15, 16), B: (6, 16, 18), C = A * B : \
            # (6, 15, 18)
            # indices means the space of (6, 15, 18), n=3, k = 16
            # x_indices = indices[:2]+(16, ) = (6, 15, 16)
            # y_indices = indices[:1]+(16, )+indices[2:] = (6, 16, 18)
            x_indices = indices[:(n_len - 1)] + (k, )
            y_indices = indices[:(n_len - 2)] + (k, ) + indices[(n_len - 1):]
            return tvm.sum(x(*x_indices) * y(*y_indices), axis=k)

    def check_supportted_shape_size(shape_x, shape_y, limit, trans_a, trans_b):
        """
        check shape size for operator
        ----------
        shape: shape of data

        limit: limit of the product

        Returns
        -------
        None
        """
        # This function is used to check whether the shape is too large to \
        # cause a timeout.
        # shape_x = (a,b,c,d,e,k)  shape_y = (a,b,c,d,k,f)
        # t_1 : time consumed by each addition operation
        # t_2 : time consumed by each multiplication operation
        # t_all : time consumed by a complete calculation
        # t_all is approximately equal to (a*b*c*d)*(e*k*f)*(t_1+t_2)
        # As (t_1 + t_2) is a constant, so t_all is proportional to \
        # (a * b * c * d * e * k * f)

        len_x = len(shape_x)
        len_y = len(shape_y)
        if (len_x < 2) or (len_y < 2):
            raise RuntimeError("Only tensors of rank>=2 are supported!")

        shape_x = list(shape_x)
        shape_y = list(shape_y)

        tmp_shape_x = shape_x[:]
        if trans_a:
            tmp_shape_x = shape_x[:-2] + [shape_x[-1], shape_x[-2]]

        tmp_shape_y = shape_y[:]
        if trans_b:
            tmp_shape_y = shape_y[:-2] + [shape_y[-1], shape_y[-2]]

        union_shape = tmp_shape_x + [tmp_shape_y[-1]]

        union_size = reduce(lambda i, j: i * j, union_shape)

        if union_size > limit:
            raise RuntimeError("the shape is too large to calculate")

    if data_dtype in ["float16", "float32", "int32"]:
        type_shape_map = {
            'float16': SHAPE_SIZE_FP16_LIMIT,
            'float32': SHAPE_SIZE_FP32_LIMIT,
            'int32': SHAPE_SIZE_INT32_LIMIT
        }

        check_supportted_shape_size(shape_x, shape_y,
                                    type_shape_map[data_dtype], trans_a,
                                    trans_b)

    x_size = len(shape_x)
    y_size = len(shape_y)
    shape_a = shape_x
    shape_b = shape_y
    if trans_a is True:
        shape_x = transpose_tensor(shape_x, x_size)

    if trans_b is True:
        shape_y = transpose_tensor(shape_y, y_size)

    check_matmul(shape_x, shape_y)
    last_axis = shape_x[x_size - 1]

    x_temp = tvm.placeholder(shape_a, name="input_1", dtype=data_dtype)
    y_temp = tvm.placeholder(shape_b, name="input_2", dtype=data_dtype)

    # output shape
    output_shape = ()
    for i in range(x_size - 1):
        output_shape = output_shape + (shape_x[i], )
    output_shape = output_shape + (shape_y[x_size - 1], )
    result = tvm.compute(
        output_shape,
        lambda *indices: _compute(output_shape, x_temp, y_temp, last_axis,
                                  trans_a, trans_b, *indices),
        name="result")
    schedule = tvm.create_schedule(result.op)

    if need_print:
        with build_config:
            print(
                tvm.lower(schedule, [x_temp, y_temp, result],
                          simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(schedule, [x_temp, y_temp, result],
                      "cce",
                      name=kernel_name)
Ejemplo n.º 29
0
def custom_log1p(shape,
                 dtype,
                 kernel_name="cce_tf_log1p",
                 need_build=False,
                 need_print=False):
    """
    calculate ln(1 + data), calculating data type is float16 or float32

    Parameters
    ----------
    shape : shape of data

    dtype : source data type, assume src_dtype equals dst_type, only support
        float16 or float32

    kernel_name : cce kernel name, default value is "cce_tf_log1p"

    need_buid : if need to build CCEC kernel, default value is False

    need_print : if need to print the ir, default value is False

    Returns
    -------
    None

    """

    supported_dtypes = ["float16", "float32"]
    device_api = "DeviceLog"

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape)
    util.check_shape_size(shape, SHAPE_SIZE_LIMIT)

    if not dtype.lower() in supported_dtypes:
        raise RuntimeError("tf_log1p_cce only support %s while dtype is %s" %
                           (",".join(supported_dtypes), dtype))

    inp_dtype = dtype.lower()
    shape = util.shape_refine(shape)
    data_input = tvm.placeholder(shape, name="data_input", dtype=inp_dtype)

    v_datatype = util.get_device_api_dtype(inp_dtype)
    v_ndim = len(shape)
    block_num = "block_num"
    block_idx = "block_idx"
    pad_c0 = 0
    p_scale = util.create_param_ptr([1], inp_dtype, "p_scale")
    p_shift = util.create_param_ptr([1], inp_dtype, "p_shift")
    p_base = util.create_param_ptr([-1], inp_dtype, "p_base")
    p_shape = util.create_param_ptr(shape, "int32", "p_shape")

    output = tvm.extern(
        shape,
        [data_input, p_scale, p_shift, p_base, p_shape],
        lambda ins, outs: tvm.call_extern(
            "int32_t",
            device_api,
            block_num,
            block_idx,
            v_datatype,
            ins[1].access_ptr("r"),  # scale
            ins[2].access_ptr("r"),  # shift
            ins[3].access_ptr("r"),  # base
            v_ndim,
            ins[4].access_ptr("r"),  # shape
            pad_c0,
            ins[0].access_ptr("r"),  # input x
            outs[0].access_ptr("w")),
        name="output",
        dtype=inp_dtype)

    schedule = tvm.create_schedule(output.op)

    if need_print:
        with build_config:
            print(tvm.lower(schedule, [data_input, output], simple_mode=True))
    if need_build:
        with build_config:
            tvm.build(schedule, [data_input, output], "cce", name=kernel_name)
Ejemplo n.º 30
0
def CusMatMulCubeFraczLeftCast(input_x1,
                               input_x2,
                               bias=None,
                               output_y={},
                               trans_a=False,
                               trans_b=False,
                               kernel_name="CusMatMulCubeFraczLeftCast"):
    """
    calculating  matrix multiplication with bias, C = A*B + bias, support input
    data with fractal format.

    Parameters:
    shape_a: list or tuple
            Shape of the first tensor a with rank > 1
    shape_b:  list or tuple
            Shape of the second tensor b with the same type with a,
            and shape_a, shape_b must be 2 dims
    src_dtype: str
            The data type of input, support "float32", "float16"
    dst_dtype: str
            The data type of output, support "float32", "float16"
    trans_a: bool
            If True, shape_a == transposed before multiplication
    trans_b: bool
            If True, shape_b == transposed before multiplication
    is_fractal: bool
            If True, the input data format of a and b must be fractal format
    shape_bias: list or tuple
            Shape of bias, only support the input data format with ND

    Returns
    -------
    None
    """
    shape_a = input_x1.get("ori_shape")
    shape_b = input_x2.get("ori_shape")
    print("============")
    print(input_x1.get("format"), input_x2.get("format"))
    print(shape_a, shape_b)
    print("============")
    if input_x2.get("format") == "FRACTAL_Z":
        n, c, h, w = shape_b
        c0 = 16
        c1 = c // c0
        if c1 == 0:
            c1 = 1
        shape_b = [n, c1 * h * w * c0]
        shape_a = [n, n]

    if input_x1.get("format") == "FRACTAL_Z":
        n, c, h, w = shape_a
        c0 = 16
        c1 = c // c0
        if c1 == 0:
            c1 = 1
        shape_a = [n, c1 * h * w * c0]
        shape_b = [c1 * h * w * c0, c1 * h * w * c0]

    if input_x2.get("format") == "FRACTAL_NZ":
        shape_a = [shape_b[0], shape_b[0]]
        shape_b = shape_b

    if input_x1.get("format") == "FRACTAL_NZ":
        shape_a = shape_a
        shape_b = [shape_a[1], shape_a[1]]

    shape_a = list(shape_a)
    shape_b = list(shape_b)

    shape_a = _get_input_shape(shape_a)
    shape_b = _get_input_shape(shape_b)

    util.check_kernel_name(kernel_name)
    util.check_shape_rule(shape_a)
    util.check_shape_rule(shape_b)
    util.check_shape_size(shape_a, SHAPE_SIZE_LIMIT)
    util.check_shape_size(shape_b, SHAPE_SIZE_LIMIT)

    shape_a = [shape_a[1], shape_a[0]]
    trans_a = bool(1 - trans_a)

    shape_b = [shape_b[1], shape_b[0]]
    trans_b = bool(1 - trans_b)

    shape_bias = ()
    if bias is not None and bool(bias):
        shape_bias = bias.get("shape")
        shape_bias = list(shape_bias)
        shape_bias = _get_bias(shape_bias)

    src_dtype = input_x1.get("dtype").lower()
    _shape_check(shape_a, shape_b, shape_bias, src_dtype, trans_a, trans_b)

    m_shape = shape_a[len(shape_a) - 2]
    km_shape = shape_a[len(shape_a) - 1]
    kn_shape = shape_b[len(shape_a) - 2]
    n_shape = shape_b[len(shape_a) - 1]

    if src_dtype == "float16":
        block_reduce = cce.BLOCK_REDUCE

    block_in = cce.BLOCK_IN
    block_out = cce.BLOCK_OUT

    if trans_a and km_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if not trans_a and m_shape == 1:
        block_in = cce.BLOCK_VECTOR

    if trans_b and kn_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if not trans_b and n_shape == 1:
        block_out = cce.BLOCK_VECTOR

    if trans_a:
        shape_a_temp = (m_shape // block_reduce, km_shape // block_in,
                        block_reduce, block_in)
    else:
        shape_a_temp = (m_shape // block_in, km_shape // block_reduce,
                        block_in, block_reduce)

    if trans_b:
        shape_b_temp = (kn_shape // block_out, n_shape // block_reduce,
                        block_reduce, block_out)
    else:
        shape_b_temp = (kn_shape // block_reduce, n_shape // block_out,
                        block_out, block_reduce)
    shape_a_temp = (shape_a_temp[0], shape_a_temp[1], shape_a_temp[2],
                    shape_a_temp[3])
    shape_b_temp = (shape_b_temp[0], shape_b_temp[1], shape_b_temp[2],
                    shape_b_temp[3])

    if util.get_product_version() == util.VERSION_MINI:
        tik_instance = tik.Tik(tik.Dprofile("v100", "mini"))
    else:
        tik_instance = tik.Tik(tik.Dprofile("v100", "cloud"))
    input_x1 = tik_instance.Tensor(input_x1.get("dtype"),
                                   shape_a_temp,
                                   name="left_matrix",
                                   scope=tik.scope_gm)
    input_x2 = tik_instance.Tensor(input_x2.get("dtype"),
                                   shape_b_temp,
                                   name="right_matrix",
                                   scope=tik.scope_gm)
    res_matmul = tik_instance.Tensor(output_y.get("dtype"),
                                     output_y.get("shape"),
                                     name="output",
                                     scope=tik.scope_gm)
    DIAG_SIZE = 128
    mo_tile, ko_tile, no_tile, diag_opt = get_cus_tile_info(
        input_x1, input_x2, DIAG_SIZE)
    cus_cube_matmul_cast(tik_instance,
                         input_x1,
                         trans_a,
                         input_x2,
                         trans_b,
                         res_matmul,
                         mo_tile=mo_tile,
                         ko_tile=ko_tile,
                         no_tile=no_tile,
                         diag_opt=diag_opt,
                         diag_size=DIAG_SIZE)
    tik_instance.BuildCCE(kernel_name=kernel_name,
                          inputs=[input_x1, input_x2],
                          outputs=[res_matmul])
    return tik_instance