def scratch_cmsis_depthwise_conv_2d(mace_op, mace_net):
    output_channels = mace_op.output_shape[0].dims[3]
    bias_bytes = output_channels * 4
    cmsis_quant_bytes = output_channels * 2
    input_dims = NetUtil.get_input_dims(mace_op, mace_net, 0)
    filter_dims = NetUtil.get_input_dims(mace_op, mace_net, 1)
    cmsis_nn_buffer_bytes = input_dims[3] * filter_dims[2] * filter_dims[1] * 2
    return cmsis_nn_buffer_bytes + bias_bytes + cmsis_quant_bytes
def scratch_xtensa_depthwise_conv_2d(mace_op, mace_net):
    output_channels = mace_op.output_shape[0].dims[3]
    bias_bytes = output_channels * 4

    input_dims = NetUtil.get_input_dims(mace_op, mace_net, 0)
    input_height = input_dims[1]
    input_width = input_dims[2]
    input_channels = input_dims[3]

    output_dims = mace_op.output_shape[0].dims
    output_height = output_dims[1]
    output_width = output_dims[2]

    filter_dims = NetUtil.get_input_dims(mace_op, mace_net, 1)
    kernel_height = filter_dims[1]
    kernel_width = filter_dims[2]
    channels_multiplier = filter_dims[0]

    strides = NetUtil.get_arg(mace_op, "strides").ints
    x_stride = strides[0]
    y_stride = strides[1]

    padding = NetUtil.calc_padding(mace_op, mace_net)
    x_padding = padding[0]
    y_padding = padding[1]

    # xa_nn_conv2d_depthwise_getsize
    data_type = NetUtil.get_arg(mace_op, "T").i
    # data_format = NetUtil.get_arg(mace_op, "data_format").i
    if data_type == mace_pb2.DT_FLOAT:
        scratch_bytewidth = 4  # f32 scratch
        circ_buf_bytewidth = 4  # bytewidth
        bytewidth = circ_buf_bytewidth
    else:
        mace_check(False, "Unsupported")

    state_size = aligned_size(24, ALIGNMENT)

    circ_buf_height = kernel_height + ((output_height - 1) * y_stride)
    circ_buf_height = max(circ_buf_height, y_padding + input_height)

    if bytewidth == 4:
        circ_buf_channels = aligned_size(input_channels*channels_multiplier, 2)
    else:
        circ_buf_channels = aligned_size(input_channels*channels_multiplier, 4)

    size_in_bytes = bytewidth*circ_buf_height*circ_buf_channels*kernel_width
    circ_buf_size = size_in_bytes

    xtensa_total_size = state_size + circ_buf_size

    return xtensa_total_size * 4 + bias_bytes
def scratch_xtensa_conv_2d(mace_op, mace_net):
    output_channels = mace_op.output_shape[0].dims[3]
    bias_bytes = output_channels * 4

    input_dims = NetUtil.get_input_dims(mace_op, mace_net, 0)
    input_height = input_dims[1]
    input_width = input_dims[2]
    input_channels = input_dims[3]

    output_dims = mace_op.output_shape[0].dims
    out_height = output_dims[1]
    out_width = output_dims[2]

    filter_dims = NetUtil.get_input_dims(mace_op, mace_net, 1)
    kernel_height = filter_dims[1]
    kernel_width = filter_dims[2]

    strides = NetUtil.get_arg(mace_op, "strides").ints
    x_stride = strides[0]
    y_stride = strides[1]

    padding = NetUtil.calc_padding(mace_op, mace_net)
    x_padding = padding[0]
    y_padding = padding[1]

    # xa_nn_conv2d_std_getsize
    mem_req = 0
    input_size = 0
    align_size = 0

    mem_req += 12 + ALIGNMENT - 1
    data_type = NetUtil.get_arg(mace_op, "T").i
    if data_type == mace_pb2.DT_FLOAT:
        input_size = 4
        align_size = ALIGNMENT >> 2
    else:
        mace_check(False, "Unsupported")

    y_b_pad = kernel_height + (out_height - 1) * \
        y_stride - (y_padding + input_height)
    y_b_pad = max(0, y_b_pad)
    input_channels_pad = aligned_size(input_channels, align_size)
    cir_buf_size_bytes = (y_padding + input_height + y_b_pad) * \
        kernel_width * input_channels_pad * input_size

    mem_req += cir_buf_size_bytes
    mem_req += BUS_WIDTH

    return int(mem_req * 4 + bias_bytes)
 def valid_data_type(self, mace_op, mace_net):
     arg = NetUtil.get_arg(mace_op, MaceKeyword.mace_op_data_type_str)
     mace_check(arg is not None, "mace_op should has a explicit data type")
     if (self._data_type == mace_pb2.DT_FLOAT):
         return arg.i == mace_pb2.DT_FLOAT or arg.i == mace_pb2.DT_BFLOAT16
     else:
         return arg.i == self._data_type
    def valid_tag(self, mace_op, mace_net):
        tag = ""
        kernels = NetUtil.get_arg(mace_op, MaceKeyword.mace_kernel_str)
        mace_check(kernels is not None, "Get kernels failed.")
        size = kernels.ints[0] * kernels.ints[1]
        if size >= 4:
            tag = "s4"

        return tag == self._tag
    def valid_tag(self, mace_op, mace_net):
        tag = ""
        output_shape = mace_op.output_shape[0].dims
        size = output_shape[0] * output_shape[1] * output_shape[2]
        if size >= 4:
            size = 4
        filter_dims = NetUtil.get_input_dims(mace_op, mace_net, 1)
        k_batch = filter_dims[0]
        if k_batch >= 4:
            k_batch = 4
        if size >= 4:
            tag = "kb%ss%s" % (k_batch, size)

        return tag == self._tag
def scratch_pooling(mace_op, mace_net):
    input_dims = NetUtil.get_input_dims(mace_op, mace_net, 0)
    channels = input_dims[3]
    return channels * (4 + 4)