示例#1
0
    def test_qat_resnet_per_channel(self):
        # Quantize ResNet50 model
        x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
        qat_resnet50 = resnet50()

        qat_resnet50.qconfig = quantization.QConfig(
            activation=quantization.default_fake_quant,
            weight=quantization.default_per_channel_weight_fake_quant,
        )
        quantization.prepare_qat(qat_resnet50, inplace=True)
        qat_resnet50.apply(torch.ao.quantization.enable_observer)
        qat_resnet50.apply(torch.ao.quantization.enable_fake_quant)

        _ = qat_resnet50(x)
        for module in qat_resnet50.modules():
            if isinstance(module, quantization.FakeQuantize):
                module.calculate_qparams()
        qat_resnet50.apply(torch.ao.quantization.disable_observer)

        self.exportTest(toC(qat_resnet50), toC(x))
    def test_eval_only_fake_quant(self):
        r"""Using FakeQuant in evaluation only mode,
        this is useful for estimating accuracy loss when we quantize the
        network
        """
        model = ManualLinearQATModel()

        model = prepare_qat(model)
        self.checkObservers(model)

        model.eval()
        test_only_eval_fn(model, self.calib_data)
示例#3
0
    def test_conv_linear(self):
        model = ManualConvLinearQATModel()

        prepare_qat(model)
        self.checkObservers(model)

        test_only_train_fn(model, self.img_data)
        convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv), nnq.Conv2d)
            self.assertEqual(type(model.fc1), nnq.Linear)
            self.assertEqual(type(model.fc2), nnq.Linear)
            test_only_eval_fn(model, self.img_data)
            self.checkScriptable(model, self.img_data)

        checkQuantized(model)

        model = ManualConvLinearQATModel()
        model = quantize_qat(model, test_only_train_fn, self.img_data)
        checkQuantized(model)
示例#4
0
    def test_fuse_module_train(self):
        model = ModelForFusion(default_qat_qconfig).train()
        # Test step by step fusion
        model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
        model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
        self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
                         "Fused Conv + BN + Relu first layer")
        self.assertEqual(type(model.bn1), torch.nn.Identity,
                         "Fused Conv + BN + Relu (skipped BN)")
        self.assertEqual(type(model.relu1), torch.nn.Identity,
                         "Fused Conv + BN + Relu (skipped Relu)")

        self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
                         "Fused submodule Conv + BN")
        self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
                         "Fused submodule Conv + BN (skipped BN)")
        self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
                         "Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
                         "Non-fused submodule ReLU")
        model = prepare_qat(model)
        self.checkObservers(model)

        def checkQAT(model):
            self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)

        checkQAT(model)
        test_only_train_fn(model, self.img_data)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            test_only_eval_fn(model, self.img_data)

        checkQuantized(model)

        model = ModelForFusion(default_qat_qconfig).train()
        model = fuse_modules(
            model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        model = quantize_qat(model, test_only_train_fn, self.img_data)
        checkQuantized(model)
    def test_fixed_qparam_ops(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.sigmoid = torch.nn.Sigmoid()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.tanh = torch.nn.Tanh()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.sigmoid(x)
                x = self.hardsigmoid(x)
                x = self.tanh(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            self.assertEqual(type(getattr(m, attr).activation_post_process),
                             FixedQParamsFakeQuantize)
        data = torch.randn(1, 3, 2, 4)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)
        # make sure activation post process is removed
        for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
            # verify fake quant module is removd
            self.assertFalse(
                hasattr(getattr(m, attr), 'activation_post_process'))
            # verify that hooks are removed
            self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)

        # make sure no fake quantize module is inserted for eval mode

        def checkNoFQModule(m):
            for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
                self.assertFalse(
                    hasattr(getattr(m, attr), "activation_post_process"))
                self.assertTrue(
                    len(getattr(m, attr)._forward_hooks.items()) == 0)

        m = M().eval()
        m.qconfig = default_qconfig
        m = prepare(m)
        checkNoFQModule(m)
        m = convert(m)
        checkNoFQModule(m)
def load_student_model(config,
                       *,
                       qat=False,
                       preload_state_dict=False,
                       preload_dir='results'):
    if qat:
        student_model, student_model_name = get_qat_model(config,
                                                          pretrained=False)
    else:
        student_model, student_model_name = get_model(config, pretrained=False)
    # We don't preload from saved models, but rather from pretrained.
    if preload_state_dict:
        state_dict = make_student_save_name(preload_dir, config) + '.pt'
        state_dict = torch.load(state_dict)
        student_model.load_state_dict(state_dict)
    if qat:
        if hasattr(student_model, 'fuse_model'):
            student_model.fuse_model()
        student_model.qconfig = tq.get_default_qat_qconfig('fbgemm')
        tq.prepare_qat(student_model, inplace=True)

    return student_model, student_model_name
    def test_eval_only_fake_quant(self):
        r"""Using FakeQuant in evaluation only mode,
        this is useful for estimating accuracy loss when we quantize the
        network
        """
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)

                model = prepare_qat(model)
                self.checkObservers(model)

                model.eval()
                test_only_eval_fn(model, self.calib_data)
