コード例 #1
0
    def test_fake_quant_quant_per_channel_other_prec(self):
        kernel_size = 3

        quant_desc_input = QuantDescriptor(num_bits=4)
        quant_desc_weight = QuantDescriptor(num_bits=3, axis=(0))

        quant_conv_object = quant_conv.QuantConv3d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_input=quant_desc_input,
            quant_desc_weight=quant_desc_weight)
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 8, 8, 8)

        test_input_quantizer = TensorQuantizer(quant_desc_input)
        weight_quantizer = TensorQuantizer(quant_desc_weight)

        quant_input = test_input_quantizer(test_input)

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = weight_quantizer(weight_copy)

        out1 = F.conv3d(quant_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
コード例 #2
0
    def test_set_default_quant_desc(self):
        quant_conv_layer = quant_conv.Conv2d(32, 257, 3)
        assert quant_conv_layer.input_quantizer._axis == None
        assert quant_conv_layer.weight_quantizer._axis == (0)

        # set default to a different one
        quant_desc_input = QuantDescriptor(num_bits=11)
        quant_desc_weight = QuantDescriptor(num_bits=13, axis=(1))
        quant_conv.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_conv.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)

        # Create one with default descriptor
        quant_conv_layer = quant_conv.Conv2d(32, 257, 3)
        # Check quant_desc in quantizer created with default descriptor
        assert quant_conv_layer.input_quantizer._num_bits == quant_desc_input.num_bits
        assert quant_conv_layer.weight_quantizer._axis == quant_desc_weight.axis

        # Test default is per class
        quant_conv_layer = quant_conv.Conv3d(31, 255, 5)
        assert quant_conv_layer.input_quantizer._num_bits != quant_desc_input.num_bits
        assert quant_conv_layer.weight_quantizer._axis != quant_desc_weight.axis

        # Reset default
        quant_conv.QuantConv2d.set_default_quant_desc_input(QuantDescriptor())
        quant_conv.QuantConv2d.set_default_quant_desc_weight(QuantDescriptor(axis=(0)))
コード例 #3
0
def set_default_quantizers(args):
    """Set default quantizers before creating the model."""

    if args.calibrator == 'max':
        calib_method = 'max'
    elif args.calibrator == 'percentile':
        if args.percentile is None:
            raise ValueError(
                'Specify --percentile when using percentile calibrator')
        calib_method = 'histogram'
    elif args.calibrator == 'mse':
        calib_method = 'histogram'
    elif args.calibrator == 'entropy':
        calib_method = 'histogram'
    else:
        raise ValueError(F'Invalid calibrator {args.calibrator}')

    input_desc = QuantDescriptor(
        num_bits=args.aprec,
        calib_method=calib_method,
        narrow_range=not args.quant_asymmetric,
    )
    weight_desc = QuantDescriptor(
        num_bits=args.wprec,
        axis=(None if args.quant_per_tensor else (0, )),
    )
    quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
    quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)
コード例 #4
0
    def test_against_unquantized(self):
        kernel_size = 3
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 24, 24, 24).cuda()

        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        fake_quant_conv3d = quant_conv.QuantConv3d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=True,
            quant_desc_input=QuantDescriptor(num_bits=16),
            quant_desc_weight=QuantDescriptor(num_bits=16, axis=(0)))

        # Reset seed. Make sure weight and bias are the same
        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        conv3d = nn.Conv3d(_NUM_IN_CHANNELS, _NUM_OUT_CHANNELS, kernel_size, bias=True)

        fake_quant_output = fake_quant_conv3d(test_input)
        output = conv3d(test_input)

        test_utils.compare(fake_quant_output, output, rtol=1e-6, atol=2e-4)
