Esempio n. 1
0
def op_select_format(input_values,
                     output_data,
                     axis,
                     kernel_name="concat_v2_d"):
    """
    select format dynamically
    """
    data_list = []

    datatype_4d = "float16,float,int32,int8,int16,int64,uint8,uint16," \
                  "uint32,uint64,bool"
    format_4d = "ND,ND,ND,ND,ND,ND,ND,ND,ND,ND,ND"

    # ND
    input0 = gen_param(classify="input0",
                       name="input_values",
                       datatype=datatype_4d,
                       format=format_4d)
    output0 = gen_param(classify="output0",
                        name="output_data",
                        datatype=datatype_4d,
                        format=format_4d)

    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 2
0
def op_select_format(x, filter, y, block_size, data_format, kernel_name="space_to_depth"):
    """
    select format dynamically
    """
    ori_shape_x = x.get("ori_shape")
    if ori_shape_x[3] <= 64:
        datatype = "float16, float, int8, uint8, int16, uint16, int32," \
                   "uint32, uint64, int64"
        input_format = "NC1HWC0, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
    else:
        datatype = "float16, float, int8, uint8, int16, uint16, int32," \
                   "uint32, uint64, int64"
        input_format = "NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
    input0 = gen_param(
        classify="input0",
        name="x",
        datatype=datatype,
        format=input_format)
    input1 = gen_param(classify="input1", name="filter",
                       datatype="float16, float16, float16, float16, float16, float16," \
                                "float16, float16, float16, float16",
                       format="FRACTAL_Z,FRACTAL_Z,FRACTAL_Z,FRACTAL_Z,FRACTAL_Z,FRACTAL_Z," \
                              "FRACTAL_Z,FRACTAL_Z,FRACTAL_Z,FRACTAL_Z")
    output0 = gen_param(
        classify="output0",
        name="y",
        datatype=datatype,
        format=input_format)

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 3
0
def op_select_format(x,
                     sum,
                     square_sum,
                     num_groups,
                     kernel_name="gn_training_reduce"):
    """
    select format dynamically
    """
    input0 = gen_param(classify="input0",
                       name="x",
                       datatype="float16,float,float16,float",
                       format="NCHW,NHWC,NCHW,NHWC")
    output0 = gen_param(classify="output0",
                        name="sum",
                        datatype="float,float,float,float",
                        format="ND,ND,ND,ND")
    output1 = gen_param(classify="output1",
                        name="square_sum",
                        datatype="float,float,float,float",
                        format="ND,ND,ND,ND")

    param_list = [input0, output0, output1]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 4
0
def op_select_format(grad,
                     x1,
                     x2,
                     y,
                     axis,
                     keepdims,
                     kernel_name="softmax_grad_ext"):
    """
    select format dynamically
    """
    origin_shape0 = util.scalar2tensor_one(grad.get("ori_shape"))
    origin_shape1 = util.scalar2tensor_one(x1.get("ori_shape"))
    origin_shape2 = util.scalar2tensor_one(x2.get("ori_shape"))

    condition_0 = len(origin_shape2) == 1 and origin_shape2[0] == 1
    condition_1 = _division_sixteen(origin_shape0)
    condition_2 = _division_sixteen(origin_shape1)

    if condition_0 and condition_1 and condition_2:
        # NZ + NZ + Scalar
        input0 = gen_param(classify="input0",
                           name="grad",
                           datatype="float16,float",
                           format="FRACTAL_NZ, FRACTAL_NZ")
        input1 = gen_param(classify="input1",
                           name="x1",
                           datatype="float16,float",
                           format="FRACTAL_NZ, FRACTAL_NZ")
        input2 = gen_param(classify="input2",
                           name="x2",
                           datatype="float16,float",
                           format="ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float",
                            format="FRACTAL_NZ,FRACTAL_NZ")
    else:
        # ND+ND+ND
        input0 = gen_param(classify="input0",
                           name="grad",
                           datatype="float16,float",
                           format="ND,ND")
        input1 = gen_param(classify="input1",
                           name="x1",
                           datatype="float16,float",
                           format="ND,ND")
        input2 = gen_param(classify="input2",
                           name="x2",
                           datatype="float16,float",
                           format="ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float",
                            format="ND,ND")

    param_list = [input0, input1, input2, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 5
0
def op_select_format(input_tensor,
                     input_mask,
                     input_keep_prob,
                     output,
                     kernel_name="dropout_do_mask"):
    """
    _division_sixteen : judge whether the last two dimensions are divided by 16
    scalar2tensor_one : convert scalar to tensor
    """
    shape_0 = input_tensor.get("ori_shape")
    shape_1 = input_mask.get("ori_shape")
    shape_2 = input_keep_prob.get("ori_shape")

    shape_0 = util.scalar2tensor_one(shape_0)
    shape_1 = util.scalar2tensor_one(shape_1)
    shape_2 = util.scalar2tensor_one(shape_2)

    if _division_sixteen(shape_0) and not _division_sixteen(
            shape_1) and not _division_sixteen(shape_2):
        # Nz+ND+ND
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float16,float,float",
                           format="ND,FRACTAL_NZ,ND,FRACTAL_NZ")
        input1 = gen_param(classify="input1",
                           name="mask",
                           datatype="uint8,uint8,uint8,uint8",
                           format="ND,ND,ND,ND")
        input2 = gen_param(classify="input2",
                           name="keep_prob",
                           datatype="float16,float16,float,float",
                           format="ND,ND,ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float,float",
                            format="ND,FRACTAL_NZ,ND,FRACTAL_NZ")
    else:
        # ND+ND
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float",
                           format="ND,ND")
        input1 = gen_param(classify="input1",
                           name="mask",
                           datatype="uint8,uint8",
                           format="ND,ND")
        input2 = gen_param(classify="input2",
                           name="keep_prob",
                           datatype="float16,float",
                           format="ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float",
                            format="ND,ND")

    param_list = [input0, input1, input2, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 6
0
def op_select_format(in_dic,
                     filter_dic,
                     out_dic,
                     stride,
                     reverse,
                     kernel_name="pass_through"):
    """
    select format dynamically
    """
    product_version = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION")
    if len(filter_dic['shape']) != 0:
        if product_version in ("Hi3796CV300ES", "Hi3796CV300CS"):
            dtype0 = "float16, int8, uint8, int16, uint16, int32, uint32, int64, uint64"
            dtype1 = "float16, float16, float16, float16, float16, float16, float16, float16, float16"
            dformat0 = "NC1HWC0, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
            dformat1 = "FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, \
                        FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z"

        else:
            dtype0 = "float16, float, int8, uint8, int16, uint16, int32, uint32, int64, uint64"
            dtype1 = "float16, float16, float16, float16, float16, float16, float16, float16, float16, float16"
            dformat0 = "NC1HWC0, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
            dformat1 = "FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, \
                        FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z"

    else:
        if product_version in ("Hi3796CV300ES", "Hi3796CV300CS"):
            dtype0 = "float16, int8, uint8, int16, uint16, int32, uint32, int64, uint64"
            dtype1 = "float16, float16, float16, float16, float16, float16, float16, float16, float16"
            dformat0 = "NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
            dformat1 = "FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, \
                        FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z"

        else:
            dtype0 = "float16, float, int8, uint8, int16, uint16, int32, uint32, int64, uint64"
            dtype1 = "float16, float16, float16, float16, float16, float16, float16, float16, float16, float16"
            dformat0 = "NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC, NHWC"
            dformat1 = "FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, \
                        FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z, FRACTAL_Z"

    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype0,
                       format=dformat0)
    input1 = gen_param(classify="input1",
                       name="filter",
                       datatype=dtype1,
                       format=dformat1)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype0,
                        format=dformat0)

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 7
0
def op_select_format(x,
                     y,
                     perm,
                     shape,
                     transpose_first,
                     kernel_name="confusion_transpose_d"):
    """
    _division_sixteen : judge whether the last two dimensions are divided by 16
    scalar2tensor_one : convert scalar to tensor
    """
    condition = _condition(x, perm, shape, transpose_first)

    if condition:
        # NZ+ND
        input0 = gen_param(
            classify="input0",
            name="x",
            datatype="float16,float,int8,int16,int32,int64,"
            "uint8,uint16,uint32,uint64,"
            "float16,float,int8,int16,int32,int64,"
            "uint8,uint16,uint32,uint64",
            format="FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,"
            "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,"
            "FRACTAL_NZ,FRACTAL_NZ,"
            "ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float,int8,int16,int32,int64,"
                            "uint8,uint16,uint32,uint64,"
                            "float16,float,int8,int16,int32,int64,"
                            "uint8,uint16,uint32,uint64",
                            format="FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,"
                            "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,"
                            "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,"
                            "FRACTAL_NZ,"
                            "ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")
    else:
        # ND+ND
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float,int8,int16,int32,int64,"
                           "uint8,uint16,uint32,uint64",
                           format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float,int8,int16,int32,int64,"
                            "uint8,uint16,uint32,uint64",
                            format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")

    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 8
0
def op_select_format(x, size, y, axis=2, offsets=(0), kernel_name="crop"):
    """
    select format dynamically
    """
    dtype_base = [
        "float16", "float", "int32", "int8", "int16", "int64", "uint8",
        "uint16", "uint32", "uint64"
    ]
    dtype_lhisi = [
        "float16", "int32", "int8", "int16", "int64", "uint8", "uint16",
        "uint32", "uint64"
    ]

    ori_format = x.get("ori_format").upper()
    ori_shape = x.get("ori_shape")

    dtype_out = dtype_base
    cce_product = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION")
    if cce_product in ("Hi3796CV300ES", "Hi3796CV300CS"):
        dtype_out = dtype_lhisi

    if axis < 0:
        axis = axis + len(ori_shape)

    format_out = ["ND"] * len(dtype_out)
    if ori_format == "NCHW" and len(ori_shape) == 4 and axis >= 2:
        format_out = format_out + ["NC1HWC0"] * len(dtype_out)
        dtype_out = dtype_out + dtype_out

    dtype_str = ','.join(dtype_out)
    format_str = ','.join(format_out)

    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype_str,
                       format=format_str)
    input1 = gen_param(classify="input1",
                       name="size",
                       datatype=dtype_str,
                       format=format_str)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype_str,
                        format=format_str)
    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 9
