def test_dense_cuda_int8( target, dev, batch_size, in_dim, out_dim, use_bias, dense_ref_data, in_dtype, out_dtype, ): implementations = [ (topi.cuda.dense_int8, topi.cuda.schedule_dense_int8), ] with Int8Fallback(): test_dense( target, dev, batch_size, in_dim, out_dim, use_bias, dense_ref_data, in_dtype, out_dtype, implementations=implementations, )
def test_batch_matmul_int8(): with Int8Fallback(): verify_batch_matmul_int8(1, 1, 2, 3, 1) verify_batch_matmul_int8(1, 1, 16, 24, 32) verify_batch_matmul_int8(5, 5, 24, 16, 32) verify_batch_matmul_int8(30, 30, 16, 20, 32) verify_batch_matmul_int8(1, 5, 16, 16, 32) verify_batch_matmul_int8(5, 1, 16, 16, 32)
def test_conv2d_nhwc(): with Int8Fallback(): # Subset of inception v3 expanded (dilation > 1, batch > 1, 'VALID' padding) verify_conv2d_NHWC_gemm_int8(1, 3, 299, 32, 3, 2, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 32, 149, 32, 3, 1, 'SAME', dilation=2) verify_conv2d_NHWC_gemm_int8(4, 32, 147, 64, 3, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 64, 73, 80, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 80, 73, 192, 3, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 192, 35, 48, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 192, 35, 64, 1, 1, 'VALID') verify_conv2d_NHWC_gemm_int8(1, 192, 35, 32, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 48, 35, 64, 5, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 256, 35, 48, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 256, 35, 64, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 288, 35, 64, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 288, 35, 48, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 96, 35, 96, 3, 2, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 128, 17, 192, 7, 1, 'SAME', dilation=2) verify_conv2d_NHWC_gemm_int8(1, 160, 17, 160, 7, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 160, 17, 192, 1, 1, 'VALID') verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 768, 5, 128, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 192, 17, 320, 3, 2, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 192, 17, 192, 3, 2, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 192, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 384, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 320, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 1280, 8, 448, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 384, 8, 384, 3, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 448, 8, 384, 3, 1, 'VALID') verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 320, 1, 1, 'SAME') verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 448, 1, 1, 'SAME', add_bias=True, add_relu=True) verify_conv2d_NHWC_gemm_int8(1, 2048, 8, 192, 1, 1, 'SAME', add_bias=True) # Let's also verify that it compiles fine on AArch64 targets compile_conv2d_NHWC_gemm_int8_arm(1, 3, 299, 32, 3, 2, 'SAME')
def test_group_conv2d_NCHWc_int8(): with Int8Fallback(): # ResNeXt-50 workload verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 256, 56, 256, 3, 2, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 256, 28, 256, 3, 1, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 512, 28, 512, 3, 2, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 512, 14, 512, 3, 1, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32) verify_group_conv2d_NCHWc_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32) # bias, relu verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True) verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True) verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True) # dilation verify_group_conv2d_NCHWc_int8(1, 128, 56, 128, 3, 1, 1, 2, 32) # batch size verify_group_conv2d_NCHWc_int8(2, 128, 56, 128, 3, 1, 1, 1, 32) verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32)
def test_dense_int8(): with Int8Fallback(): verify_dense_int8(2, 1024, 1000, use_bias=True) verify_dense_int8(2, 1024, 1000, use_bias=False)
def test_conv2d_transpose_NCHWc_int8(): with Int8Fallback(): verify_conv2d_transpose_NCHWc_int8(1, 32, 32, 128, 5, 1, 0) verify_conv2d_transpose_NCHWc_int8(1, 32, 32, 128, 5, 2, 1)
def test_conv2d_nchw(): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 64, 56, 128, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 128, 28, 128, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 128, 28, 256, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 128, 28, 256, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 256, 14, 256, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 256, 14, 512, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 256, 14, 512, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 512, 7, 512, 3, 1, 1) # bias, relu verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) # dilation = 2 verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, dilation=2) # batch size verify_conv2d_NCHWc_int8(4, 64, 56, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(9, 64, 56, 64, 3, 1, 1) # weird workloads verify_conv2d_NCHWc_int8(4, 4, 4, 4, 4, 4, 4) # inception v3 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(1, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(1, 32, 147, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 64, 73, 80, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 80, 73, 192, 3, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 48, 35, 64, 5, 1, 2) verify_conv2d_NCHWc_int8(1, 64, 35, 96, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 192, 35, 32, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 256, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 256, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 384, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 128, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 128, 17, 128, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 128, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 128, 17, 128, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 128, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 160, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 160, 17, 160, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 160, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 160, 17, 160, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 160, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 192, 17, 320, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 1280, 8, 320, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1280, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 384, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 384, 8, 384, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 1280, 8, 448, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 448, 8, 384, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 1280, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 320, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 448, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1) # batch > 1 verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) # Asymmetric padding verify_conv2d_NCHWc_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1)) verify_conv2d_NCHWc_int8(1, 64, 8, 128, 3, 1, (3, 3, 2, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, (1, 2, 2, 1)) verify_conv2d_NCHWc_int8(1, 64, 17, 192, 1, 1, (1, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (3, 1)) verify_conv2d_NCHWc_int8(1, 128, 8, 384, 3, 1, (0, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, "VALID") verify_conv2d_NCHWc_int8(1, 388, 8, 64, 3, 1, "VALID") verify_conv2d_NCHWc_int8(1, 512, 19, 64, 1, 1, "SAME") verify_conv2d_NCHWc_int8(1, 64, 16, 32, 2, 1, "SAME") verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
def test_conv2d_nchw(): with Int8Fallback(): # ResNet18 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 64, 56, 128, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 128, 28, 128, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 128, 28, 256, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 128, 28, 256, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 256, 14, 256, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 256, 14, 512, 3, 2, 1) verify_conv2d_NCHWc_int8(1, 256, 14, 512, 1, 2, 0) verify_conv2d_NCHWc_int8(1, 512, 7, 512, 3, 1, 1) # bias, relu verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) # dilation = 2 verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, 1, dilation=2) # batch size verify_conv2d_NCHWc_int8(4, 64, 56, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(9, 64, 56, 64, 3, 1, 1) # weird workloads verify_conv2d_NCHWc_int8(4, 4, 4, 4, 4, 4, 4) # inception v3 workloads where channels in / out are multiple of oc_block_factor verify_conv2d_NCHWc_int8(1, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(1, 32, 147, 64, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 64, 73, 80, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 80, 73, 192, 3, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 48, 35, 64, 5, 1, 2) verify_conv2d_NCHWc_int8(1, 64, 35, 96, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 192, 35, 32, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 256, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 256, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 64, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 48, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 288, 35, 384, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 96, 35, 96, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 128, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 128, 17, 128, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 128, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 128, 17, 128, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 128, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 768, 17, 160, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 160, 17, 160, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 160, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 160, 17, 160, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 160, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 7, 1, 3) verify_conv2d_NCHWc_int8(1, 192, 17, 320, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 192, 17, 192, 3, 2, 0) verify_conv2d_NCHWc_int8(1, 1280, 8, 320, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1280, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 384, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 384, 8, 384, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 1280, 8, 448, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 448, 8, 384, 3, 1, 1) verify_conv2d_NCHWc_int8(1, 1280, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 320, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 384, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 448, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1) # batch > 1 verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) # Asymmetric padding verify_conv2d_NCHWc_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1)) verify_conv2d_NCHWc_int8(1, 64, 8, 128, 3, 1, (3, 3, 2, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, (1, 2, 2, 1)) verify_conv2d_NCHWc_int8(1, 64, 17, 192, 1, 1, (1, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (3, 1)) verify_conv2d_NCHWc_int8(1, 128, 8, 384, 3, 1, (0, 2)) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 1, 1, "VALID") verify_conv2d_NCHWc_int8(1, 388, 8, 64, 3, 1, "VALID") verify_conv2d_NCHWc_int8(1, 512, 19, 64, 1, 1, "SAME") verify_conv2d_NCHWc_int8(1, 64, 16, 32, 2, 1, "SAME") verify_conv2d_NCHWc_int8(1, 64, 8, 64, 3, 1, (1, 2, 2, 1), add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 8, 64, 5, 2, (1, 3), add_bias=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True) verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True) # Conv2d NCHW int8 schedule testing. Internally, it uses NCHWc schedule. So, just # performing basic testing - one test for all different scenarios - batch, dilation etc.. verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, add_relu=True) verify_conv2d_nchw_int8(1, 64, 56, 64, 3, 1, 1, dilation=2) verify_conv2d_nchw_int8(9, 64, 56, 64, 3, 1, 1) verify_conv2d_nchw_int8(4, 4, 4, 4, 4, 4, 4) verify_conv2d_nchw_int8(1, 32, 149, 32, 3, 1, 0) verify_conv2d_nchw_int8(7, 32, 149, 32, 3, 1, 0) verify_conv2d_nchw_int8(1, 32, 35, 64, 7, 2, (0, 0, 1, 1))