def test_initialize_deactivate(self):
        no_replace_list = ["Linear"]
        custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]

        quant_modules.initialize(no_replace_list, custom_quant_modules)

        assert (type(quant_nn.QuantLinear(16, 256, 3)) == type(
            torch.nn.Linear(16, 256, 3)))
        assert (type(quant_nn.QuantConv2d(16, 256, 3)) == type(
            torch.nn.Conv2d(16, 256, 3)))

        quant_modules.deactivate()
示例#2
0
    def test_quant_module_replacement(self):
        """test monkey patching of modules with their quantized versions"""
        lenet = LeNet()
        qlenet = QuantLeNet()

        mod_list = [type(mod) for name, mod in lenet.named_modules()]
        mod_list = mod_list[1:]    
        qmod_list = [type(mod) for name, mod in qlenet.named_modules()]
        qmod_list = qmod_list[1:]  

        # Before any monkey patching, the networks should be different
        assert(mod_list != qmod_list)

        # Monkey patch the modules
        no_replace_list = ["Linear"]
        custom_quant_modules = [(torch.nn, "Linear", quant_nn.QuantLinear)]

        quant_modules.initialize(no_replace_list, custom_quant_modules)

        lenet = LeNet()
        qlenet = QuantLeNet()
    
        mod_list = [type(mod) for name, mod in lenet.named_modules()]
        mod_list = mod_list[1:]    
        qmod_list = [type(mod) for name, mod in qlenet.named_modules()]
        qmod_list = qmod_list[1:]

        # After monkey patching, the networks should be same
        assert(mod_list == qmod_list)

        # Reverse monkey patching
        quant_modules.deactivate()

        lenet = LeNet()
        qlenet = QuantLeNet()
    
        mod_list = [type(mod) for name, mod in lenet.named_modules()]
        mod_list = mod_list[1:]    
        qmod_list = [type(mod) for name, mod in qlenet.named_modules()]
        qmod_list = qmod_list[1:]

        # After reversing monkey patching, the networks should again be different
        assert(mod_list != qmod_list)
    def test_asp(self):
        """test Sparsity (ASP) and QAT toolkits together"""
        try:
            from apex.contrib.sparsity import ASP
        except ImportError:
            pytest.skip("ASP is not available.")

        quant_modules.initialize()
        model = LeNet()
        quant_modules.deactivate()

        optimizer = optim.SGD(model.parameters(), lr=0.01)

        ASP.init_model_for_pruning(
            model,
            mask_calculator="m4n2_1d",
            verbosity=2,
            whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d, quant_nn.modules.quant_linear.QuantLinear],
            allow_recompute_mask=False,
            custom_layer_dict={
                quant_nn.QuantConv1d: ['weight'],
                quant_nn.QuantConv2d: ['weight'],
                quant_nn.QuantConv3d: ['weight'],
                quant_nn.QuantConvTranspose1d: ['weight'],
                quant_nn.QuantConvTranspose2d: ['weight'],
                quant_nn.QuantConvTranspose3d: ['weight'],
                quant_nn.QuantLinear: ['weight']
            })
        ASP.init_optimizer_for_pruning(optimizer)
        ASP.compute_sparse_masks()

        model = model.to('cuda')
        output = model(torch.empty(16, 1, 28, 28).to('cuda'))
        optimizer.zero_grad()
        loss = F.nll_loss(output, torch.randint(10, (16,), dtype=torch.int64))
        loss.backward()
        optimizer.step()
def prepare_model(model_name,
                  data_dir,
                  per_channel_quantization,
                  batch_size_train,
                  batch_size_test,
                  batch_size_onnx,
                  calibrator,
                  pretrained=True,
                  ckpt_path=None,
                  ckpt_url=None):
    """
    Prepare the model for the classification flow.
    Arguments:
        model_name: name to use when accessing torchvision model dictionary
        data_dir: directory with train and val subdirs prepared "imagenet style"
        per_channel_quantization: iff true use per channel quantization for weights
                                   note that this isn't currently supported in ONNX-RT/Pytorch
        batch_size_train: batch size to use when training
        batch_size_test: batch size to use when testing in Pytorch
        batch_size_onnx: batch size to use when testing with ONNX-RT
        calibrator: calibration type to use (max/histogram)

        pretrained: if true a pretrained model will be loaded from torchvision
        ckpt_path: path to load a model checkpoint from, if not pretrained
        ckpt_url: url to download a model checkpoint from, if not pretrained and no path was given
        * at least one of {pretrained, path, url} must be valid

    The method returns a the following list:
        [
            Model object,
            data loader for training,
            data loader for Pytorch testing,
            data loader for onnx testing
        ]
    """
    # Use 'spawn' to avoid CUDA reinitialization with forked subprocess
    torch.multiprocessing.set_start_method('spawn')

    ## Initialize quantization, model and data loaders
    if per_channel_quantization:
        quant_desc_input = QuantDescriptor(calib_method=calibrator)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
    else:
        ## Force per tensor quantization for onnx runtime
        quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(
            quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

        quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(
            quant_desc_weight)
        quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)

    if model_name in models.__dict__:
        model = models.__dict__[model_name](pretrained=pretrained,
                                            quantize=True)
    else:
        quant_modules.initialize()
        model = torchvision.models.__dict__[model_name](pretrained=pretrained)
        quant_modules.deactivate()

    if not pretrained:
        if ckpt_path:
            checkpoint = torch.load(ckpt_path)
        else:
            checkpoint = load_state_dict_from_url(ckpt_url)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        elif 'model' in checkpoint.keys():
            checkpoint = checkpoint['model']
        model.load_state_dict(checkpoint)
    model.eval()
    model.cuda()

    ## Prepare the data loaders
    traindir = os.path.join(data_dir, 'train')
    valdir = os.path.join(data_dir, 'val')
    _args = collections.namedtuple("mock_args",
                                   ["model", "distributed", "cache_dataset"])
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        traindir, valdir,
        _args(model=model_name, distributed=False, cache_dataset=False))

    data_loader_train = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_train,
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=batch_size_test,
                                                   sampler=test_sampler,
                                                   num_workers=4,
                                                   pin_memory=True)

    data_loader_onnx = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=batch_size_onnx,
                                                   sampler=test_sampler,
                                                   num_workers=4,
                                                   pin_memory=True)

    return model, data_loader_train, data_loader_test, data_loader_onnx