0
def op_select_format(x, y, num, axis, kernel_name="unpack"):
    """
    unpacks the given dimension of a rank R tensor into rank (R-1) tensors.
    1. when unpack by C, but output size not C0 align so don't support NC1HWC0
    2. when split_d by N,H,W, support NC1HWC0
    """
    support_ori_format = ["NCHW", "NHWC"]

    # all output attributes are consistent
    ori_format = x.get("ori_format").upper()
    ori_shape = x.get("ori_shape")
    axis = axis % len(ori_shape)

    is_support_5hd = False
    if ori_format in support_ori_format and len(
            ori_shape) == 4 and ori_format[axis] != "C":
        is_support_5hd = True

    dtype_base = [
        "float16", "float", "int32", "int8", "int16", "int64", "uint8",
        "uint16", "uint32", "uint64"
    ]

    dtype_base_out = dtype_base.copy()
    format_base_out = ["ND"] * len(dtype_base)

    if is_support_5hd:
        dtype_base_out = dtype_base_out + dtype_base
        format_base_out = format_base_out + ["NC1HWC0"] * len(format_base_out)

    dtype_str = ','.join(dtype_base_out)
    format_str = ','.join(format_base_out)

    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype_str,
                       format=format_str)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype_str,
                        format=format_str)
    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 10
0
def op_select_format(x, sum, square_sum, gamma, beta, mean, variance,
                     y, batch_mean, batch_variance,
                     momentum, epsilon,
                     kernel_name="in_training_update_v2"):
    """
    select format dynamically
    """
    input0 = gen_param(classify="input0", name="x",
                       datatype="float16,float",
                       format="NC1HWC0,NC1HWC0")
    input1 = gen_param(classify="input1", name="sum",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    input2 = gen_param(classify="input2", name="square_sum",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    input3 = gen_param(classify="input3", name="gamma",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    input4 = gen_param(classify="input4", name="beta",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    input5 = gen_param(classify="input5", name="mean",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    input6 = gen_param(classify="input6", name="variance",
                       datatype="float,float",
                       format="NC1HWC0,NC1HWC0")
    output0 = gen_param(classify="output0", name="y",
                        datatype="float16,float",
                        format="NC1HWC0,NC1HWC0")
    output1 = gen_param(classify="output1", name="batch_mean",
                        datatype="float,float",
                        format="NC1HWC0,NC1HWC0")
    output2 = gen_param(classify="output2", name="batch_variance",
                        datatype="float,float",
                        format="NC1HWC0,NC1HWC0")

    param_list = [input0, input1, input2, input3, input4, input5, input6,
                  output0, output1, output2]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 11
0
def op_select_format(x, y, split_dim, num_split, kernel_name="split_d"):
    """
    select format dynamically
    """
    dtype = "float16, float, int32, int8, int16, int64, uint8, uint16, uint32, uint64"
    input_format = "ND, ND, ND, ND, ND, ND, ND, ND, ND, ND"

    # ND
    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype,
                       format=input_format)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype,
                        format=input_format)

    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 12
0
def op_select_format(x, sum, square_sum, kernel_name="bn_training_reduce"):
    """
    select format dynamically
    """
    origin_format = x.get("ori_format").upper()
    origin_shape = x.get("ori_shape")

    # can support Nz + ND
    if origin_format == "NCHW" and len(origin_shape) == 4 \
            and origin_shape[0] == 1 and origin_shape[2] == 1:
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float,float16,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0",
                            name="sum",
                            datatype="float,float,float,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output1 = gen_param(classify="output1",
                            name="square_sum",
                            datatype="float,float,float,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
    # support 5HD + 5HD
    else:
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float",
                           format="NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0",
                            name="sum",
                            datatype="float,float",
                            format="NC1HWC0,NC1HWC0")
        output1 = gen_param(classify="output1",
                            name="square_sum",
                            datatype="float,float",
                            format="NC1HWC0,NC1HWC0")

    param_list = [input0, output0, output1]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 13
0
def op_select_format(input_x, input_y, bias=None, output_z={}, trans_a=False,
                     trans_b=False, kernel_name="matmul"):
    """
    provide dynamic format to FE
    """
    src_dtype = input_x.get("dtype")

    if src_dtype == "float16":
        input0 = gen_param(classify="input0", name="x1",
                           datatype="float16",
                           format="FRACTAL_NZ")
        input1 = gen_param(classify="input1", name="x2",
                           datatype="float16",
                           format="FRACTAL_NZ")
        input2 = gen_param(classify="input2", name="bias",
                           datatype="float16",
                           format="ND")
        output0 = gen_param(classify="output0", name="y",
                            datatype="float16",
                            format="FRACTAL_NZ")
    else:
        input0 = gen_param(classify="input0", name="x1",
                           datatype="float16,float,float,int32,int32",
                           format="FRACTAL_NZ,NHWC,ND,NHWC,ND")
        input1 = gen_param(classify="input1", name="x2",
                           datatype="float16,float,float,int32,int32",
                           format="FRACTAL_NZ,NHWC,ND,NHWC,ND")
        input2 = gen_param(classify="input2", name="bias",
                           datatype="float16,float,float,int32,int32",
                           format="ND,NHWC,ND,NHWC,ND")
        output0 = gen_param(classify="output0", name="y",
                            datatype="float16,float,float,int32,int32",
                            format="FRACTAL_NZ,NHWC,ND,NHWC,ND")

    param_list = [input0, input1, input2, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 14
0
def op_select_format(x,
                     y,
                     ksize,
                     strides,
                     padding="SAME",
                     pads=(0, 0, 0, 0, 0, 0),
                     dilation=(1, 1, 1),
                     ceil_mode=0,
                     data_format="NDHWC",
                     kernel_name="max_pool3d"):
    """
    max_pool3d ops not performance optimazation yet ,use this function to
    support covid_19 scenario.
    when performance optimazation is done, delete this function
    """

    if _is_covid_19(x.get("ori_shape"), ksize, strides, data_format):
        input0_r = gen_param(classify="input0",
                             name="x",
                             datatype="float16",
                             format="NDHWC")
        output0_r = gen_param(classify="output0",
                              name="y",
                              datatype="float16",
                              format="NDHWC")
    else:
        input0_r = gen_param(classify="input0",
                             name="x",
                             datatype="float16",
                             format="NDC1HWC0")
        output0_r = gen_param(classify="output0",
                              name="y",
                              datatype="float16",
                              format="NDC1HWC0")

    param_list = [input0_r, output0_r]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 15
0
def op_select_format(input_x,
                     output_y,
                     tiles,
                     axis=1,
                     kernel_name="tile_with_axis"):
    """
    select format dynamically
    """
    ori_format = input_x.get("ori_format")
    ori_shape = input_x.get("ori_shape")

    if ori_shape is not None:
        axis = util.axis_check(len(ori_shape), axis)

    cce_product = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION")

    # for 5hd, axis is only valid for n,h,w
    if ((ori_format == "NHWC" and axis != 3) or (ori_format == "NCHW" and axis != 1)) and \
            len(ori_shape) == 4:
        # NC1HWC0+ND
        if cce_product in ("Hi3796CV300ES", "Hi3796CV300CS"):
            # fp16
            input0 = gen_param(
                classify="input0",
                name="x",
                datatype=
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64,"
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,"
                "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
            )
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype=
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64,"
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,"
                "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
            )
        else:
            # fp16/fp32
            input0 = gen_param(
                classify="input0",
                name="x",
                datatype=
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64,"
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND,NC1HWC0,"
                "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
            )
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype=
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64,"
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND,NC1HWC0,"
                "NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0"
            )
    else:
        # ND
        if cce_product in ("Hi3796CV300ES", "Hi3796CV300CS"):
            # fp16
            input0 = gen_param(
                classify="input0",
                name="x",
                datatype=
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND")
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype=
                "float16,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND")
        else:
            # fp16/fp32
            input0 = gen_param(
                classify="input0",
                name="x",
                datatype=
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype=
                "float16,float32,int8,int16,int32,int64,uint8,uint16,uint32,uint64",
                format="ND,ND,ND,ND,ND,ND,ND,ND,ND,ND")

    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 16
0
def op_select_format(
        input_x1,
        input_x2,  # pylint: disable=too-many-arguments
        alpha,
        beta,
        bias=None,
        output_y=None,
        trans_a=False,
        trans_b=False,
        kernel_name="gemm"):
    """
    select format dynamically
    """
    def _select_format(params):
        input_x1 = params[0]
        input_x2 = params[1]
        shape_b = input_x2.get("ori_shape")
        format_a = input_x1.get("format")
        format_b = input_x2.get("format")
        format_c = bias.get("format")
        need_transdata = False
        if set([format_a, format_b, format_c]) & \
                set(["FRACTAL_NZ", "FRACTAL_Z"]):
            need_transdata = True
        else:
            if trans_b:
                b_n = shape_b[0]
            else:
                b_n = shape_b[1]
            if b_n % cce.BLOCK_OUT != 0:
                need_transdata = True

        if need_transdata:
            input0 = gen_param(
                classify="input0",
                name="a",
                datatype="float16,float16,int8,int8",
                format="FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ",
            )
            input1 = gen_param(
                classify="input1",
                name="b",
                datatype="float16,float16,int8,int8",
                format="FRACTAL_NZ,FRACTAL_NZ,FRACTAL_Z,FRACTAL_Z",
            )
            input2 = gen_param(
                classify="input2",
                name="c",
                datatype="float32,float16,int32,float32",
                format="FRACTAL_NZ,FRACTAL_NZ,ND,FRACTAL_NZ",
            )
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype="float32,float16,int32,float32",
                format="FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ",
            )
        else:
            input0 = gen_param(
                classify="input0",
                name="a",
                datatype="float16,float16,int8,int8",
                format="ND,ND,ND,ND",
            )
            input1 = gen_param(
                classify="input1",
                name="b",
                datatype="float16,float16,int8,int8",
                format="ND,ND,ND,ND",
            )
            input2 = gen_param(
                classify="input2",
                name="c",
                datatype="float32,float16,int32,float32",
                format="ND,ND,ND,ND",
            )
            output0 = gen_param(
                classify="output0",
                name="y",
                datatype="float32,float16,int32,float32",
                format="ND,ND,ND,ND",
            )
        input3 = gen_param(
            classify="input3",
            name="alpha",
            datatype="float32,float16,int32,float32",
            format="ND,ND,ND,ND",
        )
        input4 = gen_param(
            classify="input4",
            name="beta",
            datatype="float32,float16,int32,float32",
            format="ND,ND,ND,ND",
        )
        return [input0, input1, input2, input3, input4, output0]

    params = [
        input_x1, input_x2, alpha, beta, bias, output_y, trans_a, trans_b,
        kernel_name
    ]
    param_list = _select_format(params)
    return get_dynamic_param_in_json(param_list)
