def test_quantize_dynamic(): # A wrapper is required for quantize_dynamic to work correctly class LinearWrapper(nn.Module): def __init__(self, in_dim, hidden_dim): super().__init__() self.linear = nn.Linear(in_dim, hidden_dim) def forward(self, inp): return self.linear(inp) torch.manual_seed(0) mod = LinearWrapper(16, 32) for qconfig in [ torch.quantization.per_channel_dynamic_qconfig, torch.quantization.default_dynamic_qconfig, ]: for ishape in [(16, 16), (10, 16, 16)]: qspec = {nn.Linear: qconfig} qmod = torch.quantization.quantize_dynamic(mod, qconfig_spec=qspec, dtype=torch.qint8) inp = torch.randn(*ishape) script_module = torch.jit.trace(qmod, inp).eval() with torch.no_grad(): pt_result = script_module(inp.clone()).numpy() input_name = "input" runtime = get_tvm_runtime(script_module, "input", inp.shape) runtime.set_input(input_name, inp.numpy().copy()) runtime.run() tvm_result = runtime.get_output(0).numpy() # Only compare with the PyTorch result for version v1.6 or newer # Have seen a strange accuracy problem from PyTorch 1.4 and 1.5 # Even with the manual random seed set, the same PyTorch # version can outputs slightly different results depending on an environment. # Outputs from v1.6 seem reliable. TVM's outputs are always the same if is_version_greater_than("1.5.1"): tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4, atol=1e-4)
def test_quantized_modules(): imagenet_ishape = (1, 3, 224, 224) qmodules = [ ("relu", imagenet_ishape, ReLU(), False), ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False), ("avgpool", imagenet_ishape, AvgPool2d(), False), ] for per_channel in [False, True]: if per_channel: postfix = ", per_channel" else: postfix = "" qmodules += [ ("conv_bn" + postfix, imagenet_ishape, ConvBn(), per_channel), ("conv_bn_relu" + postfix, imagenet_ishape, ConvBn(with_relu=True), per_channel), ("linear" + postfix, (16, 16), Linear(), per_channel), ("linear_relu" + postfix, (16, 16), Linear(with_relu=True), per_channel), ("conv_transpose", imagenet_ishape, ConvTranspose(), False), ("hsigmoid", imagenet_ishape, Hsigmoid(add_stub=True), False), ("hswish", imagenet_ishape, Hswish(add_stub=True), False), ("semodule", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), False), ("semodule, per_channel", (1, 16, 64, 64), SqueezeExcite(16, add_stub=True), True), ("mul_scalar negative", imagenet_ishape, MulScalarNegative(), False), ] for (module_name, ishape, raw_module, per_channel) in qmodules: raw_module.eval() inp = torch.rand(ishape) # quantized conv_transpose2d is supported only with qnnpack engine before torch v1.8.0. if module_name == "conv_transpose" and not is_version_greater_than( "1.7.1"): prev_engine = torch.backends.quantized.engine torch.backends.quantized.engine = "qnnpack" quantize_model(raw_module, inp, per_channel=per_channel) torch.backends.quantized.engine = prev_engine else: quantize_model(raw_module, inp, per_channel=per_channel) script_module = torch.jit.trace(raw_module, inp).eval() with torch.no_grad(): pt_result = script_module(inp.clone()).numpy() input_name = "input" runtime = get_tvm_runtime(script_module, input_name, ishape) runtime.set_input(input_name, inp.numpy().copy()) runtime.run() tvm_result = runtime.get_output(0).numpy() max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) match_ratio = num_identical / float(np.prod(tvm_result.shape)) print(module_name, max_abs_diff, mean_abs_diff, match_ratio) # sample outputs """
def test_quantized_imagenet(): def get_transform(): import torchvision.transforms as transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) return transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) def get_real_image(im_height, im_width): repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/" img_name = "elephant-299.jpg" image_url = os.path.join(repo_base, img_name) img_path = download_testdata(image_url, img_name, module="data") return Image.open(img_path).resize((im_height, im_width)) def get_imagenet_input(): im = get_real_image(224, 224) preprocess = get_transform() pt_tensor = preprocess(im) return np.expand_dims(pt_tensor.numpy(), 0) from torchvision.models.quantization import resnet as qresnet from torchvision.models.quantization import mobilenet as qmobilenet from torchvision.models.quantization import inception as qinception from torchvision.models.quantization import googlenet as qgooglenet qmodels = [] for per_channel in [False, True]: qmodels += [ ("resnet18", qresnet.resnet18(pretrained=True), per_channel), ("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel), # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 # ("inception_v3", qinception.inception_v3(pretrained=True), per_channel), # tracing quantized googlenet broken as of v1.6 # ("googlenet", qgooglenet(pretrained=True), per_channel), ] if is_version_greater_than("1.7.1"): from torchvision.models.quantization import mobilenet_v3_large as qmobilenet_v3_large qmodels.append(("mobilenet_v3_large", qmobilenet_v3_large(pretrained=True, quantize=True).eval(), True)) results = [] for (model_name, raw_model, per_channel) in qmodels: raw_model.eval() if per_channel: model_name += ", per channel quantization" else: model_name += ", per tensor quantization" inp = get_imagenet_input() pt_inp = torch.from_numpy(inp) if "mobilenet_v3_large" not in model_name: # mv3 was qat-ed, quantize=True option above makes it already quantized quantize_model(raw_model, pt_inp, per_channel=per_channel) script_module = torch.jit.trace(raw_model, pt_inp).eval() with torch.no_grad(): pt_result = script_module(pt_inp).numpy() input_name = "image" runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224)) runtime.set_input(input_name, inp) runtime.run() tvm_result = runtime.get_output(0).asnumpy() results.append((model_name, pt_result[0], tvm_result[0])) for (model_name, pt_result, tvm_result) in results: max_abs_diff = np.max(np.abs(tvm_result - pt_result)) mean_abs_diff = np.mean(np.abs(tvm_result - pt_result)) num_identical = np.sum(tvm_result == pt_result) pt_top3_labels = np.argsort(pt_result)[::-1][:3] tvm_top3_labels = np.argsort(tvm_result)[::-1][:3] print("\nModel name: %s" % model_name) print("PyTorch top3 label:", pt_top3_labels) print("TVM top3 label:", tvm_top3_labels) print("max abs diff:", max_abs_diff) print("mean abs_diff:", mean_abs_diff) print("%d in 1000 raw outputs identical." % num_identical) assert set(pt_top3_labels) == set(tvm_top3_labels) # sample outputs """