コード例 #5
0
    def test_calibration(self):
        quant_model = QuantLeNet(quant_desc_input=QuantDescriptor(),
                                 quant_desc_weight=QuantDescriptor()).cuda()

        for name, module in quant_model.named_modules():
            if name.endswith("_quantizer"):
                if module._calibrator is not None:
                    module.disable_quant()
                    module.enable_calib()
                else:
                    module.disable()
                print(F"{name:40}: {module}")

        quant_model(torch.rand(16, 1, 224, 224, device="cuda"))

        # Load calib result and disable calibration
        for name, module in quant_model.named_modules():
            if name.endswith("_quantizer"):
                if module._calibrator is not None:
                    module.load_calib_amax()
                    module.enable_quant()
                    module.disable_calib()
                else:
                    module.enable()
        quant_model.cuda()
コード例 #6
0
    def prepare_config_and_inputs(self):
        # Set default quantizers before creating the model.
        import pytorch_quantization.nn as quant_nn
        from pytorch_quantization.tensor_quant import QuantDescriptor

        # The default tensor quantizer is set to use Max calibration method
        input_desc = QuantDescriptor(num_bits=8, calib_method="max")
        # The default tensor quantizer is set to be per-channel quantization for weights
        weight_desc = QuantDescriptor(num_bits=8, axis=((0,)))
        quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
        quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)
        # For the test cases, since QDQBert model is tested in one run without calibration, the quantized tensors are set as fake quantized tensors which give float type tensors in the end.
        quant_nn.TensorQuantizer.use_fb_fake_quant = True

        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

        input_mask = None
        if self.use_input_mask:
            input_mask = random_attention_mask([self.batch_size, self.seq_length])

        token_type_ids = None
        if self.use_token_type_ids:
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

        sequence_labels = None
        token_labels = None
        choice_labels = None
        if self.use_labels:
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
            choice_labels = ids_tensor([self.batch_size], self.num_choices)

        config = self.get_config()

        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
コード例 #7
0
    def test_simple_build(self):
        """test instantiation"""
        quant_model = QuantLeNet(quant_desc_input=QuantDescriptor(), quant_desc_weight=QuantDescriptor())
        for name, module in quant_model.named_modules():
            if "quantizer" in name:
                module.disable()

        input_desc = tensor_quant.QUANT_DESC_8BIT_PER_TENSOR
        weight_desc = tensor_quant.QUANT_DESC_8BIT_PER_TENSOR
        quant_model = QuantLeNet(quant_desc_input=input_desc, quant_desc_weight=weight_desc)

        input_desc = QuantDescriptor(amax=6.)
        weight_desc = QuantDescriptor(amax=1.)
        quant_model = QuantLeNet(quant_desc_input=input_desc, quant_desc_weight=weight_desc)
コード例 #8
0
    def test_fake_quant_quant_per_channel_bias(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConv3d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=True,
            quant_desc_weight=QuantDescriptor(axis=(0)))
        test_input = torch.randn(8, _NUM_IN_CHANNELS, 8, 8, 8)

        quant_input = tensor_quant.fake_tensor_quant(
            test_input, torch.max(torch.abs(test_input)))

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = tensor_quant.fake_tensor_quant(
            weight_copy,
            torch.max(torch.abs(weight_copy).view(_NUM_OUT_CHANNELS, -1),
                      dim=1,
                      keepdim=True)[0].view(_NUM_OUT_CHANNELS, 1, 1, 1, 1))

        out1 = F.conv3d(quant_input, quant_weight, bias=quant_conv_object.bias)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
コード例 #9
0
def select_default_calib_method(calib_method='histogram'):
    """Set up selected calibration method in whole network"""
    quant_desc_input = QuantDescriptor(calib_method=calib_method)
    quant_nn.QuantConv1d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(
        quant_desc_input)
コード例 #10
0
    def test_inference_no_head_absolute_embedding(self):
        # Set default quantizers before creating the model.
        import pytorch_quantization.nn as quant_nn
        from pytorch_quantization.tensor_quant import QuantDescriptor

        # The default tensor quantizer is set to use Max calibration method
        input_desc = QuantDescriptor(num_bits=8, calib_method="max")
        # The default tensor quantizer is set to be per-channel quantization for weights
        weight_desc = QuantDescriptor(num_bits=8, axis=((0,)))
        quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
        quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)

        model = QDQBertModel.from_pretrained("bert-base-uncased")
        input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
        output = model(input_ids, attention_mask=attention_mask)[0]
        expected_shape = torch.Size((1, 11, 768))
        self.assertEqual(output.shape, expected_shape)
        expected_slice = torch.tensor(
            [[[0.4571, -0.0735, 0.8594], [0.2774, -0.0278, 0.8794], [0.3548, -0.0473, 0.7593]]]
        )
        self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
