Exemple #1
0
    def test_fuse_module_train(self):
        model = ModelForFusion(default_qat_qconfig).train()
        # Test step by step fusion
        model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
        model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
        self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
                         msg="Fused Conv + BN + Relu first layer")
        self.assertEqual(type(model.bn1), torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped BN)")
        self.assertEqual(type(model.relu1), torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped Relu)")

        self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
                         msg="Fused submodule Conv + BN")
        self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
                         msg="Fused submodule Conv + BN (skipped BN)")
        self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")
        model = prepare_qat(model)
        self.checkObservers(model)

        def checkQAT(model):
            self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)

        checkQAT(model)
        test_only_train_fn(model, self.img_data_1d_train)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
            checkQuantized(model)

        model = ModelForFusion(default_qat_qconfig).train()
        model = fuse_modules_qat(
            model,
            [['conv1', 'bn1', 'relu1'],
             ['sub1.conv', 'sub1.bn']])
        model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
        with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
            checkQuantized(model)
Exemple #2
0
    def test_forward_hooks_preserved(self):
        r"""Test case that checks whether forward pre hooks of the first module and
        post forward hooks of the last module in modules list passed to fusion function preserved.
        (e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
        after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
        """
        model = ModelForFusion(default_qat_qconfig).train()

        counter = {
            'pre_forwards': 0,
            'forwards': 0,
        }
        fused = False

        def fw_pre_hook(fused_module_class, h_module, input):
            if fused:
                self.assertEqual(
                    type(h_module), fused_module_class,
                    "After fusion owner of the first module's forward pre hook is not a fused module"
                )
            counter['pre_forwards'] += 1

        def fw_hook(fused_module_class, h_module, input, output):
            if fused:
                self.assertEqual(
                    type(h_module), fused_module_class,
                    "After fusion owner of the last module's forward hook is not a fused module"
                )
            counter['forwards'] += 1

        # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
        model.conv1.register_forward_pre_hook(
            lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
        model.sub1.conv.register_forward_pre_hook(
            lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
        model.relu1.register_forward_hook(
            lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
        model.sub1.bn.register_forward_hook(
            lambda *args: fw_hook(nni.ConvBn2d, *args))

        test_only_eval_fn(model, self.img_data_1d)
        self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
        self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))

        model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
        model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])

        fused = True
        before_fusion_pre_count = counter['pre_forwards']
        before_fusion_post_count = counter['forwards']
        test_only_eval_fn(model, self.img_data_1d)
        self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count,
                         2 * len(self.img_data_1d))
        self.assertEqual(counter['forwards'] - before_fusion_post_count,
                         2 * len(self.img_data_1d))
Exemple #3
0
    def test_fuse_module_eval(self):
        model = ModelForFusion(default_qconfig)
        model.eval()
        model = fuse_modules(
            model,
            [['conv3', 'bn3', 'relu4'], ['conv1', 'bn1', 'relu1'],
             ['conv2', 'relu2'], ['bn2', 'relu3'], ['sub1.conv', 'sub1.bn']])
        self.assertEqual(
            type(model.conv1),
            nni.ConvReLU2d,
            msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.conv1[0]),
                         nn.Conv2d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv1[1]),
                         nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(
            type(model.bn1),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped BN)")
        self.assertEqual(
            type(model.relu1),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
        self.assertEqual(
            type(model.conv2),
            nni.ConvReLU3d,
            msg="Fused Conv + BN + Relu first layer (BN is folded)")
        self.assertEqual(type(model.bn2),
                         nni.BNReLU3d,
                         msg="Fused BN + Relu first layer (Relu is folded))")
        self.assertEqual(type(model.relu3),
                         nn.Identity,
                         msg="Fused BN + Relu second layer (Skipped Relu)")
        self.assertEqual(type(model.conv2[0]),
                         nn.Conv3d,
                         msg="Fused Conv + BN + Relu (Conv + folded BN only)")
        self.assertEqual(type(model.conv2[1]),
                         nn.ReLU,
                         msg="Fused Conv + BN + Relu second layer (Relu only)")
        self.assertEqual(
            type(model.relu2),
            nn.Identity,
            msg="Fused Conv + BN + Relu second layer (Skipped Relu)")

        self.assertEqual(type(model.conv3),
                         nni.ConvReLU1d,
                         msg="Fused Conv + Relu for Conv1d (folded BN)")
        self.assertEqual(type(model.conv3[0]),
                         nn.Conv1d,
                         msg="Fused Conv + Relu for Conv1d ")
        self.assertEqual(type(model.conv3[1]),
                         nn.ReLU,
                         msg="Fused Conv + Relu for Conv1d")
        self.assertEqual(type(model.bn3),
                         nn.Identity,
                         msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")

        self.assertEqual(type(model.sub1.conv),
                         nn.Conv2d,
                         msg="Fused submodule Conv + folded BN")
        self.assertEqual(type(model.sub1.bn),
                         nn.Identity,
                         msg="Fused submodule (skipped BN)")
        self.assertEqual(type(model.sub2.conv),
                         nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu),
                         torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")

        model = prepare(model)
        self.checkObservers(model)
        test_only_eval_fn(model, self.img_data_1d)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            self.assertEqual(type(model.bn2), nniq.BNReLU3d)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        checkQuantized(model)

        model = ModelForFusion(default_qconfig).eval()
        model = fuse_modules(
            model,
            [['conv1', 'bn1', 'relu1'], ['conv2', 'relu2'], ['bn2', 'relu3'],
             ['sub1.conv', 'sub1.bn'], ['conv3', 'bn3', 'relu4']])
        model = quantize(model, test_only_eval_fn, [self.img_data_1d])
        checkQuantized(model)