Esempio n. 17
0
def op_select_format(x,
                     sum,
                     square_sum,
                     scale,
                     offset,
                     y,
                     batch_mean,
                     batch_variance,
                     epsilon,
                     kernel_name="bn_training_update_v2"):
    """
    select format dynamically
    """
    origin_format = x.get("ori_format").upper()
    origin_shape = x.get("ori_shape")

    # can support Nz + ND
    if origin_format == "NCHW" and len(origin_shape) == 4 \
            and origin_shape[0] == 1 and origin_shape[2] == 1:
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float,float16,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input1 = gen_param(classify="input1",
                           name="sum",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input2 = gen_param(classify="input2",
                           name="square_sum",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input3 = gen_param(classify="input3",
                           name="scale",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input4 = gen_param(classify="input4",
                           name="offset",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float,float16,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output1 = gen_param(classify="output1",
                            name="batch_mean",
                            datatype="float,float,float,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output2 = gen_param(classify="output2",
                            name="batch_variance",
                            datatype="float,float,float,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
    # support 5HD + 5HD
    else:
        input0 = gen_param(classify="input0",
                           name="x",
                           datatype="float16,float",
                           format="NC1HWC0,NC1HWC0")
        input1 = gen_param(classify="input1",
                           name="sum",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input2 = gen_param(classify="input2",
                           name="square_sum",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input3 = gen_param(classify="input3",
                           name="scale",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input4 = gen_param(classify="input4",
                           name="offset",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float",
                            format="NC1HWC0,NC1HWC0")
        output1 = gen_param(classify="output1",
                            name="batch_mean",
                            datatype="float,float",
                            format="NC1HWC0,NC1HWC0")
        output2 = gen_param(classify="output2",
                            name="batch_variance",
                            datatype="float,float",
                            format="NC1HWC0,NC1HWC0")

    param_list = [
        input0, input1, input2, input3, input4, output0, output1, output2
    ]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 18
0
def op_select_format(x, bias, y, data_format="NHWC",
                     kernel_name="bias_add"):
    """
    select format dynamically
    """
    shape_bias = bias.get("shape")
    ori_shape_x = x.get("ori_shape")
    c0 = 16
    if len(ori_shape_x) <= 4:
        if shape_bias[0] % c0 == 0 and len(ori_shape_x) == 4:
            # NC1HWC0+ND NCHW+NCHW NHWC+NHWC
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float,int32,float16,float,"
                                        "int32,float16,float",
                               format="NC1HWC0,NC1HWC0,NCHW,NCHW,NCHW,NHWC,"
                                      "NHWC,NHWC")
            input1 = gen_param(classify="input1", name="bias",
                               datatype="float16,float,int32,float16,float,"
                                        "int32,float16,float",
                               format="ND,ND,NCHW,NCHW,NCHW,NHWC,NHWC,NHWC")
            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float,int32,float16,float,"
                                         "int32,float16,float",
                                format="NC1HWC0,NC1HWC0,NCHW,NCHW,NCHW,NHWC,"
                                       "NHWC,NHWC")
        elif shape_bias[0] % c0 != 0 and len(ori_shape_x) == 4:
            # NC1HWC0+NC1HWC0 NCHW+NCHW NHWC+NHWC
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float,int32,float16,float,"
                                        "int32,float16,float",
                               format="NC1HWC0,NC1HWC0,NCHW,NCHW,NCHW,NHWC,"
                                      "NHWC,NHWC")
            input1 = gen_param(classify="input1", name="bias",
                               datatype="float16,float,int32,float16,float,"
                                        "int32,float16,float",
                               format="NC1HWC0,NC1HWC0,NCHW,NCHW,NCHW,NHWC,"
                                      "NHWC,NHWC")
            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float,int32,float16,float,"
                                         "int32,float16,float",
                                format="NC1HWC0,NC1HWC0,NCHW,NCHW,NCHW,NHWC,"
                                       "NHWC,NHWC")
        else:
            # NCHW+NCHW NHWC+NHWC
            input0 = gen_param(classify="input0", name="x",
                               datatype="int32,float16,float,int32,float16,"
                                        "float",
                               format="NCHW,NCHW,NCHW,NHWC,NHWC,NHWC")
            input1 = gen_param(classify="input1", name="bias",
                               datatype="int32,float16,float,int32,float16,"
                                        "float",
                               format="NCHW,NCHW,NCHW,NHWC,NHWC,NHWC")
            output0 = gen_param(classify="output0", name="y",
                                datatype="int32,float16,float,int32,float16,"
                                         "float",
                                format="NCHW,NCHW,NCHW,NHWC,NHWC,NHWC")
    else:
        if shape_bias[0] % c0 == 0:
            # NDHWC+NDHWC NCDHW+NCDHW NDC1HWC0+NDC1HWC0
            input0 = gen_param(classify="input0", name="x",
                               datatype="int32,float16,float,int32,float16,"
                                        "float,int32,float16,float",
                               format="NDHWC,NDHWC,NDHWC,NCDHW,NCDHW,"
                                      "NCDHW,NDC1HWC0,NDC1HWC0,NDC1HWC0")
            input1 = gen_param(classify="input1", name="bias",
                               datatype="int32,float16,float,int32,float16,"
                                        "float,int32,float16,float",
                               format="ND,ND,ND,ND,ND,ND,ND,ND,ND")
            output0 = gen_param(classify="output0", name="y",
                                datatype="int32,float16,float,int32,float16,"
                                        "float,int32,float16,float",
                                format="NDHWC,NDHWC,NDHWC,NCDHW,NCDHW,"
                                       "NCDHW,NDC1HWC0,NDC1HWC0,NDC1HWC0")

        else:
            # NDHWC+NDHWC NCDHW+NCDHW
            input0 = gen_param(classify="input0", name="x",
                               datatype="int32,float16,float,int32,float16,"
                                        "float",
                               format="NDHWC,NDHWC,NDHWC,NCDHW,NCDHW,NCDHW")
            input1 = gen_param(classify="input1", name="bias",
                               datatype="int32,float16,float,int32,float16,"
                                        "float",
                               format="ND,ND,ND,ND,ND,ND")
            output0 = gen_param(classify="output0", name="y",
                                datatype="int32,float16,float,int32,float16,"
                                         "float",
                                format="NDHWC,NDHWC,NDHWC,NCDHW,NCDHW,NCDHW")

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 19
0
def op_select_format(input0,
                     input1,
                     input2,
                     output,
                     kernel_name="fused_mul_add"):
    """
    _division_sixteen : judge whether the last two dimensions are divided by 16
    scalar2tensor_one : convert scalar to tensor
    """
    shape_0 = input0.get("ori_shape")
    shape_1 = input1.get("ori_shape")
    shape_2 = input2.get("ori_shape")

    shape_0 = util.scalar2tensor_one(shape_0)
    shape_1 = util.scalar2tensor_one(shape_1)
    shape_2 = util.scalar2tensor_one(shape_2)

    if _division_sixteen(shape_0) and not _division_sixteen(shape_1) \
            and not _division_sixteen(shape_2):
        # Nz+ND+ND
        input0 = gen_param(classify="input0",
                           name="x1",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
        input1 = gen_param(classify="input1",
                           name="x2",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        input2 = gen_param(classify="input2",
                           name="x3",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float16,float16,float16,\
                                      float,float,float,float,float,\
                                      int32,int32,int32,int32,int32",
                            format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")

    elif _division_sixteen(shape_0) and not _division_sixteen(shape_1) \
            and _division_sixteen(shape_2):
        # Nz+ND+Nz
        input0 = gen_param(classify="input0",
                           name="x1",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
        input1 = gen_param(classify="input1",
                           name="x2",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        input2 = gen_param(classify="input2",
                           name="x3",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float16,float16,float16,\
                                      float,float,float,float,float,\
                                      int32,int32,int32,int32,int32",
                            format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")

    elif not _division_sixteen(shape_0) and _division_sixteen(shape_1) \
            and not _division_sixteen(shape_2):
        # ND+NZ+ND
        input0 = gen_param(classify="input0",
                           name="x1",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        input1 = gen_param(classify="input1",
                           name="x2",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
        input2 = gen_param(classify="input2",
                           name="x3",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float16,float16,float16,\
                                      float,float,float,float,float,\
                                      int32,int32,int32,int32,int32",
                            format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")

    elif not _division_sixteen(shape_0) and not _division_sixteen(shape_1) \
            and _division_sixteen(shape_2):
        # ND+ND+NZ
        input0 = gen_param(classify="input0",
                           name="x1",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        input1 = gen_param(classify="input1",
                           name="x2",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,ND")
        input2 = gen_param(classify="input2",
                           name="x3",
                           datatype="float16,float16,float16,float16,float16,\
                                     float,float,float,float,float,\
                                     int32,int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                   NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float16,float16,float16,\
                                      float,float,float,float,float,\
                                      int32,int32,int32,int32,int32",
                            format="NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ,\
                                    NCHW,NC1HWC0,NHWC,ND,FRACTAL_NZ")
    else:
        # ND+ND
        input0 = gen_param(classify="input0",
                           name="x1",
                           datatype="float16,float16,float16,float16,\
                                     float,float,float,float,\
                                     int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND")
        input1 = gen_param(classify="input1",
                           name="x2",
                           datatype="float16,float16,float16,float16,\
                                     float,float,float,float,\
                                     int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND")
        input2 = gen_param(classify="input2",
                           name="x3",
                           datatype="float16,float16,float16,float16,\
                                     float,float,float,float,\
                                     int32,int32,int32,int32",
                           format="NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND")
        output0 = gen_param(classify="output0",
                            name="y",
                            datatype="float16,float16,float16,float16,\
                                      float,float,float,float,\
                                      int32,int32,int32,int32",
                            format="NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND,\
                                   NCHW,NC1HWC0,NHWC,ND")

    param_list = [input0, input1, input2, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 20
0
def op_select_format(input_x, input_y, output_z, kernel_name="add"):
    """
   select format dynamically
   """

    def _can_division_sixteen(shape):
        if shape[-1] == 0 or shape[-2] == 0:
            raise RuntimeError("value of shape is illegal")

        if shape[-1] % SIZE_SIXTEEN == 0 and shape[-2] % SIZE_SIXTEEN == 0:
            return True

        return False

    shape_x = input_x.get("ori_shape")
    shape_y = input_y.get("ori_shape")

    shape_x = util.scalar2tensor_one(shape_x)
    shape_y = util.scalar2tensor_one(shape_y)

    format_4d_list = ["NCHW", "NHWC", "HWCN"]
    cce_product = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION")
    if cce_product in ("Hi3796CV300ES", "Hi3796CV300CS"):
        dtype_list = ["float16", "int32"]
    else:
        dtype_list = ["float16", "float32", "int32"]

    format_x = input_x.get("ori_format")
    format_y = input_y.get("ori_format")

    dtype_total = []
    format_nd = ["ND"]
    format_list = ["ND"]
    format_nz = ["FRACTAL_NZ"]
    len_format_list = len(dtype_list)

    # if shape is same, then all formats are supported.
    if list(shape_x) == list(shape_y):
        format_list = ["ND", "FRACTAL_NZ", "NC1HWC0", "FRACTAL_Z", "C1HWNCoC0"]
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len(dtype_list)
        input0 = gen_param(classify="input0", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        input1 = gen_param(classify="input1", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list))

        param_list = [input0, input1, output0]
        param_dynamic_in_json = get_dynamic_param_in_json(param_list)
        return param_dynamic_in_json

    if len(shape_x) == 4 and len(shape_y) == 4 and \
            format_x in format_4d_list and format_y in format_4d_list:
        x_cdim = shape_x[format_x.index("C")]
        x_wdim = shape_x[format_x.index("W")]
        x_hdim = shape_x[format_x.index("H")]
        x_ndim = shape_x[format_x.index("N")]
        y_cdim = shape_y[format_y.index("C")]
        y_wdim = shape_y[format_y.index("W")]
        y_hdim = shape_y[format_y.index("H")]
        y_ndim = shape_y[format_y.index("N")]
    if (len(shape_y) == 1 and shape_y[0] == 1 and len(shape_x) == 4) and \
            format_x in format_4d_list:
        x_cdim = shape_x[format_x.index("C")]
        x_ndim = shape_x[format_x.index("N")]
    if (len(shape_x) == 1 and shape_x[0] == 1 and len(shape_y) == 4) and \
            format_y in format_4d_list:
        y_cdim = shape_y[format_y.index("C")]
        y_ndim = shape_y[format_y.index("N")]

    # ND+ND NZ+NZ 5HD+5HD FZ+FZ
    if len(shape_x) >= 2 and len(shape_y) >= 2 and shape_x[-2:] == shape_y[-2:]:
        format_list.append("FRACTAL_NZ")
        if len(shape_x) == 4 and len(shape_y) == 4 and \
                format_x in format_4d_list and format_y in format_4d_list:
            if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                if format_x == format_y == "NCHW" and \
                        (x_cdim == y_cdim or x_cdim // 16 == 1 or y_cdim // 16 == 1) and \
                        (x_ndim == y_ndim or x_ndim == 1 or y_ndim == 1):
                    format_list.append("NC1HWC0")
                if format_x == format_y == "HWCN":
                    if x_hdim == y_hdim and (x_wdim == 1 or y_wdim == 1):
                        format_list.append("NC1HWC0")
                    if x_wdim == y_wdim and (x_hdim == 1 or y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if x_wdim == y_wdim and x_hdim == y_hdim:
                        format_list.append("NC1HWC0")
                    if (x_wdim == x_hdim == 1) or (y_hdim == y_wdim == 1):
                        format_list.append("NC1HWC0")
                    if (x_hdim == y_wdim == 1) or (x_wdim == y_hdim == 1):
                        format_list.append("NC1HWC0")
                if format_x == format_y == "NHWC":
                    if x_hdim == y_hdim and (x_ndim == 1 or y_ndim == 1):
                        format_list.append("NC1HWC0")
                    if x_ndim == y_ndim and (x_hdim == 1 or y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if x_ndim == y_ndim and x_hdim == y_hdim:
                        format_list.append("NC1HWC0")
                    if (x_ndim == x_hdim == 1) or (y_ndim == y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if (x_ndim == 1 and y_hdim == 1) or (x_hdim == 1 and y_ndim == 1):
                        format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and y_cdim % 16 == 0 and \
                    y_ndim % 16 == 0 and x_ndim % 16 == 0:
                if (format_x == format_y == "NHWC" and list(shape_x) == list(shape_y)) \
                        or (format_x == format_y == "NCHW" and list(shape_x) == list(shape_y)):
                    format_list.append("FRACTAL_Z")
                if format_x == format_y == "HWCN" and \
                        x_wdim * x_hdim == y_wdim * y_hdim:
                    format_list.append("FRACTAL_Z")
            if list(shape_x) == list(shape_y):
                format_list.append("NC1HWC0")
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        input0 = gen_param(classify="input0", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        input1 = gen_param(classify="input1", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list))

    # NZ+ND,ND+ND,5HD+5HD,FZ+FZ,ND+NZ
    elif len(shape_x) >= 2 and len(shape_y) >= 2 and \
            ((_can_division_sixteen(shape_x) and
              not _can_division_sixteen(shape_y)) or
             (not _can_division_sixteen(shape_x) and
              _can_division_sixteen(shape_y))):
        if len(shape_x) == 4 and len(shape_y) == 4 and \
                format_x in format_4d_list and format_y in format_4d_list:
            if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                if x_cdim == y_cdim or x_cdim // 16 == 1 or y_cdim // 16 == 1:
                    format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and x_ndim % 16 == 0 and \
                    y_cdim % 16 == 0 and y_ndim % 16 == 0:
                if format_x == format_y == "NCHW" and \
                        x_hdim * x_wdim == y_hdim * y_wdim and x_cdim == y_cdim:
                    if x_ndim == y_ndim:
                        format_list.append("FRACTAL_Z")
                    if (x_ndim // 16 == 1 and y_ndim % 16 == 0) or \
                            (y_ndim // 16 == 1 and x_ndim % 16 == 0):
                        format_list.append("FRACTAL_Z")
                if format_x == format_y == "NHWC" and \
                        x_hdim * x_wdim == y_hdim * y_wdim and \
                        x_ndim == y_ndim and x_cdim == y_cdim:
                    format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nz * len_format_list
        format_list1 = format_list + format_nd * len_format_list
        if _can_division_sixteen(shape_x) and not _can_division_sixteen(shape_y):
            input0 = gen_param(classify="input0", name="x1",
                               datatype=",".join(dtype_total),
                               format=",".join(format_list0))
            input1 = gen_param(classify="input1", name="x2",
                               datatype=",".join(dtype_total),
                               format=",".join(format_list1))
            output0 = gen_param(classify="output0", name="y",
                                datatype=",".join(dtype_total),
                                format=",".join(format_list0))
        else:
            input0 = gen_param(classify="input0", name="x1",
                               datatype=",".join(dtype_total),
                               format=",".join(format_list1))
            input1 = gen_param(classify="input1", name="x2",
                               datatype=",".join(dtype_total),
                               format=",".join(format_list0))
            output0 = gen_param(classify="output0", name="y",
                                datatype=",".join(dtype_total),
                                format=",".join(format_list1))

    # 5HD+scalar,ND+ND,FZ+scalar
    elif len(shape_x) >= 2 and len(shape_y) == 1 and shape_y[0] == 1:
        if len(shape_x) == 4 and len(shape_y) == 1 and format_x in format_4d_list:
            if x_cdim % 16 == 0:
                format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and x_ndim % 16 == 0:
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nd * len_format_list
        format_list1 = format_nd * len(format_list) + format_nd * len_format_list
        input0 = gen_param(classify="input0", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list0))
        input1 = gen_param(classify="input1", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list1))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list0))

    # ND+ND,scalar+5HD,scalar+FZ
    elif len(shape_y) >= 2 and len(shape_x) == 1 and shape_x[0] == 1:
        if len(shape_x) == 1 and len(shape_y) == 4 and format_y in format_4d_list:
            if y_cdim % 16 == 0:
                format_list.append("NC1HWC0")
            if y_cdim % 16 == 0 and y_ndim % 16 == 0:
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nd * len_format_list
        format_list1 = format_nd * len(format_list) + format_nd * len_format_list
        input0 = gen_param(classify="input0", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list1))
        input1 = gen_param(classify="input1", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list0))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list0))
    # ND+ND,5HD+5HD
    else:
        if len(shape_x) == 1 and len(shape_y) == 1 and \
                shape_x[0] % 16 == 0 and shape_y[0] % 16 == 0:
            format_list.append("NC1HWC0")
        if len(shape_x) == 4 and len(shape_y) == 4 \
                and format_x in format_4d_list and format_y in format_4d_list:
            if format_x == format_y == "NCHW" or format_x == format_y == "HWCN" \
                    or format_x == format_y == "NHWC":
                if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                    if (x_cdim // 16 == 1 or y_cdim // 16 == 1) or x_cdim == y_cdim:
                        if x_ndim == y_ndim:
                            if x_hdim == y_hdim and (x_wdim == 1 or y_wdim == 1):
                                format_list.append("NC1HWC0")
                            if x_wdim == y_wdim and (x_hdim == 1 or y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if x_hdim == y_hdim and x_wdim == y_wdim:
                                format_list.append("NC1HWC0")
                            if (x_wdim == x_hdim == 1) or (y_wdim == y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if (x_hdim == 1 and y_wdim == 1) or (x_wdim == 1 and y_hdim == 1):
                                format_list.append("NC1HWC0")
                        if x_hdim == y_hdim:
                            if x_ndim == y_ndim and (x_wdim == 1 or y_wdim == 1):
                                format_list.append("NC1HWC0")
                            if x_wdim == y_wdim and (x_ndim == 1 or y_ndim == 1):
                                format_list.append("NC1HWC0")
                            if x_ndim == y_ndim and x_wdim == y_wdim:
                                format_list.append("NC1HWC0")
                            if (x_ndim == x_wdim == 1) or (y_ndim == y_wdim == 1):
                                format_list.append("NC1HWC0")
                            if (x_ndim == 1 and y_wdim == 1) or (x_wdim == 1 and y_ndim == 1):
                                format_list.append("NC1HWC0")
                        if x_wdim == y_wdim:
                            if x_ndim == y_ndim and (x_hdim == 1 or y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if x_hdim == y_hdim and (x_ndim == 1 or y_ndim == 1):
                                format_list.append("NC1HWC0")
                            if x_ndim == y_ndim and x_hdim == y_hdim:
                                format_list.append("NC1HWC0")
                            if (x_ndim == x_hdim == 1) or (y_ndim == y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if (x_ndim == 1 and y_hdim == 1) or (x_hdim == 1 and y_ndim == 1):
                                format_list.append("NC1HWC0")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        len_format_list = len(dtype_list)
        format_list = format_list * len_format_list
        input0 = gen_param(classify="input0", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        input1 = gen_param(classify="input1", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list))

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 21
0
def op_select_format(input_x, input_y, output_z, kernel_name="add"):
    """
    select format dynamically, supporting dynamic shape format selecting
    """
    shape_x = input_x.get("ori_shape")
    shape_y = input_y.get("ori_shape")

    shape_x = util.scalar2tensor_one(shape_x)
    shape_y = util.scalar2tensor_one(shape_y)

    format_4d_list = ["NCHW", "NHWC", "HWCN"]
    cce_product = tbe_platform.cce_conf.get_soc_spec("SOC_VERSION")
    if cce_product in ("Hi3796CV300ES", "Hi3796CV300CS"):
        dtype_list = ["float16", "int32"]
    else:
        dtype_list = ["float16", "float32", "int32"]

    format_x = input_x.get("ori_format")
    format_y = input_y.get("ori_format")

    dtype_total = []
    format_nd = ["ND"]
    format_list = ["ND"]
    format_nz = ["FRACTAL_NZ"]
    format_5hd = ["NC1HWC0"]
    len_format_list = len(dtype_list)
    add_nd_nz = False
    add_nz_nd = False
    if len(shape_x) == 1 and len(shape_y) >= 2 and shape_x[-1] == shape_y[-1]:
        for i in range(0, len(shape_y)):
            if shape_x[0] == shape_y[i] and shape_x[0] % 16 == 0:
                add_nd_nz = True
                break
    if len(shape_y) == 1 and len(shape_x) >= 2 and shape_x[-1] == shape_y[-1]:
        for i in range(0, len(shape_x)):
            if shape_y[0] == shape_x[i] and shape_y[0] % 16 == 0:
                add_nz_nd = True
                break

    if len(shape_x) == 4 and len(shape_y) == 4 and \
            format_x in format_4d_list and format_y in format_4d_list:
        x_cdim = shape_x[format_x.index("C")]
        x_wdim = shape_x[format_x.index("W")]
        x_hdim = shape_x[format_x.index("H")]
        x_ndim = shape_x[format_x.index("N")]
        y_cdim = shape_y[format_y.index("C")]
        y_wdim = shape_y[format_y.index("W")]
        y_hdim = shape_y[format_y.index("H")]
        y_ndim = shape_y[format_y.index("N")]
    if (len(shape_y) == 1 and len(shape_x) == 4) and \
            format_x in format_4d_list:
        x_cdim = shape_x[format_x.index("C")]
        x_ndim = shape_x[format_x.index("N")]
    if (len(shape_x) == 1 and len(shape_y) == 4) and \
            format_y in format_4d_list:
        y_cdim = shape_y[format_y.index("C")]
        y_ndim = shape_y[format_y.index("N")]

    # ND+ND NZ+NZ 5HD+5HD FZ+FZ
    if len(shape_x) >= 2 and len(shape_y) >= 2 and \
            shape_x[-2:] == shape_y[-2:]:
        format_list.append("FRACTAL_NZ")
        if len(shape_x) == 4 and len(shape_y) == 4 and \
                format_x in format_4d_list and format_y in format_4d_list:
            if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                if format_x == format_y == "NCHW" and \
                        (x_cdim == y_cdim or x_cdim // 16 == 1 or \
                         y_cdim // 16 == 1) and (x_ndim == y_ndim or \
                                                 x_ndim == 1 or y_ndim == 1):
                    format_list.append("NC1HWC0")
                if format_x == format_y == "HWCN":
                    if x_hdim == y_hdim and (x_wdim == 1 or y_wdim == 1):
                        format_list.append("NC1HWC0")
                    if x_wdim == y_wdim and (x_hdim == 1 or y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if x_wdim == y_wdim and x_hdim == y_hdim:
                        format_list.append("NC1HWC0")
                    if (x_wdim == x_hdim == 1) or (y_hdim == y_wdim == 1):
                        format_list.append("NC1HWC0")
                    if (x_hdim == y_wdim == 1) or (x_wdim == y_hdim == 1):
                        format_list.append("NC1HWC0")
                if format_x == format_y == "NHWC":
                    if x_hdim == y_hdim and (x_ndim == 1 or y_ndim == 1):
                        format_list.append("NC1HWC0")
                    if x_ndim == y_ndim and (x_hdim == 1 or y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if x_ndim == y_ndim and x_hdim == y_hdim:
                        format_list.append("NC1HWC0")
                    if (x_ndim == x_hdim == 1) or (y_ndim == y_hdim == 1):
                        format_list.append("NC1HWC0")
                    if (x_ndim == 1 and y_hdim == 1) or (x_hdim == 1
                                                         and y_ndim == 1):
                        format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and y_cdim % 16 == 0 and y_ndim % 16 == 0 and \
                    x_ndim % 16 == 0:
                if (format_x == format_y == "NHWC" and \
                    list(shape_x) == list(shape_y)) or \
                        (format_x == format_y == "NCHW" and \
                         list(shape_x) == list(shape_y)):
                    format_list.append("FRACTAL_Z")
                if format_x == format_y == "HWCN" and \
                        x_wdim * x_hdim == y_wdim * y_hdim:
                    format_list.append("FRACTAL_Z")
            if list(shape_x) == list(shape_y):
                format_list.append("NC1HWC0")
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        format_list_input0 = format_list
        format_list_input1 = format_list
        format_list_output = format_list

    # NZ+ND,ND+ND,5HD+5HD,FZ+FZ,ND+NZ
    elif len(shape_x) >= 2 and len(shape_y) >= 2 and \
            ((_can_division_sixteen(shape_x) and \
              not _can_division_sixteen(shape_y)) or \
             (not _can_division_sixteen(shape_x) and \
              _can_division_sixteen(shape_y))):
        if len(shape_x) == 4 and len(
                shape_y) == 4 and format_x in format_4d_list and \
                format_y in format_4d_list:
            if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                if x_cdim == y_cdim or x_cdim // 16 == 1 or y_cdim // 16 == 1:
                    format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and x_ndim % 16 == 0 and y_cdim % 16 == 0 and \
                    y_ndim % 16 == 0:
                if format_x == format_y == "NCHW" and \
                        x_hdim * x_wdim == y_hdim * y_wdim and \
                        x_cdim == y_cdim:
                    if x_ndim == y_ndim:
                        format_list.append("FRACTAL_Z")
                    if (x_ndim // 16 == 1
                            and y_ndim % 16 == 0) or (y_ndim // 16 == 1
                                                      and x_ndim % 16 == 0):
                        format_list.append("FRACTAL_Z")
                if format_x == format_y == "NHWC" and \
                        x_hdim * x_wdim == y_hdim * y_wdim and \
                        x_ndim == y_ndim and x_cdim == y_cdim:
                    format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nz * len_format_list
        format_list1 = format_list + format_nd * len_format_list
        if _can_division_sixteen(
                shape_x) and not _can_division_sixteen(shape_y):
            format_list_input0 = format_list0
            format_list_input1 = format_list1
            format_list_output = format_list0
        else:
            format_list_input0 = format_list1
            format_list_input1 = format_list0
            format_list_output = format_list0

    elif add_nd_nz or add_nz_nd:
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nz * len_format_list
        format_list1 = format_list + format_nd * len_format_list
        if len(shape_y) == 1 and len(
                shape_x) == 4 and format_x in format_4d_list:
            if shape_y[0] % 16 == 0 and x_cdim % 16 == 0:
                format_list0 = format_list + format_5hd * len_format_list
                format_list1 = format_list + format_5hd * len_format_list
        if add_nz_nd:
            format_list_input0 = format_list0
            format_list_input1 = format_list1
            format_list_output = format_list0
        else:
            format_list_input0 = format_list1
            format_list_input1 = format_list0
            format_list_output = format_list0

    # 5HD+scalar,ND+ND,FZ+scalar
    elif len(shape_x) >= 2 and len(shape_y) == 1 and shape_y[0] == 1:
        if len(shape_x) == 4 and len(
                shape_y) == 1 and format_x in format_4d_list:
            if x_cdim % 16 == 0:
                format_list.append("NC1HWC0")
            if x_cdim % 16 == 0 and x_ndim % 16 == 0:
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nd * len_format_list
        format_list1 = format_nd * len(
            format_list) + format_nd * len_format_list
        format_list_input0 = format_list0
        format_list_input1 = format_list1
        format_list_output = format_list0

    # ND+ND,scalar+5HD,scalar+FZ
    elif len(shape_y) >= 2 and len(shape_x) == 1 and shape_x[0] == 1:
        if len(shape_x) == 1 and len(
                shape_y) == 4 and format_y in format_4d_list:
            if y_cdim % 16 == 0:
                format_list.append("NC1HWC0")
            if y_cdim % 16 == 0 and y_ndim % 16 == 0:
                format_list.append("FRACTAL_Z")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        format_list = format_list * len_format_list
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * 1
        format_list0 = format_list + format_nd * len_format_list
        format_list1 = format_nd * len(
            format_list) + format_nd * len_format_list
        format_list_input0 = format_list1
        format_list_input1 = format_list0
        format_list_output = format_list0
    # ND+ND,5HD+5HD
    else:
        if len(shape_x) == 1 and len(shape_y) == 1 and \
                shape_x[0] % 16 == 0 and shape_y[0] % 16 == 0:
            format_list.append("NC1HWC0")
        if len(shape_x) == 4 and len(
                shape_y) == 4 and format_x in format_4d_list and \
                format_y in format_4d_list:
            if format_x == format_y == "NCHW" or \
                    format_x == format_y == "HWCN" or \
                    format_x == format_y == "NHWC":
                if x_cdim % 16 == 0 and y_cdim % 16 == 0:
                    if (x_cdim // 16 == 1 or y_cdim // 16 == 1) or (x_cdim
                                                                    == y_cdim):
                        if x_ndim == y_ndim:
                            if x_hdim == y_hdim and (x_wdim == 1
                                                     or y_wdim == 1):
                                format_list.append("NC1HWC0")
                            if x_wdim == y_wdim and (x_hdim == 1
                                                     or y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if x_hdim == y_hdim and x_wdim == y_wdim:
                                format_list.append("NC1HWC0")
                            if (x_wdim == x_hdim == 1) or (y_wdim == y_hdim ==
                                                           1):
                                format_list.append("NC1HWC0")
                            if (x_hdim == 1
                                    and y_wdim == 1) or (x_wdim == 1
                                                         and y_hdim == 1):
                                format_list.append("NC1HWC0")
                        if x_hdim == y_hdim:
                            if x_ndim == y_ndim and (x_wdim == 1
                                                     or y_wdim == 1):
                                format_list.append("NC1HWC0")
                            if x_wdim == y_wdim and (x_ndim == 1
                                                     or y_ndim == 1):
                                format_list.append("NC1HWC0")
                            if x_ndim == y_ndim and x_wdim == y_wdim:
                                format_list.append("NC1HWC0")
                            if (x_ndim == x_wdim == 1) or (y_ndim == y_wdim ==
                                                           1):
                                format_list.append("NC1HWC0")
                            if (x_ndim == 1
                                    and y_wdim == 1) or (x_wdim == 1
                                                         and y_ndim == 1):
                                format_list.append("NC1HWC0")
                        if x_wdim == y_wdim:
                            if x_ndim == y_ndim and (x_hdim == 1
                                                     or y_hdim == 1):
                                format_list.append("NC1HWC0")
                            if x_hdim == y_hdim and (x_ndim == 1
                                                     or y_ndim == 1):
                                format_list.append("NC1HWC0")
                            if x_ndim == y_ndim and x_hdim == y_hdim:
                                format_list.append("NC1HWC0")
                            if (x_ndim == x_hdim == 1) or (y_ndim == y_hdim ==
                                                           1):
                                format_list.append("NC1HWC0")
                            if (x_ndim == 1
                                    and y_hdim == 1) or (x_hdim == 1
                                                         and y_ndim == 1):
                                format_list.append("NC1HWC0")
        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype] * len(format_list)
        len_format_list = len(dtype_list)
        format_list = format_list * len_format_list
        format_list_input0 = format_list
        format_list_input1 = format_list
        format_list_output = format_list

    if _can_broadcast(shape_x, shape_y) and len(shape_x) != len(shape_y):
        x_format = input_x.get("ori_format")
        y_format = input_y.get("ori_format")
        if x_format == "NHWC" or y_format == "NHWC":
            formats = ["NC1HWC0"]
            for item in formats:
                dtype_total = dtype_total + dtype_list
                format_list_input0 = format_list_input0 + [item
                                                           ] * len(dtype_list)
                format_list_input1 = format_list_input1 + [item
                                                           ] * len(dtype_list)
                format_list_output = format_list_output + [item
                                                           ] * len(dtype_list)

    input0 = gen_param(classify="input0",
                       name="x1",
                       datatype=",".join(dtype_total),
                       format=",".join(format_list_input0))
    input1 = gen_param(classify="input1",
                       name="x2",
                       datatype=",".join(dtype_total),
                       format=",".join(format_list_input1))
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=",".join(dtype_total),
                        format=",".join(format_list_output))

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 22
0
def op_select_format(condition, x1, x2, y, kernel_name="select"):
    """
    select format dynamically
    """
    shape_condition = condition.get("ori_shape")
    shape_x1 = x1.get("ori_shape")
    shape_x2 = x2.get("ori_shape")

    format_4d_list = ["NCHW", "NHWC", "HWCN"]

    format_condition = condition.get("ori_format")
    format_x1 = x1.get("ori_format")
    format_x2 = x2.get("ori_format")

    format_list = []
    if tbe_platform.cce_conf.api_check_support("te.lang.cce.vmul",
                                               "float32"):
        dtype_list = ["float16", "float", "int32", "int8", "uint8"]
    else:
        dtype_list = ["float16", "int32", "int8", "uint8"]
    dtype_total = []
    dtype_total0 = []
    dtype_total0.append("bool")
    format_list1 = []
    #NZ+NZ ND+ND 5HD+5HD FZ+FZ
    if (len(shape_condition) != 1) or \
            (len(shape_condition) == 1 and len(shape_x1) == 1
             and len(shape_x2) == 1):
        format_list.append("ND")
        if format_condition == format_x1 == format_x2 and \
                format_x1 in format_4d_list and \
                list(shape_condition) == list(shape_x1) == list(shape_x2):
            format_list.append("FRACTAL_Z")
            format_list.append("FRACTAL_NZ")
            format_list.append("NC1HWC0")

        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype]*len(format_list)
        dtype_total0 = dtype_total0*len(dtype_total)
        format_list = format_list * len(dtype_list)
        input0 = gen_param(classify="input0", name="condition",
                           datatype=",".join(dtype_total0),
                           format=",".join(format_list))
        input1 = gen_param(classify="input1", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        input2 = gen_param(classify="input2", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list))
    else:
        format_list.append("ND")
        if format_x1 == format_x2:
            if len(shape_x1) == 4 and len(shape_x2) == 4 and \
                    format_x1 in format_4d_list and format_x2 in format_4d_list:
                format_list1.append("FRACTAL_NZ")
                format_list1.append("ND")
                if format_x1 in ("NHWC", "NCHW"):
                    format_list1.append("NC1HWC0")
            elif len(shape_x1) > 2 and len(shape_x2) > 2 and \
                    format_x1 in format_4d_list and format_x2 in format_4d_list:
                format_list1.append("FRACTAL_NZ")
                format_list1.append("ND")
            else:
                format_list1.append("ND")
        else:
            format_list1.append("ND")

        for dtype in dtype_list:
            dtype_total = dtype_total + [dtype]*len(format_list1)
        dtype_total0 = dtype_total0*len(dtype_total)
        format_list1 = format_list1*len(dtype_list)
        format_list = format_list*len(dtype_total)
        input0 = gen_param(classify="input0", name="condition",
                           datatype=",".join(dtype_total0),
                           format=",".join(format_list))
        input1 = gen_param(classify="input1", name="x1",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list1))
        input2 = gen_param(classify="input2", name="x2",
                           datatype=",".join(dtype_total),
                           format=",".join(format_list1))
        output0 = gen_param(classify="output0", name="y",
                            datatype=",".join(dtype_total),
                            format=",".join(format_list1))

    param_list = [input0, input1, input2, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 23
0
def op_select_format(input_dy, input_x, input_variance, input_mean, input_gamma,
                     output_pd_x, kernel_name="layer_norm_x_backprop"):
    """
    function of selecting dynamic format

    Parameters
    ----------
    input_dy : dict
        shape and dtype of input dy, only support float16, float32
    input_x: dict
        shape and dtype of input x, only support float16, float32
    input_variance: dict
        shape and dtype of input variance, only support float16, float32
    input_mean: dict
        shape and dtype of input mean, only support float16, float32
    input_gamma: dict
        shape and dtype of input gamma, only support float16, float32
    output_pd_x: dict
        shape and dtype of output, only support float16, float32
    kernel_name: str
        cce kernel name, default value is "layer_norm_x_backprop"

    Returns
    -------
    None
    """
    shape_dy = input_dy.get("ori_shape")
    shape_gamma = input_gamma.get("ori_shape")
    shape_dy = util.scalar2tensor_one(shape_dy)
    shape_gamma = util.scalar2tensor_one(shape_gamma)
    c_0 = 16

    if _check_dynamic_format(shape_dy, shape_gamma, c_0):
        input0 = gen_param(classify="input0", name="dy",
                           datatype="float16,float16,float16,float16,float,"
                                    "float,float,float",
                           format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        input1 = gen_param(classify="input1", name="x",
                           datatype="float16,float16,float16,float16,float,"
                                    "float,float,float",
                           format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        input2 = gen_param(classify="input2", name="variance",
                           datatype="float16,float16,float16,float16,float,"
                                    "float,float,float",
                           format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        input3 = gen_param(classify="input3", name="mean",
                           datatype="float16,float16,float16,float16,float,"
                                    "float,float,float",
                           format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        input4 = gen_param(classify="input4", name="gamma",
                           datatype="float16,float16,float16,float16,float,"
                                    "float,float,float",
                           format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        output0 = gen_param(classify="output0", name="pd_x",
                            datatype="float16,float16,float16,float16,float,"
                                     "float,float,float",
                            format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
    else:
        input0 = gen_param(classify="input0", name="dy",
                           datatype="float16, float,float16,float16,float16,"
                                    "float16,float,float,float,float",
                           format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND,"
                                  "NCHW,NC1HWC0,NHWC,ND")
        input1 = gen_param(classify="input1", name="x",
                           datatype="float16, float,float16,float16,float16,"
                                    "float16,float,float,float,float",
                           format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND,"
                                  "NCHW,NC1HWC0,NHWC,ND")
        input2 = gen_param(classify="input2", name="variance",
                           datatype="float16, float,float16,float16,float16,"
                                    "float16,float,float,float,float",
                           format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                  "NHWC,ND")
        input3 = gen_param(classify="input3", name="mean",
                           datatype="float16, float,float16,float16,float16,"
                                    "float16,float,float,float,float",
                           format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                  "NHWC,ND")
        input4 = gen_param(classify="input4", name="gamma",
                           datatype="float16, float,float16,float16,float16,"
                                    "float16,float,float,float,float",
                           format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                  "NHWC,ND")
        output0 = gen_param(classify="output0", name="pd_x",
                            datatype="float16, float,float16,float16,float16,"
                                     "float16,float,float,float,float",
                            format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,"
                                   "ND,NCHW,NC1HWC0,NHWC,ND")

    param_list = [input0, input1, input2, input3, input4, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 24
0
def op_select_format(x,
                     bias,
                     y,
                     axis=1,
                     num_axes=1,
                     bias_from_blob=True,
                     kernel_name="bias"):
    """
    select format dynamically
    """
    shape_x_ori = x.get("ori_shape")
    shape_x = x.get("shape")
    shape_bias_ori = bias.get("ori_shape")
    shape_bias = bias.get("shape")

    length_x_ori = len(shape_x_ori)
    length_x = len(shape_x)
    length_bias_ori = len(shape_bias_ori)
    length_bias = len(shape_bias)

    if length_bias == 1 and shape_bias[0] == 1:
        format_bias = "ND,ND,ND,ND"
        format_bias_hisi = "ND,ND"
    else:
        format_bias = "NC1HWC0,NC1HWC0,ND,ND"
        format_bias_hisi = "NC1HWC0,ND"

    if length_x_ori == 4:
        # NC1HWC0+ND
        if tbe_platform.cce_conf.get_soc_spec("SOC_VERSION") in (
                "Hi3796CV300ES", "Hi3796CV300CS"):
            input0 = gen_param(classify="input0",
                               name="x",
                               datatype="float16,float16",
                               format="NC1HWC0,ND")
            input1 = gen_param(classify="input1",
                               name="bias",
                               datatype="float16,float16",
                               format=format_bias_hisi)
            output0 = gen_param(classify="output0",
                                name="y",
                                datatype="float16,float16",
                                format="NC1HWC0,ND")
        else:
            input0 = gen_param(classify="input0",
                               name="x",
                               datatype="float16,float,float16,float",
                               format="NC1HWC0,NC1HWC0,ND,ND")
            input1 = gen_param(classify="input1",
                               name="bias",
                               datatype="float16,float,float16,float",
                               format=format_bias)
            output0 = gen_param(classify="output0",
                                name="y",
                                datatype="float16,float,float16,float",
                                format="NC1HWC0,NC1HWC0,ND,ND")
    else:
        # ND+ND
        if tbe_platform.cce_conf.get_soc_spec("SOC_VERSION") in (
                "Hi3796CV300ES", "Hi3796CV300CS"):
            input0 = gen_param(classify="input0",
                               name="x",
                               datatype="float16",
                               format="ND")
            input1 = gen_param(classify="input1",
                               name="bias",
                               datatype="float16",
                               format="ND")
            output0 = gen_param(classify="output0",
                                name="y",
                                datatype="float16",
                                format="ND")
        else:
            input0 = gen_param(classify="input0",
                               name="x",
                               datatype="float16,float",
                               format="ND,ND")
            input1 = gen_param(classify="input1",
                               name="bias",
                               datatype="float16,float",
                               format="ND,ND")
            output0 = gen_param(classify="output0",
                                name="y",
                                datatype="float16,float",
                                format="ND,ND")

    param_list = [input0, input1, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 25
0
def op_select_format(grads, x, diff_scale, diff_offset, scale,
                     batch_mean, batch_variance, y, epsilon,
                     kernel_name="bn_training_reduce_grad"):
    """
    select format dynamically
    """
    format_grads = grads.get("ori_format").upper()
    origin_shape = grads.get("ori_shape")

    # can support ND + ND
    if format_grads == "NCHW" and len(origin_shape) == 4 \
            and origin_shape[0] == 1 and origin_shape[2] == 1:
        input0 = gen_param(classify="input0", name="grads",
                           datatype="float16,float,float16,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input1 = gen_param(classify="input1", name="x",
                           datatype="float16,float,float16,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input2 = gen_param(classify="input2", name="diff_scale",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input3 = gen_param(classify="input3", name="diff_offset",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input4 = gen_param(classify="input4", name="scale",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input5 = gen_param(classify="input5", name="batch_mean",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        input6 = gen_param(classify="input6", name="batch_variance",
                           datatype="float,float,float,float",
                           format="NCHW,NCHW,NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0", name="y",
                            datatype="float16,float,float16,float",
                            format="NCHW,NCHW,NC1HWC0,NC1HWC0")
    # support 5HD + 5HD
    else:
        input0 = gen_param(classify="input0", name="grads",
                           datatype="float16,float",
                           format="NC1HWC0,NC1HWC0")
        input1 = gen_param(classify="input1", name="x",
                           datatype="float16,float",
                           format="NC1HWC0,NC1HWC0")
        input2 = gen_param(classify="input2", name="diff_scale",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input3 = gen_param(classify="input3", name="diff_offset",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input4 = gen_param(classify="input4", name="scale",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input5 = gen_param(classify="input5", name="batch_mean",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        input6 = gen_param(classify="input6", name="batch_variance",
                           datatype="float,float",
                           format="NC1HWC0,NC1HWC0")
        output0 = gen_param(classify="output0", name="y",
                            datatype="float16,float",
                            format="NC1HWC0,NC1HWC0")

    param_list = [input0, input1, input2, input3,
                  input4, input5, input6, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 26
0
def op_select_format(input_x,
                     output_y,
                     operation=1,
                     axis=0,
                     coeff=1.0,
                     kernel_name="reduction"):
    """
    support to 5HD format
    Parameters
    ----------
    input_x : input tensor
    output_y: output tensor
    operation : can only be one of "1:SUM, 2:ASUM (sum of abs), 3:SUMSQ (sum of sqr), 4:MEAN"
    axis : the first axis to reduce, may be negative to index from the end
            (e.g., -1 for the last axis).If axis == 0, the output Blob always has
            the empty shape (count 1), performing reduction across the entire input.
    coeff : scale for output
    kernel_name : cce kernel name, default value is "cce_reductionLayer"
    Returns
    -------
    param_dynamic_in_json
    """
    input_ori_shape = input_x.get("ori_shape")
    input_ori_format = input_x.get("ori_format")

    if axis < 0:
        axis = len(input_ori_shape) + axis

    is_support_5hd = True

    if input_ori_format not in ("NCHW", "NHWC"):
        is_support_5hd = False

    if (input_ori_format == "NCHW" and axis == 1) \
            or (input_ori_format == "NHWC" and axis == 3):
        is_support_5hd = False

    if len(input_ori_shape) < 4:
        is_support_5hd = False

    if tbe_platform.cce_conf.get_soc_spec("SOC_VERSION") in ("Hi3796CV300ES",
                                                             "Hi3796CV300CS"):
        dtype_base = ["float16"]
    else:
        dtype_base = ["float16", "float32"]

    format_base = ["ND"] * len(dtype_base)
    if is_support_5hd:
        dtype_base = dtype_base + ["float16"]
        format_base = format_base + ["NC1HWC0"]

    dtype_base = ','.join(dtype_base)
    format_base = ','.join(format_base)

    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype_base,
                       format=format_base)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype_base,
                        format=format_base)
    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 27
0
def op_select_format(input_x, output_x, multiples, kernel_name="tile_d"):
    """TileD: to do boradcast with multiples

    Parameters
    ----------
    input_x : dict
        shape and dtype of input
    output_x: dict
        dict of output.
    multiples : list or tuple.
        Number of the axis replicates.
    kernel_name : str
        kernel name, default value is "tile_d".

    Returns
    -------
    param_dynamic_in_json
    """
    input_shape = list(input_x.get("shape"))
    input_format = input_x.get("format")
    inputdtype = input_x.get("dtype")
    # ND dtype
    dtype_base = ["float16", "float", "int32"]
    dtype_list = ["float16", "float", "int32", "bool"]
    # default support ND for dtype_base
    dtype_base_out = dtype_base.copy()
    format_base_out = ["ND"] * len(dtype_base)
    format_base_in = ["ND"] * len(dtype_base)

    # check whether support 4D to 5HD
    is_support_5hd = True
    if inputdtype == "bool":
        is_support_5hd = False
        dtype_base_out = dtype_list.copy()
        format_base_out = ["ND"] * len(dtype_list)
        format_base_in = ["ND"] * len(dtype_list)
    elif input_format not in (
            "NCHW", "NHWC") or len(input_shape) != 4 or len(multiples) != 4:
        is_support_5hd = False
    elif input_shape[1] != 1 or input_shape[2] != 1 or input_shape[3] != 1:
        is_support_5hd = False
    elif input_format in ("NCHW", ) and multiples[1] % 16 != 0:
        is_support_5hd = False
    elif input_format in ("NHWC", ) and multiples[3] % 16 != 0:
        is_support_5hd = False
    if is_support_5hd:
        dtype_base_out = dtype_base_out + dtype_base + dtype_base
        format_base_in = format_base_in + ["NCHW"] * len(
            dtype_base) + ["NHWC"] * len(dtype_base)
        format_base_out = format_base_out + ["NC1HWC0"] * len(
            dtype_base) + ["NC1HWC0"] * len(dtype_base)

    dtype_str = ','.join(dtype_base_out)
    format_input_str = ','.join(format_base_in)
    format_output_str = ','.join(format_base_out)

    input0 = gen_param(classify="input0",
                       name="x",
                       datatype=dtype_str,
                       format=format_input_str)
    output0 = gen_param(classify="output0",
                        name="y",
                        datatype=dtype_str,
                        format=format_output_str)
    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 28
0
def op_select_format(input_value,
                     output_data,
                     size_splits,
                     split_dim,
                     num_split,
                     kernel_name="split_v_d"):
    """Split a tensor into len(size_splits) tensors along one dimension.

    Parameters
    ----------
    input_value: dict
        the dict of input tensor.
    output_data: list or tuple
        the list of output tensor.
    size_splits: list or tuple
        a Python list containing the sizes of each output tensor
        along `split_dim`.
    split_dim: int
        the dimension along which to split_d.
    num_split: int
        used to specify the number of outputs.
    kernel_name: str
        cce kernel name, default value is "split_v_d".

    Returns
    -------
    None.
    """
    dtype = input_value.get("dtype").lower()
    if dtype == "int8":
        c0_len = 32
    else:
        c0_len = 16
    output_org_shape_list = []
    output_org_format_list = []
    is_support_5hd = True
    support_ori_format = ["NCHW", "NHWC"]
    input_ori_shape = input_value.get("ori_shape")
    input_ori_format = input_value.get("ori_format")
    split_dim = split_dim % len(input_ori_shape)

    for _, output_dict in enumerate(output_data):
        ori_format = output_dict.get("ori_format").upper()
        ori_shape = output_dict.get("ori_shape")
        output_org_shape_list.append(ori_shape)
        output_org_format_list.append(ori_format)

        if ori_format not in support_ori_format or len(ori_shape) != 4:
            is_support_5hd = False
            break

        # when split_d by N,H,W, support NC1HWC0
        if ori_format[split_dim] != "C":
            break

        # when split_d by C, but output size not C0 align donot support NC1HWC0
        if ori_format == "NCHW" and ori_shape[1] % c0_len != 0:
            is_support_5hd = False
            break

        if ori_format == "NHWC" and ori_shape[3] % c0_len != 0:
            is_support_5hd = False
            break

    is_support_nz = False
    if input_ori_format[0] == "N" and split_dim == 0 and \
            len(input_ori_shape) > 2:
        is_support_nz = True

    split_with_5hd_not_align = \
        SplitWith5HD(input_value, output_data,
                     split_dim, num_split, kernel_name)
    is_support_other_5hd = split_with_5hd_not_align.check_op_select()
    size_equal = len(set(size_splits))
    if size_equal != 1:
        is_support_other_5hd = False

    dtype_base = [
        "float16", "float", "int32", "int8", "int16", "int64", "uint8",
        "uint16", "uint32", "uint64"
    ]
    dtype_5hd = [
        "float16", "float", "int32", "int8", "int16", "uint16", "uint32"
    ]
    dtype_base_out = dtype_base.copy()
    format_base_out = ["ND"] * len(dtype_base)

    if is_support_5hd:
        dtype_base_out = dtype_base_out + dtype_5hd
        format_base_out = format_base_out + ["NC1HWC0"] * len(dtype_5hd)

    if is_support_nz:
        dtype_base_out = dtype_base_out + dtype_base
        format_base_out = format_base_out + ["FRACTAL_NZ"] * len(dtype_base)

    if is_support_other_5hd:
        dtype_base_out = dtype_base_out + ["float16", "int16", "uint16"]
        format_base_out = format_base_out + ["NC1HWC0"]*3

    dtype_str = ','.join(dtype_base_out)
    format_str = ','.join(format_base_out)

    input0 = gen_param(
        classify="input0", name="x", datatype=dtype_str, format=format_str)
    output0 = gen_param(
        classify="output0", name="y", datatype=dtype_str, format=format_str)
    param_list = [input0, output0]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)

    return param_dynamic_in_json
Esempio n. 29
0
def op_select_format(input_x, input_gamma, input_beta,
                     output_y, output_mean, output_variance,
                     begin_norm_axis, begin_params_axis,
                     kernel_name="layer_norm"):
    """
    select format dynamically
    """
    shape_x = input_x.get("ori_shape")
    shape_x = util.scalar2tensor_one(shape_x)
    shape_gamma = input_gamma.get("ori_shape")
    shape_gamma = util.scalar2tensor_one(shape_gamma)

    # can not support Nz + ND
    # while len(shape_gamma) >= 2 and  _division_sixteen(shape_x) = False
    if begin_params_axis == 0:
        if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x)):
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float16,float16,float16,"
                                        "float,float,float,float",
                               format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")

            input1 = gen_param(classify="input1", name="gamma",
                               datatype="float16,float16,float16,float16,float,"
                                        "float,float,float",
                               format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")

            input2 = gen_param(classify="input2", name="beta",
                               datatype="float16,float16,float16,float16,float,"
                                        "float,float,float",
                               format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")

            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float16,float16,float16,float,"
                                         "float,float,float",
                                format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")

            output1 = gen_param(classify="output1", name="mean",
                                datatype="float16,float16,float16,float16,float,"
                                         "float,float,float",
                                format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")

            output2 = gen_param(classify="output2", name="variance",
                                datatype="float16,float16,float16,float16,float,"
                                         "float,float,float",
                                format="NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,NHWC,ND")
        else:
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float,float16,float16,float16,"
                                        "float16,float,float,float,float",
                               format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,"
                                      "ND,NCHW,NC1HWC0,NHWC,ND")

            input1 = gen_param(classify="input1", name="gamma",
                               datatype="float16,float,float16,float16,float16,"
                                        "float16,float,float,float,float",
                               format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                      "NHWC,ND")

            input2 = gen_param(classify="input2", name="beta",
                               datatype="float16,float,float16,float16,float16,"
                                        "float16,float,float,float,float",
                               format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                      "NHWC,ND")

            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float,float16,float16,float16,"
                                         "float16,float,float,float,float",
                                format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NC1HWC0,NHWC,ND,"
                                       "NCHW,NC1HWC0,NHWC,ND")

            output1 = gen_param(classify="output1", name="mean",
                                datatype="float16,float,float16,float16,float16,"
                                         "float16,float,float,float,float",
                                format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                       "NHWC,ND")

            output2 = gen_param(classify="output2", name="variance",
                                datatype="float16,float,float16,float16,float16,"
                                         "float16,float,float,float,float",
                                format="ND,ND,NCHW,NC1HWC0,NHWC,ND,NCHW,NC1HWC0,"
                                       "NHWC,ND")
    else:
        if len(shape_gamma) >= 2 or (not _division_sixteen(shape_x)):
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float16,float16,"
                                        "float,float,float",
                               format="NCHW,NHWC,ND,NCHW,NHWC,ND")

            input1 = gen_param(classify="input1", name="gamma",
                               datatype="float16,float16,float16,"
                                        "float,float,float",
                               format="NCHW,NHWC,ND,NCHW,NHWC,ND")

            input2 = gen_param(classify="input2", name="beta",
                               datatype="float16,float16,float16,"
                                        "float,float,float",
                               format="NCHW,NHWC,ND,NCHW,NHWC,ND")

            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float16,float16,"
                                         "float,float,float",
                                format="NCHW,NHWC,ND,NCHW,NHWC,ND")

            output1 = gen_param(classify="output1", name="mean",
                                datatype="float16,float16,float16,"
                                         "float,float,float",
                                format="NCHW,NHWC,ND,NCHW,NHWC,ND")

            output2 = gen_param(classify="output2", name="variance",
                                datatype="float16,float16,float16,"
                                         "float,float,float",
                                format="NCHW,NHWC,ND,NCHW,NHWC,ND")
        else:
            input0 = gen_param(classify="input0", name="x",
                               datatype="float16,float,float16,float16,"
                                        "float16,float,float,float",
                               format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,"
                                      "ND,NCHW,NHWC,ND")

            input1 = gen_param(classify="input1", name="gamma",
                               datatype="float16,float,float16,float16,"
                                        "float16,float,float,float",
                               format="ND,ND,NCHW,NHWC,ND,NCHW,"
                                      "NHWC,ND")

            input2 = gen_param(classify="input2", name="beta",
                               datatype="float16,float,float16,float16,"
                                        "float16,float,float,float",
                               format="ND,ND,NCHW,NHWC,ND,NCHW,"
                                      "NHWC,ND")

            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float,float16,float16,"
                                         "float16,float,float,float",
                                format="FRACTAL_NZ,FRACTAL_NZ,NCHW,NHWC,ND,"
                                       "NCHW,NHWC,ND")

            output1 = gen_param(classify="output1", name="mean",
                                datatype="float16,float,float16,float16,"
                                         "float16,float,float,float",
                                format="ND,ND,NCHW,NHWC,ND,NCHW,"
                                       "NHWC,ND")

            output2 = gen_param(classify="output2", name="variance",
                                datatype="float16,float,float16,float16,"
                                         "float16,float,float,float",
                                format="ND,ND,NCHW,NHWC,ND,NCHW,"
                                       "NHWC,ND")

    param_list = [input0, input1, input2, output0, output1, output2]
    param_dynamic_in_json = get_dynamic_param_in_json(param_list)
    return param_dynamic_in_json
Esempio n. 30
0
def op_select_format(inputs, weights, bias, offset_w, outputs, strides,
                     pads, dilations, groups=1, data_format='NHWC',
                     offset_x=0, kernel_name="conv2d"):
    """
    select format dynamically
    """
    def _select_format(params):
        inputs = params[0]
        weights = params[1]
        c0_optim_flg = False
        shape_x = inputs.get("ori_shape")
        shape_x = scalar2tensor_one(shape_x)
        format_fm = inputs.get("ori_format")
        if format_fm == "NCHW":
            shape_fm = shape_x
        elif format_fm == "NHWC":
            shape_fm = [shape_x[0], shape_x[3], shape_x[1], shape_x[2]]
        else:
            err_man.raise_err_input_format_invalid("conv2d", "inputs", \
                ["NCHW", "NHWC"], format_fm)

        shape_w = weights.get("ori_shape")
        if (not isinstance(shape_w, (tuple, list))) or len(shape_w) != 4:
            err_man.raise_err_should_be_4d("conv2d", "weights")
        format_w = weights.get("ori_format")
        if format_w == "NCHW":
            shape_filter = shape_w
        elif format_w == "NHWC":
            shape_filter = [shape_w[0], shape_w[3], shape_w[1], shape_w[2]]
        elif format_w == "HWCN":
            shape_filter = [shape_w[3], shape_w[2], shape_w[0], shape_w[1]]
        else:
            err_man.raise_err_input_format_invalid("conv2d", "weights", \
                ["NCHW", "NHWC", "HWCN"], format_w)
        if shape_fm[1] <= 4:
            c0_optim_flg = True
        if (shape_filter[2] == 1) and (shape_filter[3] == 1):
            c0_optim_flg = False
        # format NC1HWC0_C04 can only be used at first conv layer
        # for those soc using NC1HWC0_C04, ensure is_first_layer == 1
        if inputs.get("is_first_layer") != 1 and \
            cce_conf.get_soc_spec("SOC_VERSION") \
            in ("Ascend710", "Ascend615", "Ascend610", "Hi3796CV300CS"):
            c0_optim_flg = False
        if c0_optim_flg:
            if cce_conf.get_soc_spec("SOC_VERSION") in \
            ("Ascend710", "Ascend615", "Ascend610", "Hi3796CV300CS"):
                input0 = gen_param(classify="input0", name="x",
                                   datatype="float16,float16,int8,int8",
                                   format="NC1HWC0_C04,NC1HWC0,"
                                          "NC1HWC0_C04,NC1HWC0")
            else:
                input0 = gen_param(classify="input0", name="x",
                                   datatype="float16,float16,int8,int8",
                                   format="NC1HWC0,NC1HWC0,"
                                          "NC1HWC0,NC1HWC0")
            input1 = gen_param(classify="input1", name="filter",
                               datatype="float16,float16,int8,int8",
                               format="FRACTAL_Z_C04,FRACTAL_Z,"
                                      "FRACTAL_Z_C04,FRACTAL_Z")
            input2 = gen_param(classify="input2", name="bias",
                               datatype="float16,float16,int32,int32",
                               format="ND,ND,ND,ND")
            input3 = gen_param(classify="input3", name="offset_w",
                               datatype="int8,int8,int8,int8",
                               format="ND,ND,ND,ND")
            output0 = gen_param(classify="output0", name="y",
                                datatype="float16,float16,int32,int32",
                                format="NC1HWC0,NC1HWC0,NC1HWC0,NC1HWC0")
        else:
            # only dynamic_hw or dynamic_batch is supported by dynamic conv2d
            if (shape_fm[0] == -1 and -1 not in shape_fm[1:]) or \
                (shape_fm[2] == -1 and shape_fm[3] == -1 and -1 not in shape_fm[:2]):
                input0 = gen_param(classify="input0", name="x",
                                   datatype="float16",
                                   format="NC1HWC0",
                                   unknownshape_format="NC1HWC0")
                input1 = gen_param(classify="input1", name="filter",
                                   datatype="float16",
                                   format="FRACTAL_Z",
                                   unknownshape_format="FRACTAL_Z")
                input2 = gen_param(classify="input2", name="bias",
                                   datatype="float16",
                                   format="ND")
                input3 = gen_param(classify="input3", name="offset_w",
                                   datatype="int8",
                                   format="ND")
                output0 = gen_param(classify="output0", name="y",
                                    datatype="float16",
                                    format="NC1HWC0",
                                    unknownshape_format="NC1HWC0")
            else:
                input0 = gen_param(classify="input0", name="x",
                                   datatype="float16,int8",
                                   format="NC1HWC0,NC1HWC0")
                input1 = gen_param(classify="input1", name="filter",
                                   datatype="float16,int8",
                                   format="FRACTAL_Z,FRACTAL_Z")
                input2 = gen_param(classify="input2", name="bias",
                                   datatype="float16,int32",
                                   format="ND,ND")
                input3 = gen_param(classify="input3", name="offset_w",
                                   datatype="int8,int8",
                                   format="ND,ND")
                output0 = gen_param(classify="output0", name="y",
                                    datatype="float16,int32",
                                    format="NC1HWC0,NC1HWC0")
        return [input0, input1, input2, input3, output0]

    params = [inputs, weights, bias, offset_w, outputs, strides,
              pads, dilations, groups, data_format, offset_x,
              kernel_name]
    param_list = _select_format(params)
    return get_dynamic_param_in_json(param_list)