示例#8
0
    def test_manual(self):
        model = ManualLinearQATModel()
        model = prepare_qat(model)
        self.checkObservers(model)
        test_only_train_fn(model, self.train_data)
        convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.fc1), nnq.Linear)
            self.assertEqual(type(model.fc2), nnq.Linear)
            test_only_eval_fn(model, self.calib_data)
            self.checkScriptable(model, self.calib_data)

        checkQuantized(model)

        model = quantize_qat(ManualLinearQATModel(), test_only_train_fn, self.train_data)
        checkQuantized(model)
    def test_relu(self):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.relu = nn.ReLU()

            def forward(self, x):
                x = self.relu(x)
                return x

        m = M().train()
        m.qconfig = default_qconfig
        m = prepare_qat(m)
        # make sure no activation_post_process is inserted for relu
        self.assertFalse(hasattr(m, "activation_post_process"))
        m = convert(m)
        # make sure ReLU module is not changed
        self.assertTrue(type(m.relu), nn.ReLU)
    def test_forward_hooks_preserved(self):
        r"""Test QAT on preserving pre forward and post forward hooks of original model
        """
        qengine = torch.backends.quantized.engine
        model = QuantStubModel()
        counter = {
            'pre_forwards': 0,
            'forwards': 0,
        }

        def fw_pre_hook(h_module, input):
            counter['pre_forwards'] += 1

        def fw_hook(h_module, input, output):
            counter['forwards'] += 1

        model.fc.register_forward_pre_hook(fw_pre_hook)
        model.fc.register_forward_hook(fw_hook)

        model.qconfig = torch.quantization.get_default_qat_qconfig(qengine)
        model = prepare_qat(model)

        def checkHooksIsPresent(model, before_convert=True):
            forward_hooks = 1
            if before_convert:
                self.assertEqual(len(model.quant._forward_hooks.values()), 1,
                                 "Quantization observer hook has disappeared")
                forward_hooks = 2
            self.assertObjectIn(fw_pre_hook,
                                model.fc._forward_pre_hooks.values())
            self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
            self.assertEqual(
                len(model.fc._forward_pre_hooks.values()), 1,
                "Extra pre forward hooks have appeared on a layer")
            self.assertEqual(
                len(model.fc._forward_hooks.values()), forward_hooks,
                "Extra post forward hooks have appeared on a layer")

        checkHooksIsPresent(model, True)
        x = torch.rand(2, 5, dtype=torch.float)
        model(x)
        torch.quantization.convert(model, inplace=True)
        checkHooksIsPresent(model, False)
    def test_manual(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ManualLinearQATModel(qengine)
                model = prepare_qat(model)
                self.checkObservers(model)
                test_only_train_fn(model, self.train_data)
                model = convert(model)

                def checkQuantized(model):
                    self.assertEqual(type(model.fc1), nnq.Linear)
                    self.assertEqual(type(model.fc2), nnq.Linear)
                    test_only_eval_fn(model, self.calib_data)
                    self.checkScriptable(model, self.calib_data)
                    self.checkNoQconfig(model)

                checkQuantized(model)

                model = quantize_qat(ManualLinearQATModel(qengine),
                                     test_only_train_fn, [self.train_data])
                checkQuantized(model)
    def _test_activation_convert_numerics_impl(self, Act, data):
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.act = Act()
                self.quant = QuantStub()
                self.dequant = DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.act(x)
                x = self.dequant(x)
                return x

        m = M().train()
        m.qconfig = default_qat_qconfig
        m = prepare_qat(m)
        before_convert = m(data)
        m = convert(m)
        after_convert = m(data)
        self.assertEqual(before_convert, after_convert)
示例#13
0
    def _test_model_impl(self,
                         mode,
                         name,
                         model,
                         eager_quantizable_model,
                         check_with_eager=True,
                         diff_of_quant=None,
                         diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1, ))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        graph_module = symbolic_trace(model)
        # print('graph module:', graph_module.src)
        script = torch.jit.script(graph_module)

        # make sure graph module and script module are both runanble
        original_out = graph_module(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)
        self.assertEqual(
            (original_out - script_out).abs().max(), 0,
            'Reslut of original graph module and script module does not match')

        # set to train just before quantization
        if mode != 'static':
            model.train()

        graph_module = fuse_fx(graph_module)
        prepared = prepare_fx(graph_module, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),
                     nprocs=world_size,
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion,
                            optimizer, [(input_value, output_value)],
                            torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = convert_fx(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out,
                                  qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),
                         nprocs=world_size,
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer,
                                [(input_value, output_value)],
                                torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out -
                                               qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(
                        diff_from_eager[mode][name], 0,
                        'Result of graph mode quantization and ' +
                        'eager mode quantization on model: ' + name +
                        ' should match. Mode: ' + mode + ' diff:' +
                        str(diff_from_eager[mode][name]))
