def compile_conv2d_NHWC_gemm_int8_arm(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size A = te.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='int8') W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') bias = te.placeholder((num_filter,), name='bias', dtype='int8') dtype = 'int32' device = "llvm --device arm_cpu --mtriple aarch64-linux-gnu" ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) return print("Compiling on arm AArch64 target: %s" % device) with tvm.target.create(device): assert is_aarch64_arm(), "AArch64 target not recognized" C = topi.arm_cpu.compute_conv2d_NHWC_quantized(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = topi.arm_cpu.schedule_conv2d_NHWC_quantized([C]) if add_bias: tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) else: func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
def compile_conv2d_NHWC_gemm_int8_arm( batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, ): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple( padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) in_height = in_width = in_size A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="int8") W = te.placeholder((kernel, kernel, in_channel, num_filter), name="W", dtype="int8") bias = te.placeholder((num_filter, ), name="bias", dtype="int8") dtype = "int32" devices = [ ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu", topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod", topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, ), ( "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod", topi.arm_cpu.compute_conv2d_NHWC_quantized_native, topi.arm_cpu.schedule_conv2d_NHWC_quantized_native, ), # TODO(giuseros) Need LLVM-11 in order to compile with +i8mm extension # ( # "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+i8mm", # topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved, # topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved, # ), ] for device_tuple in devices: target = device_tuple[0] compute = device_tuple[1] schedule = device_tuple[2] dev = tvm.device(target, 0) if not tvm.testing.device_enabled(target): print("Skip because %s is not enabled" % target) return print("Compiling on arm AArch64 target: %s" % target) with tvm.target.Target(target): assert is_aarch64_arm(), "AArch64 target not recognized" C = compute(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: C = topi.add(C, bias) if add_relu: C = topi.nn.relu(C) s = schedule([C]) if add_bias: tvm.build( s, [A, W, bias, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func = tvm.build( s, [A, W, bias, C], target, name="relu_%dnnn_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) else: func = tvm.build( s, [A, W, C], target, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), )