예제 #1
0
 def _check_attr_range_dw(name, value, attr_min=None, attr_max=None):
     if not attr_min and not attr_max:
         return
     if not attr_min:
         if value > attr_max:
             args_dict = {
                 'errCode': 'E60011',
                 'range': '(,{}]'.format(attr_max),
                 'attr_name': name,
                 'value': str(value)
             }
             raise RuntimeError(args_dict,
                                err_mana.get_error_message(args_dict))
     elif not attr_max:
         if value < attr_min:
             args_dict = {
                 'errCode': 'E60011',
                 'range': '[{},)'.format(attr_min),
                 'attr_name': name,
                 'value': str(value)
             }
             raise RuntimeError(args_dict,
                                err_mana.get_error_message(args_dict))
     elif value > attr_max or value < attr_min:
         args_dict = {
             'errCode': 'E60011',
             'range': '[{},{}]'.format(attr_min, attr_max),
             'attr_name': name,
             'value': str(value)
         }
         raise RuntimeError(args_dict,
                            err_mana.get_error_message(args_dict))
예제 #2
0
    def _calcute_input_shape():
        if ori_format_x == "NHWC":
            x_shape = (ori_shape_x[0], ori_shape_x[3], ori_shape_x[1],
                       ori_shape_x[2])
        elif ori_format_x == "NCHW":
            x_shape = ori_shape_x
        else:
            dict_args = {}
            dict_args['errCode'] = "E60008"
            dict_args['param_name'] = "x"
            dict_args['expected_format_list'] = "[{}, {}]".\
                format("NHWC", "NCHW")
            dict_args["format"] = ori_format_x
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if ori_format_out_backprop == "NCHW":
            shape_out = ori_shape_out_backprop
        elif ori_format_out_backprop == "NHWC":
            shape_out = (ori_shape_out_backprop[0], ori_shape_out_backprop[3],
                         ori_shape_out_backprop[1], ori_shape_out_backprop[2])
        else:
            dict_args = {}
            dict_args['errCode'] = "E60008"
            dict_args['param_name'] = "out_backprop"
            dict_args['expected_format_list'] = "[{}, {}]".\
                format("NHWC", "NCHW")
            dict_args["format"] = ori_format_out_backprop
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        return x_shape, shape_out
예제 #3
0
 def _check_attr_range_dw(name, value, attr_min=None, attr_max=None):
     if not attr_min and not attr_max:
         return
     if not attr_min:
         if (not isinstance(value, int)) or value > attr_max:
             dict_args = {}
             dict_args['errCode'] = "E64001"
             dict_args['range'] = "(, {}]".format(attr_max)
             dict_args['attr_name'] = name
             dict_args["value"] = str(value)
             raise RuntimeError(dict_args,
                                err_man.get_error_message(dict_args))
     elif not attr_max:
         if (not isinstance(value, int)) or value < attr_min:
             dict_args = {}
             dict_args['errCode'] = "E64001"
             dict_args['range'] = "[{}, )".format(attr_min)
             dict_args['attr_name'] = name
             dict_args["value"] = str(value)
             raise RuntimeError(dict_args,
                                err_man.get_error_message(dict_args))
     elif(not isinstance(value, int)) or value > attr_max \
             or value < attr_min:
         dict_args = {}
         dict_args['errCode'] = "E64001"
         dict_args['range'] = "[{},{}]".format(attr_min, attr_max)
         dict_args['attr_name'] = name
         dict_args["value"] = str(value)
         raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #4
0
 def _check_attr_range(attr_name, attr_value, attr_min=None, attr_max=None):
     if not attr_min and not attr_max:
         return
     if not attr_min:
         if attr_value > attr_max:
             args_dict = {
                 "errCode": "E60114",
                 "reason": "{} exceed max_value."
                           " max_value={}.".format(attr_name, attr_max),
                 "value": "attr_value = {}".format(attr_value)
             }
             raise RuntimeError(args_dict,
                                err_man.get_error_message(args_dict))
     elif not attr_max:
         if attr_value < attr_min:
             args_dict = {
                 "errCode": "E60114",
                 "reason": "{} less than min_value. "
                           "min_value={}.".format(attr_name, attr_min),
                 "value": "attr_value = {}".format(attr_value)
             }
             raise RuntimeError(args_dict,
                                err_man.get_error_message(args_dict))
     elif attr_value < attr_min or attr_value > attr_max:
         args_dict = {
             "errCode": "E60011",
             "range": "[{},{}]".format(attr_min, attr_max),
             "attr_name": attr_name,
             "value": attr_value
         }
         raise RuntimeError(args_dict,
                            err_man.get_error_message(args_dict))
