Beispiel #1
0
    def test_fuse_module_eval(self):
        import torch.nn._intrinsic.modules.fused as torch_fused
        testMod = ModForFusion()
        testMod.eval()
        fuse_modules(testMod,
                     [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        self.assertEqual(type(testMod.conv1), torch_fused.ConvReLU2d,
                         "Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(testMod.conv1[0]), torch.nn.Conv2d,
                         "Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(testMod.conv1[1]), torch.nn.ReLU,
                         "Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(type(testMod.bn1), torch.nn.Identity,
                         "Fused Conv + BN + Relu second layer (Skipped BN)")
        self.assertEqual(type(testMod.relu1), torch.nn.Identity,
                         "Fused Conv + BN + Relu second layer (Skipped Relu)")

        self.assertEqual(type(testMod.sub1.conv), torch.nn.Conv2d,
                         "Fused submodule Conv + folded BN")
        self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity,
                         "Fused submodule (skipped BN)")
        self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d,
                         "Non-fused submodule Conv")
        self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d,
                         "Non-fused submodule BN")
Beispiel #2
0
    def test_fuse_module_train(self):
        import torch.nn._intrinsic.modules.fused as torch_fused
        testMod = ModForFusion()
        testMod.train()
        fuse_modules(testMod,
                     [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        self.assertEqual(type(testMod.conv1), torch_fused.ConvBnReLU2d,
                         "Fused Conv + BN + Relu first layer")
        self.assertEqual(type(testMod.bn1), torch.nn.Identity,
                         "Fused Conv + BN + Relu (skipped BN)")
        self.assertEqual(type(testMod.relu1), torch.nn.Identity,
                         "Fused Conv + BN + Relu (skipped Relu)")

        self.assertEqual(type(testMod.sub1.conv), torch_fused.ConvBn2d,
                         "Fused submodule Conv + BN")
        self.assertEqual(type(testMod.sub1.bn), torch.nn.Identity,
                         "Fused submodule Conv + BN (skipped BN)")
        self.assertEqual(type(testMod.sub2.conv), torch.nn.Conv2d,
                         "Non-fused submodule Conv")
        self.assertEqual(type(testMod.sub2.bn), torch.nn.BatchNorm2d,
                         "Non-fused submodule BN")