def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding,
                                 dilation=1, add_bias=False, add_relu=False):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter,
                                                          kernel, stride, padding_sum, dilation))

    in_height = in_width = in_size
    A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8')
    W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8')
    bias = te.placeholder((num_filter,), name='bias', dtype='int8')
    dtype = 'int32'
    device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu"

    ctx = tvm.context(device, 0)
    if not ctx.exist:
        print("Skip because %s is not enabled" % device)
        return
    print("Compiling on arm AArch64 target: %s" % device)
    with tvm.target.create(device):
        assert is_aarch64_arm(), "AArch64 target not recognized"

        C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding,
                                                       (dilation, dilation), dtype)
        if add_bias:
            C = topi.add(C, bias)
        if add_relu:
            C = topi.nn.relu(C)
        s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C])

    if add_bias:
        tvm.build(s, [A, W, bias, C], device,
                  name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                         in_channel,
                                                         in_size,
                                                         num_filter,
                                                         kernel,
                                                         stride,
                                                         padding_sum,
                                                         dilation))
        func = tvm.build(s, [A, W, bias, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                in_channel,
                                                                in_size,
                                                                num_filter,
                                                                kernel,
                                                                stride,
                                                                padding_sum,
                                                                dilation))
    else:
        func = tvm.build(s, [A, W, C], device,
                         name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch,
                                                                in_channel,
                                                                in_size,
                                                                num_filter,
                                                                kernel,
                                                                stride,
                                                                padding_sum,
                                                                dilation))
Esempio n. 2
0
def compile_conv2d_NHWC_gemm_int8_arm(
    batch,
    in_channel,
    in_size,
    num_filter,
    kernel,
    stride,
    padding,
    dilation=1,
    add_bias=False,
    add_relu=False,
):
    pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(
        padding, (kernel, kernel))
    padding_sum = pad_top + pad_left + pad_bottom + pad_right
    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
          (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum,
           dilation))

    in_height = in_width = in_size
    A = te.placeholder((batch, in_height, in_width, in_channel),
                       name="A",
                       dtype="int8")
    W = te.placeholder((kernel, kernel, in_channel, num_filter),
                       name="W",
                       dtype="int8")
    bias = te.placeholder((num_filter, ), name="bias", dtype="int8")
    dtype = "int32"
    devices = [
        (
            "llvm --device arm_cpu --mtriple aarch64-linux-gnu",
            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
        ),
        (
            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
            topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
            topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
        ),
        (
            "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
            topi.arm_cpu.compute_conv2d_NHWC_quantized_native,
            topi.arm_cpu.schedule_conv2d_NHWC_quantized_native,
        ),
        # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension
        # (
        #   "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
        #   topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved,
        #   topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved,
        # ),
    ]

    for device_tuple in devices:
        target = device_tuple[0]
        compute = device_tuple[1]
        schedule = device_tuple[2]

        dev = tvm.device(target, 0)
        if not tvm.testing.device_enabled(target):
            print("Skip because %s is not enabled" % target)
            return
        print("Compiling on arm AArch64 target: %s" % target)
        with tvm.target.Target(target):
            assert is_aarch64_arm(), "AArch64 target not recognized"

            C = compute(A, W, (stride, stride), padding, (dilation, dilation),
                        dtype)
            if add_bias:
                C = topi.add(C, bias)
            if add_relu:
                C = topi.nn.relu(C)
            s = schedule([C])

        if add_bias:
            tvm.build(
                s,
                [A, W, bias, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
            func = tvm.build(
                s,
                [A, W, bias, C],
                target,
                name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )
        else:
            func = tvm.build(
                s,
                [A, W, C],
                target,
                name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
                (batch, in_channel, in_size, num_filter, kernel, stride,
                 padding_sum, dilation),
            )