예제 #5
0
def _bias_check(input_x1, input_x2, bias, trans_a, trans_b):
    if input_x1["ori_format"] == "ND" and input_x2["ori_format"] == \
            "ND" and bias["ori_format"] == "ND":
        shape_a = list(input_x1["ori_shape"])
        shape_b = list(input_x2["ori_shape"])
        shape_bias = list(bias["ori_shape"])

        if trans_a:
            a_m = shape_a[1]
        else:
            a_m = shape_a[0]

        if trans_b:
            b_n = shape_b[0]
        else:
            b_n = shape_b[1]
        if shape_bias != [a_m, b_n]:
            args_dict = {
                "errCode": "E60000",
                "param_name": "c shape",
                "expected_value": str([a_m, b_n]),
                "input_value": "{}".format(shape_bias)
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
    else:
        shape_a = list(input_x1["shape"])
        shape_b = list(input_x2["shape"])
        shape_bias = list(bias["shape"])
        if len(shape_bias) == 2:
            shape_bias = [
                ceil(shape_bias[1] / cce.BLOCK_OUT),
                ceil(shape_bias[0] / cce.BLOCK_IN)
            ]
        else:
            shape_bias = shape_bias[:2]
        if input_x2["dtype"] == "int8" and shape_bias != [
                shape_b[1], shape_a[1]
        ]:
            args_dict = {
                "errCode": "E60000",
                "param_name": "c shape",
                "expected_value": str([shape_a[1], shape_b[1]]),
                "input_value": "{}".format(shape_bias)
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if input_x2["dtype"] == "float16" and shape_bias != [
                shape_b[0], shape_a[1]
        ]:
            args_dict = {
                "errCode": "E60000",
                "param_name": "c shape",
                "expected_value": str([shape_a[1], shape_b[0]]),
                "input_value": "{}".format(shape_bias)
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
    def _check_shape(fmap_shape, dout_shape, filter_shape):
        """Check input shape."""
        fmap_n, fmap_c, _, _ = fmap_shape
        dout_n, dout_c, _, _ = dout_shape
        _, _, filter_c, filter_n = filter_shape

        if filter_n != 1:
            dict_args = {
                'errCode': 'E60000',
                'op_name': 'depthwise_conv2d_backprop_filter',
                'param_name': 'filter_n',
                'expected_value': '1',
                'input_value': str(filter_n)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if fmap_c != dout_c:
            dict_args = {
                'errCode': 'E60002',
                'op_name': 'depthwise_conv2d_backprop_filter',
                'attr_name': 'channel value',
                'param1_name': 'fmap_c',
                'param2_name': 'dout_c',
                'param1_value': str(fmap_c),
                'param2_value': str(dout_c)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if fmap_n != dout_n:
            dict_args = {
                'errCode': 'E60002',
                'op_name': 'depthwise_conv2d_backprop_filter',
                'attr_name': 'channel value',
                'param1_name': 'fmap_n',
                'param2_name': 'dout_n',
                'param1_value': str(fmap_n),
                'param2_value': str(dout_n)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
        if fmap_c != filter_c:
            dict_args = {
                'errCode': 'E60002',
                'op_name': 'depthwise_conv2d_backprop_filter',
                'attr_name': 'channel value',
                'param1_name': 'fmap_c',
                'param2_name': 'filter_c',
                'param1_value': str(fmap_c),
                'param2_value': str(filter_c)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
예제 #7
0
 def _align(input_x, input_y):
     if input_y == 0:
         dict_args = {}
         dict_args['errCode'] = "E60108"
         dict_args['reason'] = "Division by zero"
         raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
     return (input_x + input_y - 1) // input_y * input_y
예제 #8
0
    def _check_attr_pads():
        # pads check
        if isinstance(pads, (tuple, list)) and \
                len(pads) != PADDING_SHAPE_DIM:
            args_dict = {'errCode': 'E62501', 'param_name': 'pads'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

        if isinstance(pads, str) and pads not in PADDING_SUPPORT:
            args_dict = {
                'errCode': 'E60021',
                'expected_pad_mode': '[{}]'.format(PADDING_SUPPORT),
                'actual_pad_mode': str(pads)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
예제 #9
0
    def _normalize_shape_ndchw(ori_shape,
                               ori_format,
                               format_list,
                               param_name='input_param'):
        """
        normalizing the shape to NDCHW
        """
        if ori_format not in format_list:
            args_dict = {
                'errCode': 'E60008',
                'param_name': param_name,
                'expected_format_list': ','.join(format_list),
                'format': ori_format
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

        n_index = ori_format.find('N')
        d_index = ori_format.find('D')
        c_index = ori_format.find('C')
        h_index = ori_format.find('H')
        w_index = ori_format.find('W')

        new_shape = [
            ori_shape[n_index], ori_shape[d_index], ori_shape[c_index],
            ori_shape[h_index], ori_shape[w_index]
        ]

        return new_shape
예제 #10
0
    def _check_attr_pads():
        # pads check
        if isinstance(pads, (tuple, list)) and \
                len(pads) != CONV_BACKPROP_SHAPE_DIM:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "pads"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if isinstance(pads, str) and pads not in PADDING_SUPPORT:
            dict_args = {}
            dict_args['errCode'] = "E60021"
            dict_args['expected_pad_mode'] = str(PADDING_SUPPORT)
            dict_args['actual_pad_mode'] = str(pads)

            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #11
0
 def _ceil(x_1, x_2):
     if x_2 == 0:
         dict_args = {}
         dict_args['errCode'] = "E60108"
         dict_args['reason'] = "Division by zero"
         raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
     return (x_1 + x_2 - 1) // x_2
예제 #12
0
def get_filter_shape(ori_format_filters, ori_shape_filters):
    """
    Get filter shape of NCHW from original shape
    :param ori_format_filters:
    :param ori_shape_filters:
    :return: filter shape of NCHW
    """
    if ori_format_filters == "NCHW":
        shape_filters = ori_shape_filters
    elif ori_format_filters == "NHWC":
        shape_filters = (ori_shape_filters[0],
                         ori_shape_filters[3],
                         ori_shape_filters[1],
                         ori_shape_filters[2])
    elif ori_format_filters == "HWCN":
        shape_filters = (ori_shape_filters[3],
                         ori_shape_filters[2],
                         ori_shape_filters[0],
                         ori_shape_filters[1])
    else:
        args_dict = {
            "errCode": "E60004",
            "param_name": "filter",
            "expected_format_list": "[NCHW,NHWC,HWCN]",
            "format": ori_format_filters
        }
        raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
    return shape_filters
예제 #13
0
    def _check_pads():
        if isinstance(pads, (tuple, list)) \
                and len(pads) != CONV_BACKPROP_SHAPE_DIM:
            args_dict = {
                "errCode": "E60107",
                "param_name": "pads"
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))

        if isinstance(pads, str) and pads not in PADDING_SUPPORT:
            args_dict = {
                "errCode": "E60021",
                "expected_pad_mode": PADDING_SUPPORT,
                "actual_pad_mode": pads
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
예제 #14
0
    def _check_shape_relation():
        if fmap_channel != filter_channel:
            args_dict = {
                "errCode": "E60002",
                "attr_name": "shape",
                "param1_name": "Fmap'C",
                "param1_value": fmap_channel,
                "param2_name": "Filter'C",
                "param2_value": filter_channel
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if dedy_channel != filter_batch:
            args_dict = {
                "errCode": "E60002",
                "attr_name": "shape",
                "param1_name": "Dedy's C",
                "param1_value": dedy_channel,
                "param2_name": "Filter'N",
                "param2_value": filter_batch
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if fmap_batch != dedy_batch:
            args_dict = {
                "errCode": "E60002",
                "attr_name": "shape",
                "param1_name": "Fmap's N",
                "param1_value": fmap_batch,
                "param2_name": "Dedy'N",
                "param2_value": dedy_batch
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if filter_h_dilation > fmap_h_padding and h_match_rule:
            args_dict = {
                "errCode": "E60014",
                "h_of_x": fmap_h_padding,
                "h_of_filter": filter_h_dilation
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if filter_w_dilation > fmap_w_padding:
            args_dict = {
                "errCode": "E60014",
                "h_of_x": fmap_w_padding,
                "h_of_filter": filter_w_dilation
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))

        if pad_up >= filter_h_dilation or pad_down >= filter_h_dilation:
            args_dict = {
                "errCode": "E60016",
                "h_of_filter": filter_h_dilation,
                "h_of_pad": pad_up if pad_up > pad_down else pad_down
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
        if pad_left >= filter_w_dilation or pad_right >= filter_w_dilation:
            args_dict = {
                "errCode": "E60017",
                "w_of_filter": filter_w_dilation,
                "w_of_pad": pad_left if pad_left > pad_right else pad_right
            }
            raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
예제 #15
0
 def _check_64bits_limitation(attr_name, attr_value, dtype=None):
     if dtype:
         bit_ratio = BIT_RATIO_DICT.get(dtype)
     else:
         bit_ratio = BIT_RATIO_DICT.get("float16")
     if attr_value * bit_ratio > DATA_SIZE_MAX:
         args_dict = {'errCode': 'E60020', 'attr_name': attr_name}
         raise RuntimeError(args_dict,
                            err_mana.get_error_message(args_dict))
예제 #16
0
def align(input_x, input_y):
    if input_y == 0:
        args_dict = {
            'errCode': 'E62502',
            'first_operand': str(input_x),
            'second_operand': str(input_y)
        }
        raise RuntimeError(args_dict, err_mana.get_error_message(args_dict))
    return (input_x + input_y - 1) // input_y * input_y
예제 #17
0
 def _ceil(x_1, x_2):
     if x_2 == 0:
         args_dict = {
             'errCode': 'E62502',
             'first_operand': str(x_1),
             'second_operand': str(x_2)
         }
         raise RuntimeError(args_dict,
                            err_mana.get_error_message(args_dict))
     return (x_1 + x_2 - 1) // x_2
예제 #18
0
    def _check_shape(fmap_shape, filter_shape):
        """check input shape"""
        _, in_c1, _, _, _ = fmap_shape
        filter_c1, _, _, filter_k, _, _ = filter_shape

        # check feature map API feature map  shape is 5hd
        # The shape of feature map and filter must be 5HD
        if len(fmap_shape) != FEATURE_MAP_DIM:
            dict_args = {
                'errCode': 'E60008',
                'op_name': 'depthwise_conv2d',
                'param_name': 'featuremap',
                'expected_format_list': '[{}]'.format('NC1HWC0'),
                'format': fmap_data_format
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

        # check feature map shape of c, equal filter of c
        if in_c1 != filter_c1:
            dict_args = {
                'errCode': 'E60002',
                'op_name': 'depthwise_conv2d',
                'attr_name': 'channel',
                'param1_name': 'fmap',
                'param2_name': 'filter',
                'param1_value': str(in_c1),
                'param2_value': str(filter_c1)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))

        # check multiplier equal 1
        if filter_k != 1:
            dict_args = {
                'errCode': 'E60000',
                'op_name': 'depthwise_conv2d',
                'param_name': 'filter_k',
                'expected_value': '1',
                'input_value': str(filter_k)
            }
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
예제 #19
0
 def _check_64bits_limitation(attr_name, attr_value, dtype=None):
     if dtype:
         bit_ratio = BIT_RATIO_DICT.get(dtype)
     else:
         bit_ratio = BIT_RATIO_DICT.get("float16")
     if attr_value * bit_ratio > DATA_SIZE_MAX:
         dict_args = {}
         dict_args['errCode'] = "E60020"
         dict_args['attr_name'] = attr_name
         raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #20
0
 def _check_attr_range(attr_name, attr_value, attr_min, attr_max):
     if attr_value < attr_min or attr_value > attr_max:
         dict_args = {
             'errCode': 'E60011',
             'range': '[{},{}]'.format(attr_min, attr_max),
             'attr_name': attr_name,
             'value': str(attr_value)
         }
         raise RuntimeError(dict_args,
                            err_mana.get_error_message(dict_args))
예제 #21
0
    def _check_axis_hw():
        if fmap_batch != dedy_batch:
            args_dict = {
                'errCode': 'E62503',
                'backprop_N': str(dedy_batch),
                'forward_shape': str(fmap_batch)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if dedy_channel != filter_batch:
            args_dict = {
                'errCode': 'E62504',
                'backprop_C': str(dedy_channel),
                'forward_shape': str(filter_batch)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if fmap_channel != filter_channel:
            args_dict = {
                'errCode': 'E60010',
                'channel_of_x': str(fmap_channel),
                'channel_of_filter': str(filter_channel)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if filter_w_dilation > fmap_w_padding:
            args_dict = {
                'errCode': 'E60015',
                'w_of_x': str(fmap_w_padding),
                'w_of_filter': str(filter_w_dilation)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if filter_h_dilation > fmap_h_padding:
            args_dict = {
                'errCode': 'E60014',
                'h_of_x': str(fmap_h_padding),
                'h_of_filter': str(filter_h_dilation)
            }
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))

        # Third : value check, Mainly required by the convolution rule
        if ((fmap_w - filter_w_dilation + pad_left + pad_right) // stride_w +
                1) != dedy_w:
            args_dict = {'errCode': 'E60025'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
        if ((fmap_h - filter_h_dilation + pad_up + pad_down) // stride_h +
                1) != dedy_h:
            args_dict = {'errCode': 'E60024'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
예제 #22
0
def config_dynamic_para(shape_dedy):
    if shape_dedy[2] == shape_dedy[3] == -1 and shape_dedy[0] != -1 and shape_dedy[1] != -1:
        dynamic_mode = "dynamic_hw"
    elif shape_dedy[0] == -1 and -1 not in shape_dedy[1:]:
        dynamic_mode = "dynamic_batch"
    else:
        args_dict = {
            "errCode": "E60108",
            "op_name": "out_backprop",
            "reason": "only support dynamic_hw and dynamic_batch now."
        }
        raise RuntimeError(args_dict, err_man.get_error_message(args_dict))

    return dynamic_mode
예제 #23
0
    def _min_l1_byte():
        # Forth : L1 limitation, Mainly required by chip
        al1_min_byte = C0 * C0 * 2

        if dedy_w % C0 == 0:
            bl1_min_byte = filter_h_dilation * fmap_w * C0 * 2
        else:
            bl1_min_byte = (filter_h_dilation + stride_h) * fmap_w * C0 * 2

        l1_size = get_soc_spec("L1_SIZE")  # L1 size
        if (al1_min_byte + bl1_min_byte) > l1_size:
            args_dict = {'errCode': 'E60022'}
            raise RuntimeError(args_dict,
                               err_mana.get_error_message(args_dict))
예제 #24
0
def align(x_1, x_2):
    """
    Get minimum y: y >= x_1 and y % x_2 == 0
    :param x_1:
    :param x_2:
    :return: minimum y: y >= x_1 and y % x_2 == 0
    """
    if x_2 == 0:
        args_dict = {
            "errCode": "E60114",
            "reason": "Division by zero",
            "value": "x_1 = {}, x_2 = {}".format(x_1, x_2)
        }
        raise RuntimeError(args_dict, err_man.get_error_message(args_dict))
    return (x_1 + x_2 - 1) // x_2 * x_2
예제 #25
0
    def _check_axis_hw():
        if fmap_batch != dedy_batch:
            dict_args = {}
            dict_args['errCode'] = "E64002"
            dict_args['param1'] = "Fmap's N"
            dict_args['param2'] = "Dedy's N"
            dict_args['actual_value'] = "{}, {}".\
                format(fmap_batch, dedy_batch)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if dedy_channel != filter_batch:
            dict_args = {}
            dict_args['errCode'] = "E64002"
            dict_args['param1'] = "Dedy's C"
            dict_args['param2'] = "Filter's N"
            dict_args['actual_value'] = "{}, {}". \
                format(dedy_channel, filter_batch)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if fmap_channel != filter_channel:
            dict_args = {}
            dict_args['errCode'] = "E64002"
            dict_args['param1'] = "Fmap's C"
            dict_args['param2'] = "Filter's C"
            dict_args['actual_value'] = "{}, {}". \
                format(fmap_channel, filter_channel)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if filter_w_dilation > fmap_w_padding:
            dict_args = dict()
            dict_args["errCode"] = "E60015"
            dict_args["w_of_x"] = str(fmap_w_padding)
            dict_args["w_of_filter"] = str(filter_w_dilation)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if filter_h_dilation > fmap_h_padding:
            dict_args = dict()
            dict_args["errCode"] = "E60014"
            dict_args["h_of_x"] = str(fmap_h_padding)
            dict_args["h_of_filter"] = str(filter_h_dilation)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        # Third : value check, Mainly required by the convolution rule
        if ((fmap_w - filter_w_dilation + pad_left + pad_right) // stride_w +
                1) != dedy_w:
            dict_args = {}
            dict_args["errCode"] = "E60025"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if ((fmap_h - filter_h_dilation + pad_up + pad_down) // stride_h +
                1) != dedy_h:
            dict_args = {}
            dict_args["errCode"] = "E60024"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #26
0
    def _min_l1_byte():
        # Forth : L1 limitation, Mainly required by chip
        al1_min_byte = C0 * C0 * 2
        if not _is_conv1d_situation():
            kl1_min = fmap_w
        else:
            kl1_min = (C0 - 1) * stride_w + filter_w_dilation
        if dedy_w % C0 == 0:
            bl1_min_byte = filter_h_dilation * kl1_min * C0 * 2
        else:
            bl1_min_byte = (filter_h_dilation + stride_h) * kl1_min * C0 * 2

        l1_size = get_soc_spec("L1_SIZE")  # L1 size
        if (al1_min_byte + bl1_min_byte) > l1_size:
            dict_args = {}
            dict_args["errCode"] = "E60026"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #27
0
    def _check_inputs_rules():
        if (not isinstance(ori_shape_out_backprop, (tuple, list))) \
                or len(ori_shape_out_backprop) != 4:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "out_backprop"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if (not isinstance(ori_shape_x, (tuple, list))) or \
                len(ori_shape_x) != 4:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "x"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if (not isinstance(ori_shape_res, (tuple, list))) \
                or len(ori_shape_res) != 4:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "y"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if len(strides) != 2:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "strides"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if len(filter_size) != 4:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "filter_size"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))

        if len(dilations) != 4:
            dict_args = dict()
            dict_args["errCode"] = "E60107"
            dict_args["param_name"] = "dilations"
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
        if list(filter_size) != list(ori_shape_res):
            dict_args = {}
            dict_args['errCode'] = "E64002"
            dict_args['param1'] = "filter_size"
            dict_args['param2'] = "ori_shape of y"
            dict_args['actual_value'] = "{}, {}".\
                format(filter_size, ori_shape_res)
            raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
예제 #28
0
    def _check_l1_limitation():
        block_size = 16
        w_value = dedy_w * stride_w
        if fmap_w > block_size:
            h_value_max = filter_h_dilation + 1
        elif block_size % fmap_w == 0:
            h_value_max = filter_h_dilation + block_size // fmap_w - 1
        else:
            h_value_max = filter_h_dilation + block_size // fmap_w + 1

        a_l1_size = h_value_max * w_value * \
                    ((filter_d_dilation - 2)//stride_d + 2) * block_size * 2
        b_l1_size = filter_h_dilation * filter_w_dilation * \
                    filter_d_dilation * block_size * block_size * 2
        l1_size = get_soc_spec("L1_SIZE")
        if (a_l1_size + b_l1_size) > l1_size:
            dict_args = {'errCode': 'E60022'}
            raise RuntimeError(dict_args,
                               err_mana.get_error_message(dict_args))
예제 #29
0
def get_shape_dilation(data_format, dilations):
    """
    Get result shape of NCHW from original shape
    :param ori_format_res:
    :param ori_shape_res:
    :return: result shape of NCHW
    """
    if data_format == "NCHW":
        shape_dilations = dilations
    elif data_format == "NHWC":
        shape_dilations = (dilations[0], dilations[3], dilations[1],
                           dilations[2])
    else:
        dict_args = {}
        dict_args['errCode'] = "E60004"
        dict_args['param_name'] = "data_format"
        dict_args['expected_format_list'] = "[{}, {}]".format("NHWC", "NCHW")
        dict_args["format"] = data_format
        raise RuntimeError(dict_args, err_man.get_error_message(dict_args))
    return shape_dilations
예제 #30
0
def ceil(x_1, x_2):
    """
    do ceiling division

    Parameters
    ----------
    x_1: int
    x_2: int
    Returns
    -------
    result
    """
    if x_2 == 0:
        dict_args = {
            'errCode': 'E62502',
            'first_operand': str(x_1),
            'second_operand': str(x_2),
        }
        raise RuntimeError(dict_args, err_mana.get_error_message(dict_args))
    return (x_1 + x_2 - 1) // x_2