def train_subject_specific_quant(subject,
                                 epochs=500,
                                 batch_size=32,
                                 lr=0.001,
                                 silent=False,
                                 plot=True,
                                 **kwargs):
    """
    Trains a subject specific model for the given subject

    Parameters:
     - subject:    Integer in the Range 1 <= subject <= 9
     - epochs:     Number of epochs to train
     - batch_size: Batch Size
     - lr:         Learning Rate
     - silent:     bool, if True, hide all output including the progress bar
     - plot:       bool, if True, generate plots
     - kwargs:     Remaining arguments passed to the EEGnet model

    Returns: (model, metrics)
     - model:   t.nn.Module, trained model
     - metrics: t.tensor, size=[1, 4], accuracy, precision, recall, f1
    """
    # load the data
    train_samples, train_labels = get_data(subject, training=True)
    test_samples, test_labels = get_data(subject, training=False)
    train_loader = as_data_loader(train_samples,
                                  train_labels,
                                  batch_size=batch_size)
    # test_loader = as_data_loader(test_samples, test_labels, batch_size=test_labels.shape[0])
    test_loader = as_data_loader(test_samples,
                                 test_labels,
                                 batch_size=batch_size)

    # prepare quantization configuration
    qconfig = tq.QConfig(
        activation=tq.MinMaxObserver.with_args(dtype=t.quint8),
        weight=tq.MinMaxObserver.with_args(dtype=t.qint8))

    # prepare the model
    model = EEGNetQuant(T=train_samples.shape[2], qconfig=qconfig, **kwargs)
    model.initialize_params()
    if t.cuda.is_available():
        model = model.cuda()

    # prepare the quantization
    tq.prepare_qat(model, inplace=True)

    # prepare loss function and optimizer
    loss_function = t.nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=lr, eps=1e-7)
    scheduler = None

    # print the training setup
    print_summary(model, optimizer, loss_function, scheduler)

    # prepare progress bar
    with tqdm(desc=f"Subject {subject}",
              total=epochs,
              leave=False,
              disable=silent,
              unit='epoch',
              ascii=True) as pbar:

        # Early stopping is not allowed in this mode, because the testing data cannot be used for
        # training!
        model, metrics, _, history = _train_net(subject,
                                                model,
                                                train_loader,
                                                test_loader,
                                                loss_function,
                                                optimizer,
                                                scheduler=scheduler,
                                                epochs=epochs,
                                                early_stopping=False,
                                                plot=plot,
                                                pbar=pbar)

    # convert the model into a quantized model
    model = model.cpu()
    tq.convert(model, inplace=True)

    metrics = get_metrics_from_model(model, test_loader)

    if not silent:
        print(f"Subject {subject}: accuracy = {metrics[0, 0]}")
    return model, metrics, history
示例#15
0
from torchvision import models

qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

quantization_type = "per_tensor"
# 1. per-tensor quantization or per-channel quantization
# tensorrt does not support per-channel quantization
if (quantization_type == "per_tensor"):
    qat_resnet18.qconfig = quantization.QConfig(
        activation=quantization.default_fake_quant,
        weight=quantization.default_weight_fake_quant)