コード例 #11
0
def set_default_quantizers(args):
    """Set default quantizers before creating the model."""

    if args.calibrator == "max":
        calib_method = "max"
    elif args.calibrator == "percentile":
        if args.percentile is None:
            raise ValueError(
                "Specify --percentile when using percentile calibrator")
        calib_method = "histogram"
    elif args.calibrator == "mse":
        calib_method = "histogram"
    else:
        raise ValueError(f"Invalid calibrator {args.calibrator}")

    input_desc = QuantDescriptor(num_bits=args.aprec,
                                 calib_method=calib_method)
    weight_desc = QuantDescriptor(num_bits=args.wprec,
                                  axis=(None if args.quant_per_tensor else
                                        (0, )))
    quant_nn.QuantLinear.set_default_quant_desc_input(input_desc)
    quant_nn.QuantLinear.set_default_quant_desc_weight(weight_desc)
コード例 #12
0
    def __init__(self,
                 quant_desc=QuantDescriptor(),
                 disabled=False,
                 if_quant=True,
                 if_clip=False,
                 if_calib=False):
        """Initialize quantizer and set up required variables"""
        super(TensorQuantizer, self).__init__()
        # Expand quant_desc. Use quant_desc.dict would be eaiser, but adding one-by-one explicitly gives more control
        self._num_bits = quant_desc.num_bits
        self._fake_quant = quant_desc.fake_quant
        self._axis = quant_desc.axis
        self._scale_amax = quant_desc.scale_amax
        self._learn_amax = quant_desc.learn_amax
        self._unsigned = quant_desc.unsigned
        self._narrow_range = quant_desc.narrow_range

        self._scale = None if not quant_desc.fake_quant else 1.
        self._disabled = disabled
        self._if_quant = if_quant
        self._if_clip = False
        self._if_calib = if_calib

        if quant_desc.amax is not None:
            self.register_buffer('_amax', torch.tensor(quant_desc.amax))

        # Clip module consumes a lot of memory, so only create it if learn_amax is True
        if self._learn_amax:
            init_amax = quant_desc.amax if quant_desc.amax is not None else 1.
            self.clip = Clip(-init_amax,
                             init_amax,
                             learn_min=True,
                             learn_max=True)
            # It makes more sense to enable clip stage (which learns amax) if learn_amax is true
            self.enable_clip()
        if if_clip:
            self.enable_clip()

        if quant_desc.calib_method == "histogram":
            logging.info("Creating histogram calibrator")
            self._calibrator = calib.HistogramCalibrator(
                num_bits=self._num_bits,
                axis=self._axis,
                unsigned=self._unsigned)
        elif quant_desc.calib_method == "max":
            logging.info("Creating Max calibrator")
            self._calibrator = calib.MaxCalibrator(num_bits=self._num_bits,
                                                   axis=self._axis,
                                                   unsigned=self._unsigned)
コード例 #13
0
    def test_fake_quant_per_tensor(self):

        quant_instancenorm_object = quant_instancenorm.QuantInstanceNorm1d(
            NUM_CHANNELS, affine=True, quant_desc_input=QuantDescriptor())

        test_input = torch.randn(8, NUM_CHANNELS, 128)
        quant_input = tensor_quant.fake_tensor_quant(
            test_input, torch.max(torch.abs(test_input)))

        out1 = quant_instancenorm_object(test_input)
        out2 = F.instance_norm(quant_input,
                               quant_instancenorm_object.running_mean,
                               quant_instancenorm_object.running_var,
                               quant_instancenorm_object.weight,
                               quant_instancenorm_object.bias)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
