def test_fuse_module_train(self): model = ModelForFusion(default_qat_qconfig).train() # Test step by step fusion model = fuse_modules(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules(model, ['sub1.conv', 'sub1.bn']) self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, "Fused Conv + BN + Relu first layer") self.assertEqual(type(model.bn1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(model.relu1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, "Fused submodule Conv + BN") self.assertEqual(type(model.sub1.bn), torch.nn.Identity, "Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, "Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, "Non-fused submodule ReLU") model = prepare_qat(model) self.checkObservers(model) def checkQAT(model): self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) checkQAT(model) test_only_train_fn(model, self.img_data) model = convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data) checkQuantized(model) model = ModelForFusion(default_qat_qconfig).train() model = fuse_modules( model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize_qat(model, test_only_train_fn, self.img_data) checkQuantized(model)
def test_fuse_module_eval(self): model = ModelForFusion(default_qconfig) model.eval() fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(model.conv1), nni.ConvReLU2d, "Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(model.conv1[0]), nn.Conv2d, "Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(model.conv1[1]), nn.ReLU, "Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(model.bn1), nn.Identity, "Fused Conv + BN + Relu second layer (Skipped BN)") self.assertEqual(type(model.relu1), nn.Identity, "Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(model.sub1.conv), nn.Conv2d, "Fused submodule Conv + folded BN") self.assertEqual(type(model.sub1.bn), nn.Identity, "Fused submodule (skipped BN)") self.assertEqual(type(model.sub2.conv), nn.Conv2d, "Non-fused submodule Conv") self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, "Non-fused submodule ReLU") prepare(model) self.checkObservers(model) test_only_eval_fn(model, self.img_data) convert(model) def checkQuantized(model): self.assertEqual(type(model.conv1), nniq.ConvReLU2d) self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.sub1.conv), nnq.Conv2d) self.assertEqual(type(model.sub1.bn), nn.Identity) self.assertEqual(type(model.sub2.conv), nn.Conv2d) self.assertEqual(type(model.sub2.relu), nn.ReLU) test_only_eval_fn(model, self.img_data) checkQuantized(model) model = ModelForFusion(default_qat_qconfig).eval() fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) model = quantize(model, test_only_eval_fn, self.img_data) checkQuantized(model)