def test_qat_resnet_per_channel(self): # Quantize ResNet50 model x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0) qat_resnet50 = resnet50() qat_resnet50.qconfig = quantization.QConfig( activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant, ) quantization.prepare_qat(qat_resnet50, inplace=True) qat_resnet50.apply(torch.ao.quantization.enable_observer) qat_resnet50.apply(torch.ao.quantization.enable_fake_quant) _ = qat_resnet50(x) for module in qat_resnet50.modules(): if isinstance(module, quantization.FakeQuantize): module.calculate_qparams() qat_resnet50.apply(torch.ao.quantization.disable_observer) self.exportTest(toC(qat_resnet50), toC(x))
def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ model = ManualLinearQATModel() model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data)
def test_conv_linear(self): model = ManualConvLinearQATModel() prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.img_data) convert(model) def checkQuantized(model): self.assertEqual(type(model.conv), nnq.Conv2d) self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.img_data) self.checkScriptable(model, self.img_data) checkQuantized(model) model = ManualConvLinearQATModel() model = quantize_qat(model, test_only_train_fn, self.img_data) checkQuantized(model)
def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, "Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, "Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, "Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, "Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, "Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data) checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, self.img_data) checkQuantized(model)
def test_fixed_qparam_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.sigmoid = torch.nn.Sigmoid() self.hardsigmoid = torch.nn.Hardsigmoid() self.tanh = torch.nn.Tanh() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.sigmoid(x) x = self.hardsigmoid(x) x = self.tanh(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) data = torch.randn(1, 3, 2, 4) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert) # make sure activation post process is removed for attr in ['sigmoid', 'hardsigmoid', 'tanh']: # verify fake quant module is removd self.assertFalse( hasattr(getattr(m, attr), 'activation_post_process')) # verify that hooks are removed self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) # make sure no fake quantize module is inserted for eval mode def checkNoFQModule(m): for attr in ['sigmoid', 'hardsigmoid', 'tanh']: self.assertFalse( hasattr(getattr(m, attr), "activation_post_process")) self.assertTrue( len(getattr(m, attr)._forward_hooks.items()) == 0) m = M().eval() m.qconfig = default_qconfig m = prepare(m) checkNoFQModule(m) m = convert(m) checkNoFQModule(m)
def load_student_model(config, *, qat=False, preload_state_dict=False, preload_dir='results'): if qat: student_model, student_model_name = get_qat_model(config, pretrained=False) else: student_model, student_model_name = get_model(config, pretrained=False) # We don't preload from saved models, but rather from pretrained. if preload_state_dict: state_dict = make_student_save_name(preload_dir, config) + '.pt' state_dict = torch.load(state_dict) student_model.load_state_dict(state_dict) if qat: if hasattr(student_model, 'fuse_model'): student_model.fuse_model() student_model.qconfig = tq.get_default_qat_qconfig('fbgemm') tq.prepare_qat(student_model, inplace=True) return student_model, student_model_name
def test_eval_only_fake_quant(self): r"""Using FakeQuant in evaluation only mode, this is useful for estimating accuracy loss when we quantize the network """ for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) model.eval() test_only_eval_fn(model, self.calib_data)
def test_manual(self): model = ManualLinearQATModel() model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) checkQuantized(model) model = quantize_qat(ManualLinearQATModel(), test_only_train_fn, self.train_data) checkQuantized(model)
def test_relu(self): class M(torch.nn.Module): def __init__(self): super().__init__() self.relu = nn.ReLU() def forward(self, x): x = self.relu(x) return x m = M().train() m.qconfig = default_qconfig m = prepare_qat(m) # make sure no activation_post_process is inserted for relu self.assertFalse(hasattr(m, "activation_post_process")) m = convert(m) # make sure ReLU module is not changed self.assertTrue(type(m.relu), nn.ReLU)
def test_forward_hooks_preserved(self): r"""Test QAT on preserving pre forward and post forward hooks of original model """ qengine = torch.backends.quantized.engine model = QuantStubModel() counter = { 'pre_forwards': 0, 'forwards': 0, } def fw_pre_hook(h_module, input): counter['pre_forwards'] += 1 def fw_hook(h_module, input, output): counter['forwards'] += 1 model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_hook(fw_hook) model.qconfig = torch.quantization.get_default_qat_qconfig(qengine) model = prepare_qat(model) def checkHooksIsPresent(model, before_convert=True): forward_hooks = 1 if before_convert: self.assertEqual(len(model.quant._forward_hooks.values()), 1, "Quantization observer hook has disappeared") forward_hooks = 2 self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertEqual( len(model.fc._forward_pre_hooks.values()), 1, "Extra pre forward hooks have appeared on a layer") self.assertEqual( len(model.fc._forward_hooks.values()), forward_hooks, "Extra post forward hooks have appeared on a layer") checkHooksIsPresent(model, True) x = torch.rand(2, 5, dtype=torch.float) model(x) torch.quantization.convert(model, inplace=True) checkHooksIsPresent(model, False)
def test_manual(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ManualLinearQATModel(qengine) model = prepare_qat(model) self.checkObservers(model) test_only_train_fn(model, self.train_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.fc1), nnq.Linear) self.assertEqual(type(model.fc2), nnq.Linear) test_only_eval_fn(model, self.calib_data) self.checkScriptable(model, self.calib_data) self.checkNoQconfig(model) checkQuantized(model) model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]) checkQuantized(model)
def _test_activation_convert_numerics_impl(self, Act, data): class M(torch.nn.Module): def __init__(self): super().__init__() self.act = Act() self.quant = QuantStub() self.dequant = DeQuantStub() def forward(self, x): x = self.quant(x) x = self.act(x) x = self.dequant(x) return x m = M().train() m.qconfig = default_qat_qconfig m = prepare_qat(m) before_convert = m(data) m = convert(m) after_convert = m(data) self.assertEqual(before_convert, after_convert)
def _test_model_impl(self, mode, name, model, eager_quantizable_model, check_with_eager=True, diff_of_quant=None, diff_from_eager=None): if diff_of_quant is None or diff_from_eager is None: diff_of_quant = {} diff_from_eager = {} if mode not in diff_of_quant or mode not in diff_from_eager: diff_of_quant[mode] = {} diff_from_eager[mode] = {} input_tensor = torch.rand(1, 3, 224, 224) input_tensor_inception = torch.rand(1, 3, 299, 299) output_value = torch.randint(0, 1, (1, )) # print('quantizing:', name, ' mode:', mode) if name == 'inception_v3': input_value = input_tensor_inception else: input_value = input_tensor qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) script = torch.jit.script(graph_module) # make sure graph module and script module are both runanble original_out = graph_module(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) self.assertEqual( (original_out - script_out).abs().max(), 0, 'Reslut of original graph module and script module does not match') # set to train just before quantization if mode != 'static': model.train() graph_module = fuse_fx(graph_module) prepared = prepare_fx(graph_module, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, prepared), nprocs=world_size, join=True) elif mode == 'qat': assert prepared.training, 'prepared must be in training mode for qat' optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): prepared(input_value) # print('after observation root:', prepared.root) qgraph = convert_fx(prepared) # print('after quantization root:', qgraph.root) # print('after quantization code:', qgraph.src) qgraph.eval() qgraph_script = torch.jit.script(qgraph) # print('quantized and scripted:', qgraph_script.graph) qgraph_out = qgraph(input_value) qgraph_script = qgraph_script(input_value) if is_not_tuple_out: diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' else: print('tuple output') if eager_quantizable_model is not None: # comparing to eager mode quantization qeager = eager_quantizable_model ref_out = qeager(input_value) qeager.qconfig = qconfig if mode == 'static': qeager.fuse_model() prepare(qeager, inplace=True) else: qeager.train() qeager.fuse_model() prepare_qat(qeager, inplace=True) # calibration if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, qeager), nprocs=world_size, join=True) elif mode == 'qat': assert qeager.training, 'qeager should be in training mode for qat' optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): qeager(input_value) # print('ref after observation:', qeager) convert(qeager, inplace=True) qeager.eval() # print('ref after quantization:', qeager) qeager_out = qeager(input_value) qeager_script = torch.jit.script(qeager) qscript_out = qeager_script(input_value) if is_not_tuple_out: diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() if check_with_eager: self.assertEqual( diff_from_eager[mode][name], 0, 'Result of graph mode quantization and ' + 'eager mode quantization on model: ' + name + ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name]))
def train_subject_specific_quant(subject, epochs=500, batch_size=32, lr=0.001, silent=False, plot=True, **kwargs): """ Trains a subject specific model for the given subject Parameters: - subject: Integer in the Range 1 <= subject <= 9 - epochs: Number of epochs to train - batch_size: Batch Size - lr: Learning Rate - silent: bool, if True, hide all output including the progress bar - plot: bool, if True, generate plots - kwargs: Remaining arguments passed to the EEGnet model Returns: (model, metrics) - model: t.nn.Module, trained model - metrics: t.tensor, size=[1, 4], accuracy, precision, recall, f1 """ # load the data train_samples, train_labels = get_data(subject, training=True) test_samples, test_labels = get_data(subject, training=False) train_loader = as_data_loader(train_samples, train_labels, batch_size=batch_size) # test_loader = as_data_loader(test_samples, test_labels, batch_size=test_labels.shape[0]) test_loader = as_data_loader(test_samples, test_labels, batch_size=batch_size) # prepare quantization configuration qconfig = tq.QConfig( activation=tq.MinMaxObserver.with_args(dtype=t.quint8), weight=tq.MinMaxObserver.with_args(dtype=t.qint8)) # prepare the model model = EEGNetQuant(T=train_samples.shape[2], qconfig=qconfig, **kwargs) model.initialize_params() if t.cuda.is_available(): model = model.cuda() # prepare the quantization tq.prepare_qat(model, inplace=True) # prepare loss function and optimizer loss_function = t.nn.CrossEntropyLoss() optimizer = t.optim.Adam(model.parameters(), lr=lr, eps=1e-7) scheduler = None # print the training setup print_summary(model, optimizer, loss_function, scheduler) # prepare progress bar with tqdm(desc=f"Subject {subject}", total=epochs, leave=False, disable=silent, unit='epoch', ascii=True) as pbar: # Early stopping is not allowed in this mode, because the testing data cannot be used for # training! model, metrics, _, history = _train_net(subject, model, train_loader, test_loader, loss_function, optimizer, scheduler=scheduler, epochs=epochs, early_stopping=False, plot=plot, pbar=pbar) # convert the model into a quantized model model = model.cpu() tq.convert(model, inplace=True) metrics = get_metrics_from_model(model, test_loader) if not silent: print(f"Subject {subject}: accuracy = {metrics[0, 0]}") return model, metrics, history
from torchvision import models qat_resnet18 = models.resnet18(pretrained=True).eval().cuda() quantization_type = "per_tensor" # 1. per-tensor quantization or per-channel quantization # tensorrt does not support per-channel quantization if (quantization_type == "per_tensor"): qat_resnet18.qconfig = quantization.QConfig( activation=quantization.default_fake_quant, weight=quantization.default_weight_fake_quant) else: qat_resnet18.qconfig = quantization.QConfig( activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant) quantization.prepare_qat(qat_resnet18, inplace=True) qat_resnet18.apply(quantization.enable_observer) qat_resnet18.apply(quantization.enable_fake_quant) dummy_input = torch.randn(1, 3, 224, 224).cuda() _ = qat_resnet18(dummy_input) for module in qat_resnet18.modules(): if isinstance(module, quantization.FakeQuantize): module.calculate_qparams() qat_resnet18.apply(quantization.disable_observer) qat_resnet18.cuda() # enable_onnx_checker needs to be disabled because ONNX runtime doesn't support opset 13 yet #torch.onnx.export(qat_resnet18, dummy_input, "resnet18_qat.onnx", verbose=True, opset_version=13, enable_onnx_checker=False) if (quantization_type == "per_tensor"):
def test_tuple_lowered(): # See the following discuss thread for details # https://discuss.tvm.apache.org/t/bug-frontend-pytorch-relay-ir-is-inconsistent-with-that-of-the-original-model/12010 class ConvBnRelu(nn.Module): def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1, bias=True, groups=1): super(ConvBnRelu, self).__init__() if groups > 1: self.conv = nn.Conv2d(inp, inp, kernel_size, stride, padding, bias=bias, groups=groups) self.bn = nn.BatchNorm2d(inp) else: self.conv = nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=bias, groups=groups) self.bn = nn.BatchNorm2d(oup) self.relu = nn.ReLU(inplace=True) def forward(self, inputs): x = self.conv(inputs) x = self.bn(x) x = self.relu(x) return x def conv_bn(inp, oup, stride=1, width_multiplier=1): return ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False) def conv_dw(inp, oup, stride, width_multiplier=1, padding=1): dw_block = nn.Sequential() depth_wise = ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=padding, bias=False, groups=inp) point_wise = ConvBnRelu(inp, oup, kernel_size=1, stride=1, padding=0, bias=False) dw_block.add_module("depth_wise", depth_wise) dw_block.add_module("point_wise", point_wise) return dw_block class Backbone(nn.Module): def __init__(self, width_multiplier=1): super(Backbone, self).__init__() self.width_multiplier = width_multiplier self.conv1 = conv_bn(3, 16, 2, self.width_multiplier) self.conv2 = conv_dw(16, 32, 1, self.width_multiplier) def forward(self, inputs): x1 = self.conv1(inputs) x2 = self.conv2(x1) return [x1, x2] class QuantizableBackbone(nn.Module): def __init__(self, inputsize=(128, 128)): super(QuantizableBackbone, self).__init__() self.quant = QuantStub() self.dequant = DeQuantStub() self.backbone = Backbone() def fuse_model(self): for idx, m in enumerate(self.modules()): if type(m) == ConvBnRelu: torch.quantization.fuse_modules(m, ["conv", "bn", "relu"], inplace=True) def forward(self, input): input = self.quant(input) y0, y1 = self.backbone(input) y0 = self.dequant(y0) y1 = self.dequant(y1) return y0, y1 fp32_input = torch.randn(1, 3, 128, 128) model = QuantizableBackbone() model.train() model.fuse_model() model.qconfig = get_default_qat_qconfig("qnnpack") prepare_qat(model, inplace=True) model.eval() model(fp32_input) model_int8 = torch.quantization.convert(model, inplace=True) script_module = torch.jit.trace(model_int8, fp32_input).eval() input_infos = [("input", (fp32_input.shape, "float32"))] mod, _ = relay.frontend.from_pytorch(script_module, input_infos) output = mod["main"].body assert isinstance(output, relay.Tuple) and len(output) == 2 dq1, dq2 = output assert str(dq1.op) == "qnn.dequantize" and str(dq2.op) == "qnn.dequantize" scale1 = dq1.args[1].data.numpy().item() scale2 = dq2.args[1].data.numpy().item() assert scale1 != scale2
def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data_1d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data_1d) self.checkNoQconfig(model) with self.assertRaisesRegex( RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex( RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ): checkQuantized(model)
def test_fusion_sequential_model_train(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().train() model.to(torch.float) fuse_modules(model, [['conv1', 'relu1'], ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.quantization.get_default_qat_qconfig( qengine) prepare_qat(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) self.assertEqual(type(model.relu1), nn.Identity) for i in range(3): self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) checkQAT(model) model(self.img_data_2d[1][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model)