コード例 #14
0
    def test_weight_fake_quant_per_tensor(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConv2d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_weight=QuantDescriptor())
        quant_conv_object.input_quantizer.disable()
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 256, 256)

        weight_copy = quant_conv_object.weight.clone()
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, torch.max(torch.abs(weight_copy)))

        out1 = F.conv2d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
コード例 #15
0
    def test_weight_fake_quant_per_channel(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConv1d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=False,
            quant_desc_weight=QuantDescriptor(axis=(0)))
        quant_conv_object.input_quantizer.disable()
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 256)

        weight_copy = quant_conv_object.weight.clone()
        amax = quant_utils.reduce_amax(weight_copy, axis=(1, 2))
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.conv1d(test_input, quant_weight)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
コード例 #16
0
    def test_fake_quant_per_channel_bias(self):
        kernel_size = 3

        quant_conv_object = quant_conv.QuantConvTranspose1d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=True,
            quant_desc_weight=QuantDescriptor(axis=(1)))
        test_input = torch.randn(2, _NUM_IN_CHANNELS, 2)

        quant_input = tensor_quant.fake_tensor_quant(test_input, torch.max(torch.abs(test_input)))

        weight_copy = quant_conv_object.weight.clone()
        amax = quant_utils.reduce_amax(weight_copy, axis=(0, 2))
        quant_weight = tensor_quant.fake_tensor_quant(weight_copy, amax)

        out1 = F.conv_transpose1d(quant_input, quant_weight, bias=quant_conv_object.bias)
        out2 = quant_conv_object(test_input)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(), out2.detach().cpu().numpy())
コード例 #17
0
ファイル: models.py プロジェクト: leo-XUKANG/TensorRT-1
def quant_lenet():
    return QuantLeNet(quant_desc_input=QuantDescriptor(),
                      quant_desc_weight=QuantDescriptor())
コード例 #18
0
import loadMnistData

torch.manual_seed(97)
np.random.seed(97)

nImageHeight = 28
nImageWidth = 28
nTrainBatchSize = 128
nCalibrationBatchSize = 4
onnxFile = "model.onnx"
trtFile = "./model.plan"
inputImage = dataPath + "8.png"
calibrator = ["max", "histogram"][1]
percentileList = [99.9, 99.99, 99.999, 99.9999]

quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None)
qnn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
qnn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input)
qnn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None)
qnn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
qnn.QuantConvTranspose2d.set_default_quant_desc_weight(quant_desc_weight)
qnn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #self.conv1 = torch.nn.Conv2d(1, 32, (5, 5), padding=(2, 2), bias=True) # 换成对应的 Quantize 系列的 API
        self.conv1 = qnn.QuantConv2d(1, 32, (5, 5), padding=(2, 2), bias=True)
        #self.conv2 = torch.nn.Conv2d(32, 64, (5, 5), padding=(2, 2), bias=True)
