def test_fuse_module_eval(self): import torch.nn._intrinsic.modules.fused as torch_fused testMod = ModForFusion() testMod.eval() fuse_modules(testMod, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(testMod.conv1), torch_fused.ConvReLU2d, "Fused Conv + BN + Relu first layer (BN is folded)") self.assertEqual(type(testMod.conv1[0]), torch.nn.Conv2d, "Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(type(testMod.conv1[1]), torch.nn.ReLU, "Fused Conv + BN + Relu second layer (Relu only)") self.assertEqual(type(testMod.bn1), torch.nn.Identity, "Fused Conv + BN + Relu second layer (Skipped BN)") self.assertEqual(type(testMod.relu1), torch.nn.Identity, "Fused Conv + BN + Relu second layer (Skipped Relu)") self.assertEqual(type(testMod.sub1.conv), torch.nn.Conv2d, "Fused submodule Conv + folded BN") self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity, "Fused submodule (skipped BN)") self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d, "Non-fused submodule Conv") self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d, "Non-fused submodule BN")
def test_fuse_module_train(self): import torch.nn._intrinsic.modules.fused as torch_fused testMod = ModForFusion() testMod.train() fuse_modules(testMod, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']]) self.assertEqual(type(testMod.conv1), torch_fused.ConvBnReLU2d, "Fused Conv + BN + Relu first layer") self.assertEqual(type(testMod.bn1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped BN)") self.assertEqual(type(testMod.relu1), torch.nn.Identity, "Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(type(testMod.sub1.conv), torch_fused.ConvBn2d, "Fused submodule Conv + BN") self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity, "Fused submodule Conv + BN (skipped BN)") self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d, "Non-fused submodule Conv") self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d, "Non-fused submodule BN")