def fuse_model(self) -> None: fuse_modules(self, ["fc1", "activation"], inplace=True)
def test_fuse_module_eval(self): model = ModelForFusion(default_qconfig) model.eval() model = fuse_modules(model, [['conv3', 'bn3', 'relu4'], ['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(model.bn1), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped BN)") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv2), nni.ConvReLU3d, msg="Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.bn2), nni.BNReLU3d, msg="Fused BN + Relu first layer (Relu is folded))") self.assertEqual(type(model.relu3), nn.Identity, msg="Fused BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv2[0]), nn.Conv3d, msg="Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(model.conv2[1]), nn.ReLU, msg="Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(model.relu2), nn.Identity, msg="Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.conv3), nni.ConvReLU1d, msg="Fused Conv + Relu for Conv1d (folded BN)") self.assertEqual(type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d ") self.assertEqual(type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d") self.assertEqual(type(model.bn3), nn.Identity, msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") self.assertEqual(type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data_1d) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv3), nniq.ConvReLU1d) self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) self.assertEqual(type(model.bn2), nniq.BNReLU3d) test_only_eval_fn(model, self.img_data_1d) self.checkNoQconfig(model) checkQuantized(model) model = ModelForFusion(default_qconfig).eval() model = fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn'], ['conv3', 'bn3', 'relu4']]) model = quantize(model, test_only_eval_fn, [self.img_data_1d]) checkQuantized(model)
def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, msg="Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data_1d_train) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data_1d) self.checkNoQconfig(model) with self.assertRaisesRegex( RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ): checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) with self.assertRaisesRegex( RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ): checkQuantized(model)
def test_fusion_sequential_model_train(self): for qengine in supported_qengines: with override_quantized_engine(qengine): model = ModelWithSequentialFusion().train() model.to(torch.float) fuse_modules(model, [['conv1', 'relu1'], ['features.0.0', 'features.0.1', 'features.0.2'], ['features.1.0', 'features.1.1', 'features.1.2'], ['features.2.0', 'features.2.1', 'features.2.2'], ['classifier.0', 'classifier.1']], inplace=True) self.assertEqual(type(model.conv1), nni.ConvReLU2d, msg="Fused Conv + Relu: nni.ConvReLU2d") self.assertEqual(type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d") self.assertEqual(type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu") self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity") for i in range(3): self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) model.qconfig = torch.ao.quantization.get_default_qat_qconfig( qengine) prepare_qat(model, inplace=True) self.checkObservers(model) model(self.img_data_2d[0][0]) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) self.assertEqual(type(model.relu1), nn.Identity) for i in range(3): self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, msg="Fused submodule Conv + folded BN") self.assertEqual(type(model.features[i][1]), nn.Identity, msg="Fused submodule (skipped BN)") self.assertEqual(type(model.features[i][2]), nn.Identity, msg="Non-fused submodule Conv") self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) self.assertEqual(type(model.classifier[1]), nn.Identity) checkQAT(model) model(self.img_data_2d[1][0]) convert(model, inplace=True) model(self.img_data_2d[1][0]) self.checkModelWithSequentialQuantized(model)
def fuse_model(self) -> None: for m in self.modules(): if type(m) is ConvNormActivation: fuse_modules(m, ["0", "1", "2"], inplace=True) if type(m) is QuantizableInvertedResidual: m.fuse_model()
def fuse_model(self) -> None: for idx in range(len(self.conv)): if type(self.conv[idx]) is nn.Conv2d: fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)