コード例 #19
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument('--num_calib_batch',
                        default=1,
                        type=int,
                        help="Number of batches for calibration.")
    parser.add_argument('--calibrator',
                        type=str,
                        choices=["max", "histogram"],
                        default="max")
    parser.add_argument('--percentile',
                        nargs='+',
                        type=float,
                        default=[99.9, 99.99, 99.999, 99.9999])
    parser.add_argument("--amp",
                        action="store_true",
                        help="Use AMP in calibration.")
    parser.set_defaults(amp=False)

    args = parser.parse_args()
    torch.set_grad_enabled(False)

    # Initialize quantization
    quant_desc_input = QuantDescriptor(calib_method=args.calibrator)
    quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(
        quant_desc_input)
    quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
            'shuffle': True,
        })

    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()

    # Enable calibrators
    for name, module in asr_model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, test_batch in enumerate(asr_model.test_dataloader()):
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        if args.amp:
            with autocast():
                _ = asr_model(input_signal=test_batch[0],
                              input_signal_length=test_batch[1])
        else:
            _ = asr_model(input_signal=test_batch[0],
                          input_signal_length=test_batch[1])
        if i >= args.num_calib_batch:
            break

    # Save calibrated model(s)
    model_name = args.asr_model.replace(
        ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model
    if not args.calibrator == "histogram":
        compute_amax(asr_model, method="max")
        asr_model.save_to(
            F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo")
    else:
        for percentile in args.percentile:
            print(F"{percentile} percentile calibration")
            compute_amax(asr_model, method="percentile")
            asr_model.save_to(
                F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo"
            )

        for method in ["mse", "entropy"]:
            print(F"{method} calibration")
            compute_amax(asr_model, method=method)
            asr_model.save_to(
                F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo"
            )
コード例 #20
0
def prepare_model(model_name,
                  data_dir,
                  per_channel_quantization,
                  batch_size_train,
                  batch_size_test,
                  batch_size_onnx,
                  calibrator,
                  pretrained=True,
                  ckpt_path=None,
                  ckpt_url=None):
    """
    Prepare the model for the classification flow.
    Arguments:
        model_name: name to use when accessing torchvision model dictionary
        data_dir: directory with train and val subdirs prepared "imagenet style"
        per_channel_quantization: iff true use per channel quantization for weights
                                   note that this isn't currently supported in ONNX-RT/Pytorch
        batch_size_train: batch size to use when training
        batch_size_test: batch size to use when testing in Pytorch
        batch_size_onnx: batch size to use when testing with ONNX-RT
        calibrator: calibration type to use (max/histogram)

        pretrained: if true a pretrained model will be loaded from torchvision
        ckpt_path: path to load a model checkpoint from, if not pretrained
        ckpt_url: url to download a model checkpoint from, if not pretrained and no path was given
        * at least one of {pretrained, path, url} must be valid

    The method returns a the following list:
        [
            Model object,
            data loader for training,
            data loader for Pytorch testing,
            data loader for onnx testing
        ]
    """
    # Use 'spawn' to avoid CUDA reinitialization with forked subprocess
    torch.multiprocessing.set_start_method('spawn')

    ## Initialize quantization, model and data loaders
    if per_channel_quantization:
        quant_desc_input = QuantDescriptor(calib_method=calibrator)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
    else:
        ## Force per tensor quantization for onnx runtime
        quant_desc_input = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(
            quant_desc_input)
        quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

        quant_desc_weight = QuantDescriptor(calib_method=calibrator, axis=None)
        quant_nn.QuantConv2d.set_default_quant_desc_weight(quant_desc_weight)
        quant_nn.QuantConvTranspose2d.set_default_quant_desc_weight(
            quant_desc_weight)
        quant_nn.QuantLinear.set_default_quant_desc_weight(quant_desc_weight)

    quant_modules.initialize()

    model = torchvision.models.__dict__[model_name](pretrained=pretrained)
    if not pretrained:
        if ckpt_path:
            checkpoint = torch.load(ckpt_path)
        else:
            checkpoint = load_state_dict_from_url(ckpt_url)
        if 'state_dict' in checkpoint.keys():
            checkpoint = checkpoint['state_dict']
        elif 'model' in checkpoint.keys():
            checkpoint = checkpoint['model']
        model.load_state_dict(checkpoint)
    model.eval()
    model.cuda()

    ## Prepare the data loaders
    traindir = os.path.join(data_dir, 'train')
    valdir = os.path.join(data_dir, 'val')
    dataset, dataset_test, train_sampler, test_sampler = load_data(
        traindir, valdir, False, False)

    data_loader_train = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_train,
        sampler=train_sampler,
        num_workers=16,
        pin_memory=True)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=batch_size_test,
                                                   sampler=test_sampler,
                                                   num_workers=4,
                                                   pin_memory=True)

    data_loader_onnx = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=batch_size_onnx,
                                                   sampler=test_sampler,
                                                   num_workers=4,
                                                   pin_memory=True)

    return model, data_loader_train, data_loader_test, data_loader_onnx
コード例 #21
0
 def test_simple(self):
     quant_lenet = QuantLeNet(quant_desc_input=QuantDescriptor(),
                              quant_desc_weight=QuantDescriptor())
     quant_lenet.eval()
     helper.quant_weight_inplace(quant_lenet)