else:
    qat_resnet18.qconfig = quantization.QConfig(
        activation=quantization.default_fake_quant,
        weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(1, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

# enable_onnx_checker needs to be disabled because ONNX runtime doesn't support opset 13 yet
#torch.onnx.export(qat_resnet18, dummy_input, "resnet18_qat.onnx", verbose=True, opset_version=13, enable_onnx_checker=False)
if (quantization_type == "per_tensor"):
示例#16
0
文件: qnn_test.py 项目: wenxcs/tvm
def test_tuple_lowered():
    # See the following discuss thread for details
    # https://discuss.tvm.apache.org/t/bug-frontend-pytorch-relay-ir-is-inconsistent-with-that-of-the-original-model/12010

    class ConvBnRelu(nn.Module):
        def __init__(self,
                     inp,
                     oup,
                     kernel_size=3,
                     stride=1,
                     padding=1,
                     bias=True,
                     groups=1):
            super(ConvBnRelu, self).__init__()
            if groups > 1:
                self.conv = nn.Conv2d(inp,
                                      inp,
                                      kernel_size,
                                      stride,
                                      padding,
                                      bias=bias,
                                      groups=groups)
                self.bn = nn.BatchNorm2d(inp)
            else:
                self.conv = nn.Conv2d(inp,
                                      oup,
                                      kernel_size,
                                      stride,
                                      padding,
                                      bias=bias,
                                      groups=groups)
                self.bn = nn.BatchNorm2d(oup)
            self.relu = nn.ReLU(inplace=True)

        def forward(self, inputs):
            x = self.conv(inputs)
            x = self.bn(x)
            x = self.relu(x)
            return x

    def conv_bn(inp, oup, stride=1, width_multiplier=1):
        return ConvBnRelu(inp,
                          oup,
                          kernel_size=3,
                          stride=stride,
                          padding=1,
                          bias=False)

    def conv_dw(inp, oup, stride, width_multiplier=1, padding=1):
        dw_block = nn.Sequential()
        depth_wise = ConvBnRelu(inp,
                                oup,
                                kernel_size=3,
                                stride=stride,
                                padding=padding,
                                bias=False,
                                groups=inp)
        point_wise = ConvBnRelu(inp,
                                oup,
                                kernel_size=1,
                                stride=1,
                                padding=0,
                                bias=False)

        dw_block.add_module("depth_wise", depth_wise)
        dw_block.add_module("point_wise", point_wise)

        return dw_block

    class Backbone(nn.Module):
        def __init__(self, width_multiplier=1):
            super(Backbone, self).__init__()
            self.width_multiplier = width_multiplier
            self.conv1 = conv_bn(3, 16, 2, self.width_multiplier)
            self.conv2 = conv_dw(16, 32, 1, self.width_multiplier)

        def forward(self, inputs):
            x1 = self.conv1(inputs)
            x2 = self.conv2(x1)
            return [x1, x2]

    class QuantizableBackbone(nn.Module):
        def __init__(self, inputsize=(128, 128)):
            super(QuantizableBackbone, self).__init__()
            self.quant = QuantStub()
            self.dequant = DeQuantStub()
            self.backbone = Backbone()

        def fuse_model(self):
            for idx, m in enumerate(self.modules()):
                if type(m) == ConvBnRelu:
                    torch.quantization.fuse_modules(m, ["conv", "bn", "relu"],
                                                    inplace=True)

        def forward(self, input):
            input = self.quant(input)
            y0, y1 = self.backbone(input)
            y0 = self.dequant(y0)
            y1 = self.dequant(y1)
            return y0, y1

    fp32_input = torch.randn(1, 3, 128, 128)
    model = QuantizableBackbone()
    model.train()
    model.fuse_model()
    model.qconfig = get_default_qat_qconfig("qnnpack")

    prepare_qat(model, inplace=True)

    model.eval()
    model(fp32_input)

    model_int8 = torch.quantization.convert(model, inplace=True)
    script_module = torch.jit.trace(model_int8, fp32_input).eval()

    input_infos = [("input", (fp32_input.shape, "float32"))]
    mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
    output = mod["main"].body

    assert isinstance(output, relay.Tuple) and len(output) == 2
    dq1, dq2 = output
    assert str(dq1.op) == "qnn.dequantize" and str(dq2.op) == "qnn.dequantize"
    scale1 = dq1.args[1].data.numpy().item()
    scale2 = dq2.args[1].data.numpy().item()
    assert scale1 != scale2
示例#17
0
    def test_fuse_module_train(self):
        model = ModelForFusion(default_qat_qconfig).train()
        # Test step by step fusion
        model = fuse_modules(model, ['conv1', 'bn1', 'relu1'])
        model = fuse_modules(model, ['sub1.conv', 'sub1.bn'])
        self.assertEqual(type(model.conv1),
                         nni.ConvBnReLU2d,
                         msg="Fused Conv + BN + Relu first layer")
        self.assertEqual(type(model.bn1),
                         torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped BN)")
        self.assertEqual(type(model.relu1),
                         torch.nn.Identity,
                         msg="Fused Conv + BN + Relu (skipped Relu)")

        self.assertEqual(type(model.sub1.conv),
                         nni.ConvBn2d,
                         msg="Fused submodule Conv + BN")
        self.assertEqual(type(model.sub1.bn),
                         torch.nn.Identity,
                         msg="Fused submodule Conv + BN (skipped BN)")
        self.assertEqual(type(model.sub2.conv),
                         torch.nn.Conv2d,
                         msg="Non-fused submodule Conv")
        self.assertEqual(type(model.sub2.relu),
                         torch.nn.ReLU,
                         msg="Non-fused submodule ReLU")
        model = prepare_qat(model)
        self.checkObservers(model)

        def checkQAT(model):
            self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)

        checkQAT(model)
        test_only_train_fn(model, self.img_data_1d_train)
        model = convert(model)

        def checkQuantized(model):
            self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
            self.assertEqual(type(model.bn1), nn.Identity)
            self.assertEqual(type(model.relu1), nn.Identity)
            self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
            self.assertEqual(type(model.sub1.bn), nn.Identity)
            self.assertEqual(type(model.sub2.conv), nn.Conv2d)
            self.assertEqual(type(model.sub2.relu), nn.ReLU)
            test_only_eval_fn(model, self.img_data_1d)
            self.checkNoQconfig(model)

        with self.assertRaisesRegex(
                RuntimeError,
                "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"
        ):
            checkQuantized(model)

        model = ModelForFusion(default_qat_qconfig).train()
        model = fuse_modules(
            model, [['conv1', 'bn1', 'relu1'], ['sub1.conv', 'sub1.bn']])
        model = quantize_qat(model, test_only_train_fn,
                             [self.img_data_1d_train])
        with self.assertRaisesRegex(
                RuntimeError,
                "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"
        ):
            checkQuantized(model)
示例#18
0
    def test_fusion_sequential_model_train(self):
        for qengine in supported_qengines:
            with override_quantized_engine(qengine):
                model = ModelWithSequentialFusion().train()
                model.to(torch.float)
                fuse_modules(model,
                             [['conv1', 'relu1'],
                              ['features.0.0', 'features.0.1', 'features.0.2'],
                              ['features.1.0', 'features.1.1', 'features.1.2'],
                              ['features.2.0', 'features.2.1', 'features.2.2'],
                              ['classifier.0', 'classifier.1']],
                             inplace=True)
                self.assertEqual(type(model.conv1),
                                 nni.ConvReLU2d,
                                 msg="Fused Conv + Relu: nni.ConvReLU2d")
                self.assertEqual(type(model.conv1[0]),
                                 nn.Conv2d,
                                 msg="Fused Conv + Relu: Conv2d")
                self.assertEqual(type(model.conv1[1]),
                                 nn.ReLU,
                                 msg="Fused Conv + Relu: Relu")
                self.assertEqual(type(model.relu1),
                                 nn.Identity,
                                 msg="Fused Conv + Relu: Identity")
                for i in range(3):
                    self.assertEqual(type(model.features[i][0]),
                                     nni.ConvBnReLU2d,
                                     msg="Fused submodule Conv + folded BN")
                    self.assertEqual(type(model.features[i][1]),
                                     nn.Identity,
                                     msg="Fused submodule (skipped BN)")
                    self.assertEqual(type(model.features[i][2]),
                                     nn.Identity,
                                     msg="Non-fused submodule Conv")
                self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
                self.assertEqual(type(model.classifier[1]), nn.Identity)
                model.qconfig = torch.quantization.get_default_qat_qconfig(
                    qengine)
                prepare_qat(model, inplace=True)
                self.checkObservers(model)
                model(self.img_data_2d[0][0])

                def checkQAT(model):
                    self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
                    self.assertEqual(type(model.relu1), nn.Identity)

                for i in range(3):
                    self.assertEqual(type(model.features[i][0]),
                                     nniqat.ConvBnReLU2d,
                                     msg="Fused submodule Conv + folded BN")
                    self.assertEqual(type(model.features[i][1]),
                                     nn.Identity,
                                     msg="Fused submodule (skipped BN)")
                    self.assertEqual(type(model.features[i][2]),
                                     nn.Identity,
                                     msg="Non-fused submodule Conv")
                self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
                self.assertEqual(type(model.classifier[1]), nn.Identity)

                checkQAT(model)
                model(self.img_data_2d[1][0])
                convert(model, inplace=True)
                model(self.img_data_2d[1][0])
                self.checkModelWithSequentialQuantized(model)