コード例 #1
0
 def fuse_model(self) -> None:
     fuse_modules(self, ["fc1", "activation"], inplace=True)
コード例 #2
0
ファイル: test_fusion.py プロジェクト: thiagocrepaldi/pytorch
    def test_fuse_module_eval(self):
        model = ModelForFusion(default_qconfig)
        model.eval()
        model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
                             ['conv1', 'bn1', 'relu1'],
                             ['conv2', 'relu2'],
                             ['bn2', 'relu3'],
                             ['sub1.conv', 'sub1.bn']])
        self.assertEqual(type(model.conv1), nni.ConvReLU2d,
                         msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.conv1[0]), nn.Conv2d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv1[1]), nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(type(model.bn1), nn.Identity,
                         msg="Fused Conv + BN + Relu second layer (Skipped BN)")
        self.assertEqual(type(model.relu1), nn.Identity,
                         msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
        self.assertEqual(type(model.conv2), nni.ConvReLU3d,
                         msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.bn2), nni.BNReLU3d,
                         msg="Fused BN + Relu first layer (Relu is folded))")
        self.assertEqual(type(model.relu3), nn.Identity,
                         msg="Fused BN + Relu second layer (Skipped Relu)")
        self.assertEqual(type(model.conv2[0]), nn.Conv3d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv2[1]), nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(type(model.relu2), nn.Identity,
                         msg="Fused Conv + BN + Relu second layer (Skipped Relu)")

        self.assertEqual(type(model.conv3), nni.ConvReLU1d,
                         msg="Fused Conv + Relu for Conv1d (folded BN)")
        self.assertEqual(type(model.conv3[0]), nn.Conv1d,
                         msg="Fused Conv + Relu for Conv1d ")
        self.assertEqual(type(model.conv3[1]), nn.ReLU,
                         msg="Fused Conv + Relu for Conv1d")
        self.assertEqual(type(model.bn3), nn.Identity,
                         msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")

        self.assertEqual(type(model.sub1.conv), nn.Conv2d,
                         msg="Fused submodule Conv + folded BN")
        self.assertEqual(type(model.sub1.bn), nn.Identity,
                         msg="Fused submodule (skipped BN)")
        self.assertEqual(type(model.sub2.conv), nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")

        model = prepare(model)
        self.checkObservers(model)
        test_only_eval_fn(model, self.img_data_1d)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            self.assertEqual(type(model.bn2), nniq.BNReLU3d)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        checkQuantized(model)

        model = ModelForFusion(default_qconfig).eval()
        model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
                             ['conv2', 'relu2'],
                             ['bn2', 'relu3'],
                             ['sub1.conv', 'sub1.bn'],
                             ['conv3', 'bn3', 'relu4']])
        model = quantize(model, test_only_eval_fn, [self.img_data_1d])
        checkQuantized(model)
コード例 #3
0
ファイル: test_fusion.py プロジェクト: xsacha/pytorch
    def test_fuse_module_train(self):
        model = ModelForFusion(default_qat_qconfig).train()
        # Test step by step fusion
        model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
        model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
        self.assertEqual(type(model.conv1),
                         nni.ConvBnReLU2d,
                         msg="Fused Conv + BN + Relu first layer")
        self.assertEqual(type(model.bn1),
                         torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped BN)")
        self.assertEqual(type(model.relu1),
                         torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped Relu)")

        self.assertEqual(type(model.sub1.conv),
                         nni.ConvBn2d,
                         msg="Fused submodule Conv + BN")
        self.assertEqual(type(model.sub1.bn),
                         torch.nn.Identity,
                         msg="Fused submodule Conv + BN (skipped BN)")
        self.assertEqual(type(model.sub2.conv),
                         torch.nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu),
                         torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")
        model = prepare_qat(model)
        self.checkObservers(model)

        def checkQAT(model):
            self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)

        checkQAT(model)
        test_only_train_fn(model, self.img_data_1d_train)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        with self.assertRaisesRegex(
                RuntimeError,
                "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"
        ):
            checkQuantized(model)

        model = ModelForFusion(default_qat_qconfig).train()
        model = fuse_modules(
            model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        model = quantize_qat(model, test_only_train_fn,
                             [self.img_data_1d_train])
        with self.assertRaisesRegex(
                RuntimeError,
                "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"
        ):
            checkQuantized(model)
コード例 #4
0
ファイル: test_fusion.py プロジェクト: xsacha/pytorch
    def test_fusion_sequential_model_train(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ModelWithSequentialFusion().train()
                model.to(torch.float)
                fuse_modules(model,
                             [['conv1', 'relu1'],
                              ['features.0.0', 'features.0.1', 'features.0.2'],
                              ['features.1.0', 'features.1.1', 'features.1.2'],
                              ['features.2.0', 'features.2.1', 'features.2.2'],
                              ['classifier.0', 'classifier.1']],
                             inplace=True)
                self.assertEqual(type(model.conv1),
                                 nni.ConvReLU2d,
                                 msg="Fused Conv + Relu: nni.ConvReLU2d")
                self.assertEqual(type(model.conv1[0]),
                                 nn.Conv2d,
                                 msg="Fused Conv + Relu: Conv2d")
                self.assertEqual(type(model.conv1[1]),
                                 nn.ReLU,
                                 msg="Fused Conv + Relu: Relu")
                self.assertEqual(type(model.relu1),
                                 nn.Identity,
                                 msg="Fused Conv + Relu: Identity")
                for i in range(3):
                    self.assertEqual(type(model.features[i][0]),
                                     nni.ConvBnReLU2d,
                                     msg="Fused submodule Conv + folded BN")
                    self.assertEqual(type(model.features[i][1]),
                                     nn.Identity,
                                     msg="Fused submodule (skipped BN)")
                    self.assertEqual(type(model.features[i][2]),
                                     nn.Identity,
                                     msg="Non-fused submodule Conv")
                self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
                self.assertEqual(type(model.classifier[1]), nn.Identity)
                model.qconfig = torch.ao.quantization.get_default_qat_qconfig(
                    qengine)
                prepare_qat(model, inplace=True)
                self.checkObservers(model)
                model(self.img_data_2d[0][0])

                def checkQAT(model):
                    self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
                    self.assertEqual(type(model.relu1), nn.Identity)

                for i in range(3):
                    self.assertEqual(type(model.features[i][0]),
                                     nniqat.ConvBnReLU2d,
                                     msg="Fused submodule Conv + folded BN")
                    self.assertEqual(type(model.features[i][1]),
                                     nn.Identity,
                                     msg="Fused submodule (skipped BN)")
                    self.assertEqual(type(model.features[i][2]),
                                     nn.Identity,
                                     msg="Non-fused submodule Conv")
                self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
                self.assertEqual(type(model.classifier[1]), nn.Identity)

                checkQAT(model)
                model(self.img_data_2d[1][0])
                convert(model, inplace=True)
                model(self.img_data_2d[1][0])
                self.checkModelWithSequentialQuantized(model)
コード例 #5
0
 def fuse_model(self) -> None:
     for m in self.modules():
         if type(m) is ConvNormActivation:
             fuse_modules(m, ["0", "1", "2"], inplace=True)
         if type(m) is QuantizableInvertedResidual:
             m.fuse_model()
コード例 #6
0
 def fuse_model(self) -> None:
     for idx in range(len(self.conv)):
         if type(self.conv[idx]) is nn.Conv2d:
             fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)